From 590ad92a3c8651dd08f7234f714ea98e2361029a Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 23 Oct 2024 16:42:37 -0700 Subject: [PATCH 01/12] minor fixes --- py/core/pipelines/search_pipeline.py | 4 ++-- py/core/pipes/kg/community_summary.py | 31 +++++++++++++++++++++----- py/core/pipes/kg/entity_description.py | 5 ++++- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/py/core/pipelines/search_pipeline.py b/py/core/pipelines/search_pipeline.py index 22361b6d2..5df9f6124 100644 --- a/py/core/pipelines/search_pipeline.py +++ b/py/core/pipelines/search_pipeline.py @@ -103,9 +103,9 @@ async def enqueue_requests(): await enqueue_task vector_search_results = ( - await vector_search_task if use_vector_search else None + await vector_search_task if use_vector_search else [] ) - kg_results = await kg_task if do_kg else None + kg_results = await kg_task if do_kg else [] return AggregateSearchResult( vector_search_results=vector_search_results, diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 84b6ee400..354f0b08b 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -198,9 +198,13 @@ async def process_community( break except Exception as e: if attempt == 2: - raise ValueError( - f"Failed to generate a summary for community {community_number} at level {community_level}." - ) from e + logger.error( + f"KGCommunitySummaryPipe: Error generating community summary for community {community_number}: {e}" + ) + return { + "community_number": community_number, + "error": str(e), + } community_report = CommunityReport( community_number=community_number, @@ -273,11 +277,28 @@ async def _run_logic( # type: ignore ) ) + total_jobs = len(community_summary_jobs) + total_errors = 0 completed_community_summary_jobs = 0 for community_summary in asyncio.as_completed(community_summary_jobs): + + summary = await community_summary completed_community_summary_jobs += 1 if completed_community_summary_jobs % 50 == 0: logger.info( - f"KGCommunitySummaryPipe: {completed_community_summary_jobs}/{len(community_summary_jobs)} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds" + f"KGCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds" + ) + + if "error" in summary: + logger.error( + f"KGCommunitySummaryPipe: Error generating community summary for community {summary['community_number']}: {summary['error']}" ) - yield await community_summary + total_errors += 1 + continue + + yield summary + + if total_errors > 0: + raise ValueError( + f"KGCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures." + ) diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 7c5bee14d..24de1eaec 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -126,7 +126,10 @@ async def process_entity( .message.content ) - # will do more requests, but it is simpler + if not out_entity.description: + logger.error(f"No description for entity {out_entity.name}") + return out_entity.name + out_entity.description_embedding = ( await self.embedding_provider.async_get_embeddings( [out_entity.description] From 4c271e54db2bf9d21b5b65b242c71b145b21245c Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Sun, 27 Oct 2024 09:38:06 -0700 Subject: [PATCH 02/12] up --- .../main/orchestration/hatchet/kg_workflow.py | 21 +++++- .../main/orchestration/simple/kg_workflow.py | 64 ++++++++++++------- py/core/providers/database/document.py | 5 +- 3 files changed, 65 insertions(+), 25 deletions(-) diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 7080bd35b..51f81632f 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -9,7 +9,7 @@ from core import GenerationConfig from core.base import OrchestrationProvider -from core.base.abstractions import KGExtractionStatus +from core.base.abstractions import KGExtractionStatus, KGEnrichmentStatus from ...services import KgService @@ -291,12 +291,29 @@ async def kg_entity_deduplication_setup( key=f"{i}/{total_workflows}_entity_deduplication_part", ) ) - await asyncio.gather(*workflows) + result = await asyncio.gather(*workflows) + # set status to success + await self.kg_service.providers.database.set_workflow_status( + id=collection_id, + status_type="kg_enrichment_status", + status=KGEnrichmentStatus.SUCCESS, + ) return { "result": f"successfully queued kg entity deduplication for collection {collection_id} with {number_of_distinct_entities} distinct entities" } + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + collection_id = context.workflow_input()["request"][ + "collection_id" + ] + await self.kg_service.providers.database.set_workflow_status( + id=collection_id, + status_type="kg_enrichment_status", + status=KGEnrichmentStatus.FAILED, + ) + @orchestration_provider.workflow( name="kg-entity-deduplication-summary", timeout="360m" ) diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index 9f55f857c..58b14eca5 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -3,8 +3,10 @@ import math from core import GenerationConfig +from core import R2RException from ...services import KgService +from core.base.abstractions import KGEnrichmentStatus logger = logging.getLogger() @@ -62,32 +64,50 @@ async def enrich_graph(input_data): input_data = get_input_data_dict(input_data) - num_communities = await service.kg_clustering( - collection_id=input_data["collection_id"], - **input_data["kg_enrichment_settings"], - ) - num_communities = num_communities[0]["num_communities"] - # TODO - Do not hardcode the number of parallel communities, - # make it a configurable parameter at runtime & add server-side defaults - parallel_communities = min(100, num_communities) - - total_workflows = math.ceil(num_communities / parallel_communities) - for i in range(total_workflows): - input_data_copy = input_data.copy() - input_data_copy["offset"] = i * parallel_communities - input_data_copy["limit"] = min( - parallel_communities, - num_communities - i * parallel_communities, + try: + num_communities = await service.kg_clustering( + collection_id=input_data["collection_id"], + **input_data["kg_enrichment_settings"], ) - # running i'th workflow out of total_workflows - logger.info( - f"Running kg community summary for {i+1}'th workflow out of total {total_workflows} workflows" + num_communities = num_communities[0]["num_communities"] + # TODO - Do not hardcode the number of parallel communities, + # make it a configurable parameter at runtime & add server-side defaults + parallel_communities = min(100, num_communities) + + total_workflows = math.ceil(num_communities / parallel_communities) + for i in range(total_workflows): + input_data_copy = input_data.copy() + input_data_copy["offset"] = i * parallel_communities + input_data_copy["limit"] = min( + parallel_communities, + num_communities - i * parallel_communities, + ) + # running i'th workflow out of total_workflows + logger.info( + f"Running kg community summary for {i+1}'th workflow out of total {total_workflows} workflows" + ) + await kg_community_summary( + input_data=input_data_copy, + ) + + await service.providers.database.set_workflow_status( + id=input_data["collection_id"], + status_type="kg_enrichment_status", + status=KGEnrichmentStatus.SUCCESS, ) - await kg_community_summary( - input_data=input_data_copy, + return { + "result": "successfully ran kg community summary workflows" + } + + except Exception as e: + + await service.providers.database.set_workflow_status( + id=input_data["collection_id"], + status_type="kg_enrichment_status", + status=KGEnrichmentStatus.FAILED, ) - return {"result": "successfully ran kg community summary workflows"} + raise R2RException(f"Error in enriching graph: {e}") async def kg_community_summary(input_data): diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index 6999bffcb..e586200bc 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -259,7 +259,10 @@ def _get_status_model_and_table_name(self, status_type: str): elif status_type == "kg_extraction_status": return KGExtractionStatus, "document_info" elif status_type == "kg_enrichment_status": - return KGEnrichmentStatus, "collection_info" + return ( + KGEnrichmentStatus, + "collections", + ) # TODO: Rename to collection info? else: raise R2RException( status_code=400, message=f"Invalid status type: {status_type}" From 610c1dba6d5f0a3e34f4e5e66b6ad4eb259ed4e0 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Mon, 28 Oct 2024 08:31:20 -0700 Subject: [PATCH 03/12] adding bin support and make it default (#1508) * Feature/tweak actions (#1507) * up * tweak actions * adding bin sup and making it default * tested and vetted --- .../r2r-full-py-integration-tests.yml | 7 + .../r2r-light-py-integration-tests.yml | 7 + py/cli/commands/ingestion.py | 2 + py/core/base/providers/database.py | 3 + .../api/data/ingestion_router_openapi.yml | 1 + py/core/main/api/ingestion_router.py | 5 + py/core/main/services/ingestion_service.py | 1 + py/core/providers/database/postgres.py | 1 + py/core/providers/database/vector.py | 376 +++++++++++++----- py/r2r.toml | 6 +- py/sdk/mixins/ingestion.py | 2 + py/shared/abstractions/search.py | 20 +- py/shared/abstractions/vector.py | 28 +- 13 files changed, 325 insertions(+), 134 deletions(-) diff --git a/.github/workflows/r2r-full-py-integration-tests.yml b/.github/workflows/r2r-full-py-integration-tests.yml index 1f130b1cc..1e10f9414 100644 --- a/.github/workflows/r2r-full-py-integration-tests.yml +++ b/.github/workflows/r2r-full-py-integration-tests.yml @@ -10,6 +10,7 @@ on: branches: - dev - dev-minor + - main workflow_dispatch: jobs: @@ -55,23 +56,29 @@ jobs: - name: Run CLI Ingestion Tests if: matrix.test_category == 'cli-ingestion' uses: ./.github/actions/run-cli-ingestion-tests + continue-on-error: true - name: Run CLI Retrieval Tests if: matrix.test_category == 'cli-retrieval' uses: ./.github/actions/run-cli-retrieval-tests + continue-on-error: true - name: Run SDK Ingestion Tests if: matrix.test_category == 'sdk-ingestion' uses: ./.github/actions/run-sdk-ingestion-tests + continue-on-error: true - name: Run SDK Retrieval Tests if: matrix.test_category == 'sdk-retrieval' uses: ./.github/actions/run-sdk-retrieval-tests + continue-on-error: true - name: Run SDK Auth Tests if: matrix.test_category == 'sdk-auth' uses: ./.github/actions/run-sdk-auth-tests + continue-on-error: true - name: Run SDK Collections Tests if: matrix.test_category == 'sdk-collections' uses: ./.github/actions/run-sdk-collections-tests + continue-on-error: true diff --git a/.github/workflows/r2r-light-py-integration-tests.yml b/.github/workflows/r2r-light-py-integration-tests.yml index a0760f492..b4ac669a5 100644 --- a/.github/workflows/r2r-light-py-integration-tests.yml +++ b/.github/workflows/r2r-light-py-integration-tests.yml @@ -12,6 +12,7 @@ on: branches: - dev - dev-minor + - main workflow_dispatch: jobs: @@ -58,23 +59,29 @@ jobs: - name: Run CLI Ingestion Tests if: matrix.test_category == 'cli-ingestion' uses: ./.github/actions/run-cli-ingestion-tests + continue-on-error: true - name: Run CLI Retrieval Tests if: matrix.test_category == 'cli-retrieval' uses: ./.github/actions/run-cli-retrieval-tests + continue-on-error: true - name: Run SDK Ingestion Tests if: matrix.test_category == 'sdk-ingestion' uses: ./.github/actions/run-sdk-ingestion-tests + continue-on-error: true - name: Run SDK Retrieval Tests if: matrix.test_category == 'sdk-retrieval' uses: ./.github/actions/run-sdk-retrieval-tests + continue-on-error: true - name: Run SDK Auth Tests if: matrix.test_category == 'sdk-auth' uses: ./.github/actions/run-sdk-auth-tests + continue-on-error: true - name: Run SDK Collections Tests if: matrix.test_category == 'sdk-collections' uses: ./.github/actions/run-sdk-collections-tests + continue-on-error: true diff --git a/py/cli/commands/ingestion.py b/py/cli/commands/ingestion.py index e6d09e88a..cd4ad137b 100644 --- a/py/cli/commands/ingestion.py +++ b/py/cli/commands/ingestion.py @@ -243,6 +243,7 @@ async def create_vector_index( index_measure, index_arguments, index_name, + index_column, no_concurrent, ): """Create a vector index for similarity search.""" @@ -254,6 +255,7 @@ async def create_vector_index( index_measure=index_measure, index_arguments=index_arguments, index_name=index_name, + index_column=index_column, concurrently=not no_concurrent, ) click.echo(json.dumps(response, indent=2)) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 1ff3cd47f..65e6ea6d4 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -571,6 +571,7 @@ async def create_index( Union[IndexArgsIVFFlat, IndexArgsHNSW] ] = None, index_name: Optional[str] = None, + index_column: Optional[str] = None, concurrently: bool = True, ) -> None: pass @@ -1457,6 +1458,7 @@ async def create_index( Union[IndexArgsIVFFlat, IndexArgsHNSW] ] = None, index_name: Optional[str] = None, + index_column: Optional[str] = None, concurrently: bool = True, ) -> None: return await self.vector_handler.create_index( @@ -1465,6 +1467,7 @@ async def create_index( index_method, index_arguments, index_name, + index_column, concurrently, ) diff --git a/py/core/main/api/data/ingestion_router_openapi.yml b/py/core/main/api/data/ingestion_router_openapi.yml index 283ad58da..d494ca1f6 100644 --- a/py/core/main/api/data/ingestion_router_openapi.yml +++ b/py/core/main/api/data/ingestion_router_openapi.yml @@ -172,6 +172,7 @@ create_vector_index: index_method: "The indexing method to use. Options: hnsw, ivfflat, auto. Default: hnsw" index_measure: "Distance measure for vector comparisons. Options: cosine_distance, l2_distance, max_inner_product. Default: cosine_distance" index_name: "Optional custom name for the index. If not provided, one will be auto-generated" + index_column: "The column containing the vectors to index. Default: `vec`, or `vec_binary` when using hamming or jaccard distance." index_arguments: "Configuration parameters for the chosen index method. For HNSW: {m: int, ef_construction: int}. For IVFFlat: {n_lists: int}" concurrently: "Whether to create the index concurrently. Default: true" diff --git a/py/core/main/api/ingestion_router.py b/py/core/main/api/ingestion_router.py index 76f7b8ae4..b723b6fa2 100644 --- a/py/core/main/api/ingestion_router.py +++ b/py/core/main/api/ingestion_router.py @@ -509,6 +509,10 @@ async def create_vector_index_app( None, description=create_vector_descriptions.get("index_name"), ), + index_column: Optional[str] = Body( + None, + description=create_vector_descriptions.get("index_column"), + ), concurrently: bool = Body( default=True, description=create_vector_descriptions.get("concurrently"), @@ -532,6 +536,7 @@ async def create_vector_index_app( "index_method": index_method, "index_measure": index_measure, "index_name": index_name, + "index_column": index_column, "index_arguments": index_arguments, "concurrently": concurrently, }, diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index ab222737e..72637b1c0 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -660,6 +660,7 @@ def parse_create_vector_index_input(data: dict) -> dict: "index_method": IndexMethod(data["index_method"]), "index_measure": IndexMeasure(data["index_measure"]), "index_name": data["index_name"], + "index_column": data["index_column"], "index_arguments": data["index_arguments"], "concurrently": data["concurrently"], } diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index 5db7c0fef..655473192 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -152,6 +152,7 @@ def __init__( self.project_name, self.connection_manager, self.dimension, + self.quantization_type, self.enable_fts, ) self.kg_handler = PostgresKGHandler( diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 6fca3bee2..4280211c0 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -6,6 +6,8 @@ from typing import Any, Optional, Tuple, TypedDict, Union from uuid import UUID +import numpy as np + from core.base import ( IndexArgsHNSW, IndexArgsIVFFlat, @@ -33,6 +35,33 @@ def index_measure_to_ops( return _decorate_vector_type(measure.ops, quantization_type) +def quantize_vector_to_binary( + vector: Union[list[float], np.ndarray], threshold: float = 0.0 +) -> bytes: + """ + Quantizes a float vector to a binary vector string for PostgreSQL bit type. + Used when quantization_type is INT1. + + Args: + vector (Union[List[float], np.ndarray]): Input vector of floats + threshold (float, optional): Threshold for binarization. Defaults to 0.0. + + Returns: + str: Binary string representation for PostgreSQL bit type + """ + # Convert input to numpy array if it isn't already + if not isinstance(vector, np.ndarray): + vector = np.array(vector) + + # Convert to binary (1 where value > threshold, 0 otherwise) + binary_vector = (vector > threshold).astype(int) + + # Convert to string of 1s and 0s + # Convert to string of 1s and 0s, then to bytes + binary_string = "".join(map(str, binary_vector)) + return binary_string.encode("ascii") + + class HybridSearchIntermediateResult(TypedDict): semantic_rank: int full_text_rank: int @@ -55,10 +84,12 @@ def __init__( project_name: str, connection_manager: PostgresConnectionManager, dimension: int, + quantization_type: VectorQuantizationType, enable_fts: bool = False, ): super().__init__(project_name, connection_manager) self.dimension = dimension + self.quantization_type = quantization_type self.enable_fts = enable_fts async def create_tables(self): @@ -82,8 +113,12 @@ async def create_tables(self): "your database schema to the new version." ) - # TODO - Move ids to `UUID` type - # Create the vector table if it doesn't exist + binary_col = ( + "" + if self.quantization_type != VectorQuantizationType.INT1 + else f"vec_binary bit({self.dimension})," + ) + query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} ( extraction_id UUID PRIMARY KEY, @@ -91,6 +126,7 @@ async def create_tables(self): user_id UUID, collection_ids UUID[], vec vector({self.dimension}), + {binary_col} text TEXT, metadata JSONB {",fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED" if self.enable_fts else ""} @@ -108,57 +144,134 @@ async def create_tables(self): await self.connection_manager.execute_query(query) async def upsert(self, entry: VectorEntry) -> None: - query = f""" - INSERT INTO {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} - (extraction_id, document_id, user_id, collection_ids, vec, text, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (extraction_id) DO UPDATE SET - document_id = EXCLUDED.document_id, - user_id = EXCLUDED.user_id, - collection_ids = EXCLUDED.collection_ids, - vec = EXCLUDED.vec, - text = EXCLUDED.text, - metadata = EXCLUDED.metadata; """ - await self.connection_manager.execute_query( - query, - ( - entry.extraction_id, - entry.document_id, - entry.user_id, - entry.collection_ids, - str(entry.vector.data), - entry.text, - json.dumps(entry.metadata), - ), - ) + Upsert function that handles vector quantization only when quantization_type is INT1. + Matches the table schema where vec_binary column only exists for INT1 quantization. + """ + # Check the quantization type to determine which columns to use + if self.quantization_type == VectorQuantizationType.INT1: + # For quantized vectors, use vec_binary column + query = f""" + INSERT INTO {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} + (extraction_id, document_id, user_id, collection_ids, vec, vec_binary, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8) + ON CONFLICT (extraction_id) DO UPDATE SET + document_id = EXCLUDED.document_id, + user_id = EXCLUDED.user_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + vec_binary = EXCLUDED.vec_binary, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + await self.connection_manager.execute_query( + query, + ( + entry.extraction_id, + entry.document_id, + entry.user_id, + entry.collection_ids, + str(entry.vector.data), + quantize_vector_to_binary( + entry.vector.data + ), # Convert to binary + entry.text, + json.dumps(entry.metadata), + ), + ) + else: + # For regular vectors, use vec column only + query = f""" + INSERT INTO {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} + (extraction_id, document_id, user_id, collection_ids, vec, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (extraction_id) DO UPDATE SET + document_id = EXCLUDED.document_id, + user_id = EXCLUDED.user_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + + await self.connection_manager.execute_query( + query, + ( + entry.extraction_id, + entry.document_id, + entry.user_id, + entry.collection_ids, + str(entry.vector.data), + entry.text, + json.dumps(entry.metadata), + ), + ) async def upsert_entries(self, entries: list[VectorEntry]) -> None: - query = f""" - INSERT INTO {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} - (extraction_id, document_id, user_id, collection_ids, vec, text, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (extraction_id) DO UPDATE SET - document_id = EXCLUDED.document_id, - user_id = EXCLUDED.user_id, - collection_ids = EXCLUDED.collection_ids, - vec = EXCLUDED.vec, - text = EXCLUDED.text, - metadata = EXCLUDED.metadata; """ - params = [ - ( - entry.extraction_id, - entry.document_id, - entry.user_id, - entry.collection_ids, - str(entry.vector.data), - entry.text, - json.dumps(entry.metadata), - ) - for entry in entries - ] - await self.connection_manager.execute_many(query, params) + Batch upsert function that handles vector quantization only when quantization_type is INT1. + Matches the table schema where vec_binary column only exists for INT1 quantization. + """ + if self.quantization_type == VectorQuantizationType.INT1: + # For quantized vectors, use vec_binary column + query = f""" + INSERT INTO {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} + (extraction_id, document_id, user_id, collection_ids, vec, vec_binary, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8) + ON CONFLICT (extraction_id) DO UPDATE SET + document_id = EXCLUDED.document_id, + user_id = EXCLUDED.user_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + vec_binary = EXCLUDED.vec_binary, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + bin_params = [ + ( + entry.extraction_id, + entry.document_id, + entry.user_id, + entry.collection_ids, + str(entry.vector.data), + quantize_vector_to_binary( + entry.vector.data + ), # Convert to binary + entry.text, + json.dumps(entry.metadata), + ) + for entry in entries + ] + await self.connection_manager.execute_many(query, bin_params) + + else: + # For regular vectors, use vec column only + query = f""" + INSERT INTO {self._get_table_name(PostgresVectorHandler.TABLE_NAME)} + (extraction_id, document_id, user_id, collection_ids, vec, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (extraction_id) DO UPDATE SET + document_id = EXCLUDED.document_id, + user_id = EXCLUDED.user_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + params = [ + ( + entry.extraction_id, + entry.document_id, + entry.user_id, + entry.collection_ids, + str(entry.vector.data), + entry.text, + json.dumps(entry.metadata), + ) + for entry in entries + ] + + await self.connection_manager.execute_many(query, params) async def semantic_search( self, query_vector: list[float], search_settings: VectorSearchSettings @@ -177,33 +290,101 @@ async def semantic_search( f"{table_name}.text", ] - # Use cosine distance calculation - distance_calc = f"{table_name}.vec <=> $1::vector" - - if search_settings.include_values: - cols.append(f"({distance_calc}) AS distance") - - if search_settings.include_metadatas: - cols.append(f"{table_name}.metadata") - - select_clause = ", ".join(cols) + params: list[Union[str, int, bytes]] = [] + # For binary vectors (INT1), implement two-stage search + if self.quantization_type == VectorQuantizationType.INT1: + # Convert query vector to binary format + binary_query = quantize_vector_to_binary(query_vector) + # TODO - Put depth multiplier in config / settings + extended_limit = ( + search_settings.search_limit * 20 + ) # Get 20x candidates for re-ranking + + # Use binary column and binary-specific distance measures for first stage + stage1_distance = f"{table_name}.vec_binary {search_settings.index_measure.pgvector_repr} $1::bit({self.dimension})" + stage1_param = binary_query + + cols.append( + f"{table_name}.vec" + ) # Need original vector for re-ranking + if search_settings.include_metadatas: + cols.append(f"{table_name}.metadata") + + select_clause = ", ".join(cols) + where_clause = "" + params.append(stage1_param) + + if search_settings.filters: + where_clause = self._build_filters( + search_settings.filters, params + ) + where_clause = f"WHERE {where_clause}" + + # First stage: Get candidates using binary search + query = f""" + WITH candidates AS ( + SELECT {select_clause}, + ({stage1_distance}) as binary_distance + FROM {table_name} + {where_clause} + ORDER BY {stage1_distance} + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + ) + -- Second stage: Re-rank using original vectors + SELECT + extraction_id, + document_id, + user_id, + collection_ids, + text, + {"metadata," if search_settings.include_metadatas else ""} + (vec <=> ${len(params) + 4}::vector({self.dimension})) as distance + FROM candidates + ORDER BY distance + LIMIT ${len(params) + 3} + """ - where_clause = "" - params: list[Union[str, int]] = [str(query_vector)] - if search_settings.filters: - where_clause = self._build_filters(search_settings.filters, params) - where_clause = f"WHERE {where_clause}" + params.extend( + [ + extended_limit, # First stage limit + search_settings.offset, + search_settings.search_limit, # Final limit + str(query_vector), # For re-ranking + ] + ) - query = f""" - SELECT {select_clause} - FROM {table_name} - {where_clause} - ORDER BY {distance_calc} - LIMIT ${len(params) + 1} - OFFSET ${len(params) + 2} - """ + else: + # Standard float vector handling - unchanged from original + distance_calc = f"{table_name}.vec {search_settings.index_measure.pgvector_repr} $1::vector({self.dimension})" + query_param = str(query_vector) + + if search_settings.include_values: + cols.append(f"({distance_calc}) AS distance") + if search_settings.include_metadatas: + cols.append(f"{table_name}.metadata") + + select_clause = ", ".join(cols) + where_clause = "" + params.append(query_param) + + if search_settings.filters: + where_clause = self._build_filters( + search_settings.filters, params + ) + where_clause = f"WHERE {where_clause}" - params.extend([search_settings.search_limit, search_settings.offset]) + query = f""" + SELECT {select_clause} + FROM {table_name} + {where_clause} + ORDER BY {distance_calc} + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + """ + params.extend( + [search_settings.search_limit, search_settings.offset] + ) results = await self.connection_manager.fetch_query(query, params) @@ -216,7 +397,7 @@ async def semantic_search( text=result["text"], score=( (1 - float(result["distance"])) - if search_settings.include_values + if "distance" in result else -1 ), metadata=( @@ -237,7 +418,7 @@ async def full_text_search( ) where_clauses = [] - params: list[Union[str, int]] = [query_text] + params: list[Union[str, int, bytes]] = [query_text] if search_settings.filters: filters_clause = self._build_filters( @@ -400,7 +581,7 @@ async def hybrid_search( async def delete( self, filters: dict[str, Any] ) -> dict[str, dict[str, str]]: - params: list[Union[str, int]] = [] + params: list[Union[str, int, bytes]] = [] where_clause = self._build_filters(filters, params) query = f""" @@ -538,6 +719,7 @@ async def create_index( Union[IndexArgsIVFFlat, IndexArgsHNSW] ] = None, index_name: Optional[str] = None, + index_column: Optional[str] = None, concurrently: bool = True, ) -> None: """ @@ -574,7 +756,17 @@ async def create_index( if table_name == VectorTableName.VECTORS: table_name_str = f"{self.project_name}.{VectorTableName.VECTORS}" # TODO - Fix bug in vector table naming convention - col_name = "vec" + if index_column: + col_name = index_column + else: + col_name = ( + "vec" + if ( + index_measure != IndexMeasure.hamming_distance + and index_measure != IndexMeasure.jaccard_distance + ) + else "vec_binary" + ) elif table_name == VectorTableName.ENTITIES_DOCUMENT: table_name_str = ( f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}" @@ -592,6 +784,7 @@ async def create_index( col_name = "embedding" else: raise ArgError("invalid table name") + if index_method not in ( IndexMethod.ivfflat, IndexMethod.hnsw, @@ -634,7 +827,7 @@ async def create_index( index_name = ( index_name - or f"ix_{ops}_{index_method}__{time.strftime('%Y%m%d%H%M%S')}" + or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}" ) create_index_sql = f""" @@ -661,7 +854,7 @@ async def create_index( return None def _build_filters( - self, filters: dict, parameters: list[Union[str, int]] + self, filters: dict, parameters: list[Union[str, int, bytes]] ) -> str: def parse_condition(key: str, value: Any) -> str: # type: ignore @@ -1003,28 +1196,3 @@ def _get_index_options( return "WITH (m=16, ef_construction=64)" else: return "" # No options for other methods - - def _get_index_type(self, method: IndexMethod) -> str: - if method == IndexMethod.ivfflat: - return "ivfflat" - elif method == IndexMethod.hnsw: - return "hnsw" - elif method == IndexMethod.auto: - # Here you might want to implement logic to choose between ivfflat and hnsw - return "hnsw" - - def _get_index_operator(self, measure: IndexMeasure) -> str: - if measure == IndexMeasure.l2_distance: - return "vector_l2_ops" - elif measure == IndexMeasure.max_inner_product: - return "vector_ip_ops" - elif measure == IndexMeasure.cosine_distance: - return "vector_cosine_ops" - - def _get_distance_function(self, imeasure_obj: IndexMeasure) -> str: - if imeasure_obj == IndexMeasure.cosine_distance: - return "<=>" - elif imeasure_obj == IndexMeasure.l2_distance: - return "l2_distance" - elif imeasure_obj == IndexMeasure.max_inner_product: - return "max_inner_product" diff --git a/py/r2r.toml b/py/r2r.toml index b7fb67555..7ca2c0ac4 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -70,13 +70,13 @@ batch_size = 256 [embedding] provider = "litellm" -base_model = "openai/text-embedding-3-small" -base_dimension = 512 +base_model = "openai/text-embedding-3-large" +base_dimension = 3072 +quantization_settings = { quantization_type = "INT1" } batch_size = 128 add_title_as_prefix = false rerank_model = "None" concurrent_request_limit = 256 -quantization_settings = { quantization_type = "FP32" } [file] provider = "postgres" diff --git a/py/sdk/mixins/ingestion.py b/py/sdk/mixins/ingestion.py index 6c2349ba4..0a4973a0d 100644 --- a/py/sdk/mixins/ingestion.py +++ b/py/sdk/mixins/ingestion.py @@ -205,6 +205,7 @@ async def create_vector_index( index_measure: IndexMeasure = IndexMeasure.cosine_distance, index_arguments: Optional[dict] = None, index_name: Optional[str] = None, + index_column: Optional[list[str]] = None, concurrently: bool = True, ) -> dict: """ @@ -227,6 +228,7 @@ async def create_vector_index( "index_measure": index_measure, "index_arguments": index_arguments, "index_name": index_name, + "index_column": index_column, "concurrently": concurrently, } return await self._make_request( # type: ignore diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 2e8a2bad6..ad6d8d760 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -6,10 +6,10 @@ from pydantic import Field -from shared.abstractions.graph import EntityLevel - from .base import R2RSerializable +from .graph import EntityLevel from .llm import GenerationConfig +from .vector import IndexMeasure class VectorSearchResult(R2RSerializable): @@ -165,22 +165,6 @@ def as_dict(self) -> dict: } -# TODO - stop duplication of this enum, move collections primitives to 'abstractions' -class IndexMeasure(str, Enum): - """ - An enum representing the types of distance measures available for indexing. - - Attributes: - cosine_distance (str): The cosine distance measure for indexing. - l2_distance (str): The Euclidean (L2) distance measure for indexing. - max_inner_product (str): The maximum inner product measure for indexing. - """ - - cosine_distance = "cosine_distance" - l2_distance = "l2_distance" - max_inner_product = "max_inner_product" - - class HybridSearchSettings(R2RSerializable): full_text_weight: float = Field( default=1.0, description="Weight to apply to full text search" diff --git a/py/shared/abstractions/vector.py b/py/shared/abstractions/vector.py index b367edd03..7b7150f1d 100644 --- a/py/shared/abstractions/vector.py +++ b/py/shared/abstractions/vector.py @@ -44,9 +44,12 @@ class IndexMeasure(str, Enum): max_inner_product (str): The maximum inner product measure for indexing. """ - cosine_distance = "cosine_distance" l2_distance = "l2_distance" max_inner_product = "max_inner_product" + cosine_distance = "cosine_distance" + l1_distance = "l1_distance" + hamming_distance = "hamming_distance" + jaccard_distance = "jaccard_distance" def __str__(self) -> str: return self.value @@ -54,9 +57,23 @@ def __str__(self) -> str: @property def ops(self) -> str: return { - IndexMeasure.cosine_distance: "_cosine_ops", IndexMeasure.l2_distance: "_l2_ops", IndexMeasure.max_inner_product: "_ip_ops", + IndexMeasure.cosine_distance: "_cosine_ops", + IndexMeasure.l1_distance: "_l1_ops", + IndexMeasure.hamming_distance: "_hamming_ops", + IndexMeasure.jaccard_distance: "_jaccard_ops", + }[self] + + @property + def pgvector_repr(self) -> str: + return { + IndexMeasure.l2_distance: "<->", + IndexMeasure.max_inner_product: "<#>", + IndexMeasure.cosine_distance: "<=>", + IndexMeasure.l1_distance: "<+>", + IndexMeasure.hamming_distance: "<~>", + IndexMeasure.jaccard_distance: "<%>", }[self] @@ -92,13 +109,6 @@ class IndexArgsHNSW(R2RSerializable): ef_construction: Optional[int] = 64 -INDEX_MEASURE_TO_SQLA_ACC = { - IndexMeasure.cosine_distance: lambda x: x.cosine_distance, - IndexMeasure.l2_distance: lambda x: x.l2_distance, - IndexMeasure.max_inner_product: lambda x: x.max_inner_product, -} - - class VectorTableName(str, Enum): """ This enum represents the different tables where we store vectors. From 2f674dd48fe1660364ac8fd95871c32ef92f3d8b Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:40:18 -0700 Subject: [PATCH 04/12] up (#1510) * up * set verification to default false --- py/core/__init__.py | 3 + py/core/base/__init__.py | 3 + py/core/base/providers/__init__.py | 4 + py/core/base/providers/auth.py | 17 +++- py/core/base/providers/email.py | 70 ++++++++++++++ py/core/main/abstractions.py | 3 + py/core/main/app.py | 3 +- py/core/main/assembly/factory.py | 43 ++++++++- py/core/main/config.py | 5 +- py/core/providers/__init__.py | 4 + py/core/providers/auth/r2r_auth.py | 15 ++- py/core/providers/auth/supabase.py | 6 +- py/core/providers/database/vector.py | 11 ++- py/core/providers/email/__init__.py | 4 + py/core/providers/email/console_mock.py | 56 +++++++++++ py/core/providers/email/smtp.py | 118 ++++++++++++++++++++++++ py/poetry.lock | 69 ++++++++------ py/pyproject.toml | 4 +- py/r2r.toml | 3 + 19 files changed, 393 insertions(+), 48 deletions(-) create mode 100644 py/core/base/providers/email.py create mode 100644 py/core/providers/email/__init__.py create mode 100644 py/core/providers/email/console_mock.py create mode 100644 py/core/providers/email/smtp.py diff --git a/py/core/__init__.py b/py/core/__init__.py index 9d8917895..babf7f1f8 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -134,6 +134,9 @@ # Crypto provider "CryptoConfig", "CryptoProvider", + # Email provider + "EmailConfig", + "EmailProvider", # Database providers "DatabaseConfig", "DatabaseProvider", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 7cf762533..81a4cdbef 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -106,6 +106,9 @@ # Crypto provider "CryptoConfig", "CryptoProvider", + # Email provider + "EmailConfig", + "EmailProvider", # Database providers "DatabaseConfig", "DatabaseProvider", diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 8234f5f35..37af2b8f3 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -16,6 +16,7 @@ UserHandler, VectorHandler, ) +from .email import EmailConfig, EmailProvider from .embedding import EmbeddingConfig, EmbeddingProvider from .ingestion import ChunkingStrategy, IngestionConfig, IngestionProvider from .llm import CompletionConfig, CompletionProvider @@ -36,6 +37,9 @@ # Crypto provider "CryptoConfig", "CryptoProvider", + # Email provider + "EmailConfig", + "EmailProvider", # Database providers "DatabaseConnectionManager", "DocumentHandler", diff --git a/py/core/base/providers/auth.py b/py/core/base/providers/auth.py index 37fc95b83..6ed1a9338 100644 --- a/py/core/base/providers/auth.py +++ b/py/core/base/providers/auth.py @@ -10,6 +10,8 @@ from ..api.models import UserResponse from .base import Provider, ProviderConfig from .crypto import CryptoProvider +from .database import DatabaseProvider +from .email import EmailProvider logger = logging.getLogger() @@ -33,8 +35,17 @@ def validate_config(self) -> None: class AuthProvider(Provider, ABC): security = HTTPBearer(auto_error=False) - - def __init__(self, config: AuthConfig, crypto_provider: CryptoProvider): + crypto_provider: CryptoProvider + email_provider: EmailProvider + database_provider: DatabaseProvider + + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: DatabaseProvider, + email_provider: EmailProvider, + ): if not isinstance(config, AuthConfig): raise ValueError( "AuthProvider must be initialized with an AuthConfig" @@ -43,6 +54,8 @@ def __init__(self, config: AuthConfig, crypto_provider: CryptoProvider): self.admin_email = config.default_admin_email self.admin_password = config.default_admin_password self.crypto_provider = crypto_provider + self.database_provider = database_provider + self.email_provider = email_provider super().__init__(config) self.config: AuthConfig = config # for type hinting diff --git a/py/core/base/providers/email.py b/py/core/base/providers/email.py new file mode 100644 index 000000000..b619c2136 --- /dev/null +++ b/py/core/base/providers/email.py @@ -0,0 +1,70 @@ +# email_provider.py +import logging +from abc import ABC, abstractmethod +from typing import Optional + +from .base import Provider, ProviderConfig + + +class EmailConfig(ProviderConfig): + smtp_server: Optional[str] = None + smtp_port: Optional[int] = None + smtp_username: Optional[str] = None + smtp_password: Optional[str] = None + from_email: Optional[str] = None + use_tls: bool = True + + @property + def supported_providers(self) -> list[str]: + return [ + "smtp", + "console", + ] # Could add more providers like AWS SES, SendGrid etc. + + def validate_config(self) -> None: + if self.provider == "smtp": + if not all( + [ + self.smtp_server, + self.smtp_port, + self.smtp_username, + self.smtp_password, + self.from_email, + ] + ): + raise ValueError("SMTP configuration is incomplete") + + +logger = logging.getLogger(__name__) + + +class EmailProvider(Provider, ABC): + def __init__(self, config: EmailConfig): + if not isinstance(config, EmailConfig): + raise ValueError( + "EmailProvider must be initialized with an EmailConfig" + ) + super().__init__(config) + self.config: EmailConfig = config # for type hinting + + @abstractmethod + async def send_email( + self, + to_email: str, + subject: str, + body: str, + html_body: Optional[str] = None, + ) -> None: + pass + + @abstractmethod + async def send_verification_email( + self, to_email: str, verification_code: str + ) -> None: + pass + + @abstractmethod + async def send_password_reset_email( + self, to_email: str, reset_token: str + ) -> None: + pass diff --git a/py/core/main/abstractions.py b/py/core/main/abstractions.py index 63bba5cd1..0fc3bd2a8 100644 --- a/py/core/main/abstractions.py +++ b/py/core/main/abstractions.py @@ -6,6 +6,8 @@ from core.base.pipes import AsyncPipe from core.pipelines import RAGPipeline, SearchPipeline from core.providers import ( + AsyncSMTPEmailProvider, + ConsoleMockEmailProvider, HatchetOrchestrationProvider, LiteLLMCompletionProvider, LiteLLMEmbeddingProvider, @@ -31,6 +33,7 @@ class R2RProviders(BaseModel): HatchetOrchestrationProvider, SimpleOrchestrationProvider ] logging: SqlitePersistentLoggingProvider + email: Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider] class Config: arbitrary_types_allowed = True diff --git a/py/core/main/app.py b/py/core/main/app.py index 287c26238..5fc6ec16c 100644 --- a/py/core/main/app.py +++ b/py/core/main/app.py @@ -1,12 +1,11 @@ from typing import Union from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi +from fastapi.responses import JSONResponse from core.base import R2RException - from core.providers import ( HatchetOrchestrationProvider, SimpleOrchestrationProvider, diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index fede0f34d..caaf0f76f 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -10,6 +10,7 @@ CompletionProvider, CryptoConfig, DatabaseConfig, + EmailConfig, EmbeddingConfig, EmbeddingProvider, IngestionConfig, @@ -24,8 +25,10 @@ logger = logging.getLogger() from core.providers import ( + AsyncSMTPEmailProvider, BCryptConfig, BCryptProvider, + ConsoleMockEmailProvider, HatchetOrchestrationProvider, LiteLLMCompletionProvider, LiteLLMEmbeddingProvider, @@ -49,21 +52,24 @@ def __init__(self, config: R2RConfig): @staticmethod async def create_auth_provider( auth_config: AuthConfig, - database_provider: PostgresDBProvider, crypto_provider: BCryptProvider, + database_provider: PostgresDBProvider, + email_provider: Union[ + AsyncSMTPEmailProvider, ConsoleMockEmailProvider + ], *args, **kwargs, ) -> Union[R2RAuthProvider, SupabaseAuthProvider]: if auth_config.provider == "r2r": r2r_auth = R2RAuthProvider( - auth_config, crypto_provider, database_provider + auth_config, crypto_provider, database_provider, email_provider ) await r2r_auth.initialize() return r2r_auth elif auth_config.provider == "supabase": return SupabaseAuthProvider( - auth_config, crypto_provider, database_provider + auth_config, crypto_provider, database_provider, email_provider ) else: raise ValueError( @@ -208,6 +214,23 @@ def create_llm_provider( raise ValueError("Language model provider not found") return llm_provider + @staticmethod + async def create_email_provider( + email_config: Optional[EmailConfig] = None, *args, **kwargs + ) -> Optional[Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]]: + """Creates an email provider based on configuration.""" + if not email_config: + return None + + if email_config.provider == "smtp": + return AsyncSMTPEmailProvider(email_config) + elif email_config.provider == "console_mock": + return ConsoleMockEmailProvider(email_config) + else: + raise ValueError( + f"Email provider {email_config.provider} not supported." + ) + async def create_providers( self, auth_provider_override: Optional[ @@ -215,6 +238,9 @@ async def create_providers( ] = None, crypto_provider_override: Optional[BCryptProvider] = None, database_provider_override: Optional[PostgresDBProvider] = None, + email_provider_override: Optional[ + Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider] + ] = None, embedding_provider_override: Optional[ Union[LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider] ] = None, @@ -259,12 +285,20 @@ async def create_providers( ) ) + email_provider = ( + email_provider_override + or await self.create_email_provider( + self.config.email, crypto_provider, *args, **kwargs + ) + ) + auth_provider = ( auth_provider_override or await self.create_auth_provider( self.config.auth, - database_provider, crypto_provider, + database_provider, + email_provider, *args, **kwargs, ) @@ -287,6 +321,7 @@ async def create_providers( embedding=embedding_provider, ingestion=ingestion_provider, llm=llm_provider, + email=email_provider, orchestration=orchestration_provider, logging=logging_provider, ) diff --git a/py/core/main/config.py b/py/core/main/config.py index 8853b476a..4b914d6da 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -1,7 +1,6 @@ import logging import os from enum import Enum -from pathlib import Path from typing import Any, Optional import toml @@ -14,6 +13,7 @@ from ..base.providers.auth import AuthConfig from ..base.providers.crypto import CryptoConfig from ..base.providers.database import DatabaseConfig +from ..base.providers.email import EmailConfig from ..base.providers.embedding import EmbeddingConfig from ..base.providers.ingestion import IngestionConfig from ..base.providers.llm import CompletionConfig @@ -41,6 +41,7 @@ class R2RConfig: "app": [], "completion": ["provider"], "crypto": ["provider"], + "email": ["provider"], "auth": ["provider"], "embedding": [ "provider", @@ -63,6 +64,7 @@ class R2RConfig: crypto: CryptoConfig database: DatabaseConfig embedding: EmbeddingConfig + email: EmailConfig ingestion: IngestionConfig logging: PersistentLoggingConfig agent: AgentConfig @@ -113,6 +115,7 @@ def __init__(self, config_data: dict[str, Any]): self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore self.completion = CompletionConfig.create(**self.completion, app=self.app) # type: ignore self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore + self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore diff --git a/py/core/providers/__init__.py b/py/core/providers/__init__.py index 04520e415..9950961fd 100644 --- a/py/core/providers/__init__.py +++ b/py/core/providers/__init__.py @@ -1,6 +1,7 @@ from .auth import R2RAuthProvider, SupabaseAuthProvider from .crypto import BCryptConfig, BCryptProvider from .database import PostgresDBProvider +from .email import AsyncSMTPEmailProvider, ConsoleMockEmailProvider from .embeddings import LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider from .ingestion import ( # type: ignore R2RIngestionConfig, @@ -32,6 +33,9 @@ # Embeddings "LiteLLMEmbeddingProvider", "OpenAIEmbeddingProvider", + # Email + "AsyncSMTPEmailProvider", + "ConsoleMockEmailProvider", # Orchestration "HatchetOrchestrationProvider", "SimpleOrchestrationProvider", diff --git a/py/core/providers/auth/r2r_auth.py b/py/core/providers/auth/r2r_auth.py index f1fb9f88f..5babfa508 100644 --- a/py/core/providers/auth/r2r_auth.py +++ b/py/core/providers/auth/r2r_auth.py @@ -12,6 +12,7 @@ AuthProvider, CryptoProvider, DatabaseProvider, + EmailProvider, R2RException, Token, TokenData, @@ -33,11 +34,12 @@ def __init__( config: AuthConfig, crypto_provider: CryptoProvider, database_provider: DatabaseProvider, + email_provider: EmailProvider, ): - super().__init__(config, crypto_provider) + super().__init__( + config, crypto_provider, database_provider, email_provider + ) logger.debug(f"Initializing R2RAuthProvider with config: {config}") - self.crypto_provider = crypto_provider - self.database_provider = database_provider self.secret_key = ( config.secret_key or os.getenv("R2R_SECRET_KEY") or DEFAULT_R2R_SK ) @@ -157,7 +159,10 @@ async def register(self, email: str, password: str) -> UserResponse: ) new_user.verification_code_expiry = expiry # TODO - Integrate email provider(s) - # self.providers.email.send_verification_email(new_user.email, verification_code) + + await self.email_provider.send_verification_email( + new_user.email, verification_code + ) else: expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10) @@ -301,7 +306,7 @@ async def request_password_reset(self, email: str) -> Dict[str, str]: ) # TODO: Integrate with email provider to send reset link - # self.email_provider.send_reset_email(email, reset_token) + await self.email_provider.send_reset_email(email, reset_token) return {"message": "If the email exists, a reset link has been sent"} diff --git a/py/core/providers/auth/supabase.py b/py/core/providers/auth/supabase.py index d959f4a33..c5cfbba71 100644 --- a/py/core/providers/auth/supabase.py +++ b/py/core/providers/auth/supabase.py @@ -10,6 +10,7 @@ AuthProvider, CryptoProvider, DatabaseProvider, + EmailProvider, R2RException, Token, TokenData, @@ -29,8 +30,11 @@ def __init__( config: AuthConfig, crypto_provider: CryptoProvider, database_provider: DatabaseProvider, + email_provider: EmailProvider, ): - super().__init__(config, crypto_provider) + super().__init__( + config, crypto_provider, database_provider, email_provider + ) self.supabase_url = config.extra_fields.get( "supabase_url", None ) or os.getenv("SUPABASE_URL") diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 4280211c0..90ca5377b 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -299,9 +299,18 @@ async def semantic_search( extended_limit = ( search_settings.search_limit * 20 ) # Get 20x candidates for re-ranking + if ( + imeasure_obj == IndexMeasure.hamming_distance + or imeasure_obj == IndexMeasure.jaccard_distance + ): + binary_search_measure_repr = imeasure_obj.pgvector_repr + else: + binary_search_measure_repr = ( + IndexMeasure.hamming_distance.pgvector_repr + ) # Use binary column and binary-specific distance measures for first stage - stage1_distance = f"{table_name}.vec_binary {search_settings.index_measure.pgvector_repr} $1::bit({self.dimension})" + stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})" stage1_param = binary_query cols.append( diff --git a/py/core/providers/email/__init__.py b/py/core/providers/email/__init__.py new file mode 100644 index 000000000..d70f65330 --- /dev/null +++ b/py/core/providers/email/__init__.py @@ -0,0 +1,4 @@ +from .console_mock import ConsoleMockEmailProvider +from .smtp import AsyncSMTPEmailProvider + +__all__ = ["ConsoleMockEmailProvider", "AsyncSMTPEmailProvider"] diff --git a/py/core/providers/email/console_mock.py b/py/core/providers/email/console_mock.py new file mode 100644 index 000000000..3bab24723 --- /dev/null +++ b/py/core/providers/email/console_mock.py @@ -0,0 +1,56 @@ +import logging +from typing import Optional + +from core.base import EmailProvider + +logger = logging.getLogger() + + +class ConsoleMockEmailProvider(EmailProvider): + """A simple email provider that logs emails to console, useful for testing""" + + async def send_email( + self, + to_email: str, + subject: str, + body: str, + html_body: Optional[str] = None, + ) -> None: + logger.info( + f""" + -------- Email Message -------- + To: {to_email} + Subject: {subject} + Body: + {body} + ----------------------------- + """ + ) + + async def send_verification_email( + self, to_email: str, verification_code: str + ) -> None: + logger.info( + f""" + -------- Email Message -------- + To: {to_email} + Subject: Please verify your email address + Body: + Verification code: {verification_code} + ----------------------------- + """ + ) + + async def send_password_reset_email( + self, to_email: str, reset_token: str + ) -> None: + logger.info( + f""" + -------- Email Message -------- + To: {to_email} + Subject: Password Reset Request + Body: + Reset token: {reset_token} + ----------------------------- + """ + ) diff --git a/py/core/providers/email/smtp.py b/py/core/providers/email/smtp.py new file mode 100644 index 000000000..fecbe0008 --- /dev/null +++ b/py/core/providers/email/smtp.py @@ -0,0 +1,118 @@ +import logging +import os +from abc import ABC, abstractmethod +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from typing import Optional + +from aiosmtplib import SMTP + +from core.base import EmailConfig, EmailProvider + +logger = logging.getLogger() + + +class AsyncSMTPEmailProvider(EmailProvider): + def __init__(self, config: EmailConfig): + super().__init__(config) + self.smtp_server = config.smtp_server or os.getenv("R2R_SMTP_SERVER") + if not self.smtp_server: + raise ValueError("SMTP server is required") + + self.smtp_port = config.smtp_port or os.getenv("R2R_SMTP_PORT") + if not self.smtp_port: + raise ValueError("SMTP port is required") + + self.smtp_username = config.smtp_username or os.getenv( + "R2R_SMTP_USERNAME" + ) + if not self.smtp_username: + raise ValueError("SMTP username is required") + + self.smtp_password = config.smtp_password or os.getenv( + "R2R_SMTP_PASSWORD" + ) + if not self.smtp_password: + raise ValueError("SMTP password is required") + + self.from_email = config.from_email or os.getenv("R2R_FROM_EMAIL") + if not self.from_email: + raise ValueError("From email is required") + + self.use_tls = ( + config.use_tls + or os.getenv("R2R_SMTP_USE_TLS", "true").lower() == "true" + ) + + async def send_email( + self, + to_email: str, + subject: str, + body: str, + html_body: Optional[str] = None, + ) -> None: + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = self.from_email + msg["To"] = to_email + + msg.attach(MIMEText(body, "plain")) + if html_body: + msg.attach(MIMEText(html_body, "html")) + + try: + smtp = SMTP( + hostname=self.smtp_server, + port=self.smtp_port, + use_tls=self.use_tls, + ) + + await smtp.connect() + if self.smtp_username and self.smtp_password: + await smtp.login(self.smtp_username, self.smtp_password) + + await smtp.send_message(msg) + await smtp.quit() + + except Exception as e: + logger.error(f"Failed to send email: {str(e)}") + raise + + async def send_verification_email( + self, to_email: str, verification_code: str + ) -> None: + subject = "Verify Your Email Address" + body = f""" + Thank you for registering! Please verify your email address by entering the following code: + + {verification_code} + + This code will expire in 24 hours. + """ + html_body = f""" +

Email Verification

+

Thank you for registering! Please verify your email address by entering the following code:

+

{verification_code}

+

This code will expire in 24 hours.

+ """ + await self.send_email(to_email, subject, body, html_body) + + async def send_password_reset_email( + self, to_email: str, reset_token: str + ) -> None: + subject = "Password Reset Request" + body = f""" + We received a request to reset your password. Use the following code to reset your password: + + {reset_token} + + This code will expire in 1 hour. If you didn't request this reset, please ignore this email. + """ + html_body = f""" +

Password Reset Request

+

We received a request to reset your password. Use the following code to reset your password:

+

{reset_token}

+

This code will expire in 1 hour.

+

If you didn't request this reset, please ignore this email.

+ """ + await self.send_email(to_email, subject, body, html_body) diff --git a/py/poetry.lock b/py/poetry.lock index b191bea0a..e441531f2 100644 --- a/py/poetry.lock +++ b/py/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiofiles" @@ -15,7 +15,7 @@ files = [ name = "aiohappyeyeballs" version = "2.4.3" description = "Happy Eyeballs for asyncio" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiohappyeyeballs-2.4.3-py3-none-any.whl", hash = "sha256:8a7a83727b2756f394ab2895ea0765a0a8c475e3c71e98d43d76f22b4b435572"}, @@ -26,7 +26,7 @@ files = [ name = "aiohttp" version = "3.10.10" description = "Async http client/server framework (asyncio)" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiohttp-3.10.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:be7443669ae9c016b71f402e43208e13ddf00912f47f623ee5994e12fc7d4b3f"}, @@ -162,7 +162,7 @@ files = [ name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, @@ -172,6 +172,21 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosmtplib" +version = "3.0.2" +description = "asyncio SMTP client" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiosmtplib-3.0.2-py3-none-any.whl", hash = "sha256:8783059603a34834c7c90ca51103c3aa129d5922003b5ce98dbaa6d4440f10fc"}, + {file = "aiosmtplib-3.0.2.tar.gz", hash = "sha256:08fd840f9dbc23258025dca229e8a8f04d2ccf3ecb1319585615bfc7933f7f47"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.0.0)", "sphinx-autodoc-typehints (>=1.24.0)", "sphinx-copybutton (>=0.5.0)"] +uvloop = ["uvloop (>=0.18)"] + [[package]] name = "aiosqlite" version = "0.20.0" @@ -383,7 +398,7 @@ test = ["flake8 (>=6.1,<7.0)", "uvloop (>=0.15.3)"] name = "attrs" version = "24.2.0" description = "Classes Without Boilerplate" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, @@ -1367,7 +1382,7 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "frozenlist" version = "1.5.0" description = "A list-like structure which implements collections.abc.MutableSequence" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, @@ -1468,7 +1483,7 @@ files = [ name = "fsspec" version = "2024.10.0" description = "File-system specification" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "fsspec-2024.10.0-py3-none-any.whl", hash = "sha256:03b9a6785766a4de40368b88906366755e2819e758b83705c88cd7cb5fe81871"}, @@ -1967,7 +1982,7 @@ zstd = ["zstandard (>=0.18.0)"] name = "huggingface-hub" version = "0.26.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = true +optional = false python-versions = ">=3.8.0" files = [ {file = "huggingface_hub-0.26.1-py3-none-any.whl", hash = "sha256:5927a8fc64ae68859cd954b7cc29d1c8390a5e15caba6d3d349c973be8fdacf3"}, @@ -2057,7 +2072,7 @@ all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2 name = "importlib-metadata" version = "8.5.0" description = "Read metadata from Python packages" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b"}, @@ -2108,7 +2123,7 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "jinja2" version = "3.1.4" description = "A very fast and expressive template engine." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, @@ -2240,7 +2255,7 @@ files = [ name = "jsonschema" version = "4.23.0" description = "An implementation of JSON Schema validation for Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, @@ -2261,7 +2276,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2024.10.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"}, @@ -2398,7 +2413,7 @@ files = [ name = "litellm" version = "1.50.4" description = "Library to easily interface with LLM API providers" -optional = true +optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ {file = "litellm-1.50.4-py3-none-any.whl", hash = "sha256:cc6992275e24a0bbb4a3b377e6842d45a8510fc85d7f255930a64bb872980a36"}, @@ -2836,7 +2851,7 @@ files = [ name = "multidict" version = "6.1.0" description = "multidict implementation" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, @@ -3599,7 +3614,7 @@ virtualenv = ">=20.10.0" name = "propcache" version = "0.2.0" description = "Accelerated property cache" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "propcache-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c5869b8fd70b81835a6f187c5fdbe67917a04d7e52b6e7cc4e5fe39d55c39d58"}, @@ -4396,7 +4411,7 @@ websockets = ">=11,<14" name = "referencing" version = "0.35.1" description = "JSON Referencing + Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, @@ -4411,7 +4426,7 @@ rpds-py = ">=0.7.0" name = "regex" version = "2024.9.11" description = "Alternative regular expression module, to replace re." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1494fa8725c285a81d01dc8c06b55287a1ee5e0e382d8413adc0a9197aac6408"}, @@ -4549,7 +4564,7 @@ requests = ">=2.0.1,<3.0.0" name = "rpds-py" version = "0.20.0" description = "Python bindings to Rust's persistent data structures (rpds)" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, @@ -4696,11 +4711,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -5121,7 +5131,7 @@ files = [ name = "tiktoken" version = "0.8.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "tiktoken-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b07e33283463089c81ef1467180e3e00ab00d46c2c4bbcef0acab5f771d6695e"}, @@ -5132,7 +5142,6 @@ files = [ {file = "tiktoken-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d8c2d0e5ba6453a290b86cd65fc51fedf247e1ba170191715b049dac1f628005"}, {file = "tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1"}, {file = "tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a"}, - {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d"}, {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47"}, {file = "tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419"}, {file = "tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99"}, @@ -5168,7 +5177,7 @@ blobfile = ["blobfile (>=2)"] name = "tokenizers" version = "0.19.0" description = "" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "tokenizers-0.19.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:8e1c2ab2e501d52c39fa61fecb1270ff5ece272beab9b893792176c6e077116a"}, @@ -5731,7 +5740,7 @@ files = [ name = "yarl" version = "1.16.0" description = "Yet another URL library" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "yarl-1.16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32468f41242d72b87ab793a86d92f885355bcf35b3355aa650bfa846a5c60058"}, @@ -5827,7 +5836,7 @@ propcache = ">=0.2.0" name = "zipp" version = "3.20.2" description = "Backport of pathlib-compatible object wrapper for zip files" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, @@ -5843,10 +5852,10 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -core = ["aiosqlite", "apscheduler", "asyncpg", "bcrypt", "boto3", "deepdiff", "fire", "fsspec", "future", "graspologic", "gunicorn", "hatchet-sdk", "litellm", "networkx", "ollama", "passlib", "psutil", "pydantic", "pyjwt", "python-multipart", "pyyaml", "sqlalchemy", "supabase", "tokenizers", "unstructured-client", "uvicorn", "vecs"] +core = ["aiosqlite", "apscheduler", "asyncpg", "bcrypt", "boto3", "deepdiff", "fire", "fsspec", "future", "graspologic", "gunicorn", "hatchet-sdk", "networkx", "ollama", "passlib", "psutil", "pydantic", "pyjwt", "python-multipart", "pyyaml", "sqlalchemy", "supabase", "tokenizers", "unstructured-client", "uvicorn", "vecs"] ingestion-bundle = ["aiofiles", "aioshutil", "beautifulsoup4", "bs4", "markdown", "numpy", "openpyxl", "pdf2image", "pypdf", "pypdf2", "python-docx", "python-pptx"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "65a86f61c7efcbc23b9fc221cdd487e8aa4f3b3129444e0b6b75b28868f34654" +content-hash = "076bf72cff07b22d020e62cbe5e477865157be0deba19f817b1b2a41787625df" diff --git a/py/pyproject.toml b/py/pyproject.toml index 4f19cdbd7..e56393c3e 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -29,6 +29,7 @@ asyncclick = "^8.1.7.2" click = "^8.0.0" fastapi = "^0.114.0" httpx = "^0.27.0" +litellm = "^1.42.3" # move back to optional after zerox integration is complete nest-asyncio = "^1.6.0" openai = "^1.11.1" posthog = "^3.5.0" @@ -52,7 +53,6 @@ future = { version = "^1.0.0", optional = true } graspologic = { version = "^3.4.1", optional = true } gunicorn = { version = "^21.2.0", optional = true } hatchet-sdk = { version = "^0.38.0", optional = true } -litellm = { version = "^1.42.3", optional = true } networkx = { version = "^3.3", optional = true } ollama = { version = "^0.3.1", optional = true } passlib = { version = "^1.7.4", optional = true } @@ -80,6 +80,7 @@ pypdf = { version = "^4.2.0", optional = true } pypdf2 = { version = "^3.0.1", optional = true } python-pptx = { version = "^1.0.1", optional = true } python-docx = { version = "^1.1.0", optional = true } +aiosmtplib = "^3.0.2" [tool.poetry.extras] core = [ @@ -95,7 +96,6 @@ core = [ "graspologic", "gunicorn", "hatchet-sdk", - "litellm", "networkx", "ollama", "passlib", diff --git a/py/r2r.toml b/py/r2r.toml index 7ca2c0ac4..5ece375f4 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -111,3 +111,6 @@ provider = "simple" [prompt] provider = "r2r" + +[email] +provider = "console_mock" From 255bf73feae93aaa49449f6b16cb01d77b2a61ba Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Mon, 28 Oct 2024 16:20:09 -0700 Subject: [PATCH 05/12] up --- .vscode/launch.json | 26 +++ .../2024-10-28/10-02-34/.hydra/config.yaml | 35 ++++ outputs/2024-10-28/10-02-34/.hydra/hydra.yaml | 159 ++++++++++++++++++ .../2024-10-28/10-02-34/.hydra/overrides.yaml | 1 + outputs/2024-10-28/10-02-34/main.log | 2 + py/core/main/api/management_router.py | 10 +- .../hatchet/ingestion_workflow.py | 21 ++- py/core/main/services/ingestion_service.py | 18 +- py/poetry.lock | 50 +++++- py/pyproject.toml | 1 + 10 files changed, 302 insertions(+), 21 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 outputs/2024-10-28/10-02-34/.hydra/config.yaml create mode 100644 outputs/2024-10-28/10-02-34/.hydra/hydra.yaml create mode 100644 outputs/2024-10-28/10-02-34/.hydra/overrides.yaml create mode 100644 outputs/2024-10-28/10-02-34/main.log diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..7d10ad3be --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,26 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: main.py with RAG", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/py/workspace/evals/main.py", + "console": "integratedTerminal", + "args": ["step=rag"], + "python": "/Users/shreyas/Library/Caches/pypoetry/virtualenvs/r2r-R0EY5voQ-py3.12/bin/python" + }, + { + "name": "Python Debugger: Current File with Arguments", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/py/workspace/evals/main.py", + "console": "integratedTerminal", + "args": "${command:pickArgs}", + "python": "/Users/shreyas/Library/Caches/pypoetry/virtualenvs/r2r-R0EY5voQ-py3.12/bin/python" + } + ] +} diff --git a/outputs/2024-10-28/10-02-34/.hydra/config.yaml b/outputs/2024-10-28/10-02-34/.hydra/config.yaml new file mode 100644 index 000000000..52d5c6a13 --- /dev/null +++ b/outputs/2024-10-28/10-02-34/.hydra/config.yaml @@ -0,0 +1,35 @@ +name: vector_rag +step: ingest +run_settings: + use_graph: false + use_vector: true + use_hybrid: false + contextual_enrichment: false + deduplication: false +server: + type: r2r + settings: + port: 7272 + use_docker: true + config_path: /Users/shreyas/new7/R2R/py/workspace/evals/configs/server/r2r/r2r_full.toml + stdout_path: r2r_stdout.log + stderr_path: r2r_stderr.log + project_name: ${name} + contextual_enrichment: ${run_settings.contextual_enrichment} + r2r_image: r2r-test +ingestion_dataset: + id: 671c1dddea3bd8151d1ceb3c +eval_dataset: + id: 671c1dddea3bd8151d1ceb3c +evaluator: + base_url: https://api.relari.com/v1 + api_key: ek-e581232b34d3d6c3a7af8079b50ab67a + project_id: 671854e7db1ad90003b5c04b + eval_dataset: '{eval_dataset}' +metrics: + metrics_list: + - answer_relevancy + - context_relevancy + - faithfulness + - answer_similarity + - latency diff --git a/outputs/2024-10-28/10-02-34/.hydra/hydra.yaml b/outputs/2024-10-28/10-02-34/.hydra/hydra.yaml new file mode 100644 index 000000000..fef701667 --- /dev/null +++ b/outputs/2024-10-28/10-02-34/.hydra/hydra.yaml @@ -0,0 +1,159 @@ +hydra: + run: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: [] + job: + name: main + chdir: true + override_dirname: '' + id: ??? + num: ??? + config_name: config.yaml + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /Users/shreyas/new7/R2R + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /Users/shreyas/new7/R2R/py/workspace/evals/configs + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /Users/shreyas/new7/R2R/outputs/2024-10-28/10-02-34 + choices: + metrics: llm + evaluator: relari + eval_dataset: aristotle + ingestion_dataset: aristotle + server: r2r + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: false diff --git a/outputs/2024-10-28/10-02-34/.hydra/overrides.yaml b/outputs/2024-10-28/10-02-34/.hydra/overrides.yaml new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/outputs/2024-10-28/10-02-34/.hydra/overrides.yaml @@ -0,0 +1 @@ +[] diff --git a/outputs/2024-10-28/10-02-34/main.log b/outputs/2024-10-28/10-02-34/main.log new file mode 100644 index 000000000..8ee85a065 --- /dev/null +++ b/outputs/2024-10-28/10-02-34/main.log @@ -0,0 +1,2 @@ +[2024-10-28 10:02:37,858][httpx][INFO] - HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK" +[2024-10-28 10:02:38,987][root][INFO] - Posthog telemetry disabled, debug mode off diff --git a/py/core/main/api/management_router.py b/py/core/main/api/management_router.py index 606c943e9..f4ee8cd5e 100644 --- a/py/core/main/api/management_router.py +++ b/py/core/main/api/management_router.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from typing import Any, Optional, Set, Union from uuid import UUID - +import os import psutil from fastapi import Body, Depends, Path, Query from fastapi.responses import StreamingResponse @@ -887,3 +887,11 @@ async def delete_conversation( ) -> WrappedDeleteResponse: await self.service.delete_conversation(conversation_id) return None # type: ignore + + + @self.router.get("/r2r_project_name") + @self.base_endpoint + async def r2r_project_name( + auth_user=Depends(self.service.providers.auth.auth_wrapper), + ) -> dict: + return {"project_name": os.environ["R2R_PROJECT_NAME"]} diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index ab98562b0..49cf12b0d 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -15,7 +15,7 @@ increment_version, ) from core.base.abstractions import DocumentInfo, R2RException -from core.utils import generate_default_user_collection_id +from core.utils import generate_default_user_collection_id, update_settings_from_dict from ...services import IngestionService, IngestionServiceAdapter @@ -163,15 +163,21 @@ async def parse(self, context: Context) -> dict: document_id=document_info.id, collection_id=collection_id ) - chunk_enrichment_settings = getattr( + # get server chunk enrichment settings and override parts of it if provided in the ingestion config + server_chunk_enrichment_settings = getattr( service.providers.ingestion.config, "chunk_enrichment_settings", None, ) - if chunk_enrichment_settings and getattr( - chunk_enrichment_settings, "enable_chunk_enrichment", False - ): + if server_chunk_enrichment_settings: + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: logger.info("Enriching document with contextual chunks") @@ -192,6 +198,7 @@ async def parse(self, context: Context) -> dict: await self.ingestion_service.chunk_enrichment( document_id=document_info.id, + chunk_enrichment_settings=chunk_enrichment_settings, ) await self.ingestion_service.update_document_status( @@ -243,8 +250,8 @@ async def on_failure(self, context: Context) -> None: # Update the document status to FAILED if ( - not document_info.ingestion_status - == IngestionStatus.SUCCESS + document_info.ingestion_status + not in [IngestionStatus.SUCCESS, IngestionStatus.ENRICHED] ): await self.ingestion_service.update_document_status( document_info, diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index ab222737e..90ba5c372 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -457,7 +457,17 @@ async def _get_enriched_chunk_text( for neighbor in semantic_neighbors ) - context_chunk_ids = list(set(context_chunk_ids)) + context_chunk_ids = list( + set( + [ + str(context_chunk_id) + for context_chunk_id in context_chunk_ids + ] + ) + ) + context_chunk_ids = [ + UUID(context_chunk_id) for context_chunk_id in context_chunk_ids + ] context_chunk_texts = [ ( @@ -521,13 +531,9 @@ async def _get_enriched_chunk_text( metadata=chunk["metadata"], ) - async def chunk_enrichment(self, document_id: UUID) -> int: + async def chunk_enrichment(self, document_id: UUID, chunk_enrichment_settings: ChunkEnrichmentSettings) -> int: # just call the pipe on every chunk of the document - # TODO: Why is the config not recognized as an ingestionconfig but as a providerconfig? - chunk_enrichment_settings = ( - self.providers.ingestion.config.chunk_enrichment_settings # type: ignore - ) # get all document_chunks document_chunks = ( await self.providers.database.get_document_chunks( diff --git a/py/poetry.lock b/py/poetry.lock index b191bea0a..1b42e6e41 100644 --- a/py/poetry.lock +++ b/py/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiofiles" @@ -234,6 +234,16 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.9.3" +description = "ANTLR 4.9.3 runtime for Python 3.7" +optional = false +python-versions = "*" +files = [ + {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, +] + [[package]] name = "anyio" version = "4.6.2.post1" @@ -1997,6 +2007,22 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr torch = ["safetensors[torch]", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "hydra-core" +version = "1.3.2" +description = "A framework for elegantly configuring complex applications" +optional = false +python-versions = "*" +files = [ + {file = "hydra-core-1.3.2.tar.gz", hash = "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824"}, + {file = "hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b"}, +] + +[package.dependencies] +antlr4-python3-runtime = "==4.9.*" +omegaconf = ">=2.2,<2.4" +packaging = "*" + [[package]] name = "hyperframe" version = "6.0.1" @@ -3130,6 +3156,21 @@ files = [ [package.dependencies] httpx = ">=0.27.0,<0.28.0" +[[package]] +name = "omegaconf" +version = "2.3.0" +description = "A flexible configuration library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b"}, + {file = "omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7"}, +] + +[package.dependencies] +antlr4-python3-runtime = "==4.9.*" +PyYAML = ">=5.1.0" + [[package]] name = "openai" version = "1.52.2" @@ -4696,11 +4737,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -5849,4 +5885,4 @@ ingestion-bundle = ["aiofiles", "aioshutil", "beautifulsoup4", "bs4", "markdown" [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "65a86f61c7efcbc23b9fc221cdd487e8aa4f3b3129444e0b6b75b28868f34654" +content-hash = "47faefbe9ad27015be050b5f5c931ded5e4c902a0b8e9cadba415e69dcea72fe" diff --git a/py/pyproject.toml b/py/pyproject.toml index 4f19cdbd7..a847f35fd 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -80,6 +80,7 @@ pypdf = { version = "^4.2.0", optional = true } pypdf2 = { version = "^3.0.1", optional = true } python-pptx = { version = "^1.0.1", optional = true } python-docx = { version = "^1.1.0", optional = true } +hydra-core = "^1.3.2" [tool.poetry.extras] core = [ From 080d8cb0f6e0b3fb3b75e35487c4ed87f01ed39d Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:55:05 -0700 Subject: [PATCH 06/12] cleanup (#1512) * cleanup * cleanup prompt mgmt * up * cleanup printout * cleanup new parser logic, set vlm as default for all providers * allow user to re-override --- .../actions/run-script-zerox-tests/action.yml | 14 - .../r2r-full-integration-deep-dive-tests.yml | 3 - docs/cookbooks/graphrag.mdx | 6 - py/core/__init__.py | 4 +- py/core/base/parsers/base_parser.py | 7 +- py/core/base/providers/ingestion.py | 28 +- py/core/base/providers/llm.py | 2 + py/core/configs/full.toml | 3 - .../examples/scripts/run_ingest_with_zerox.py | 15 - py/core/main/assembly/factory.py | 40 +- py/core/parsers/__init__.py | 5 +- py/core/parsers/media/__init__.py | 10 +- py/core/parsers/media/audio_parser.py | 82 ++++- py/core/parsers/media/docx_parser.py | 16 +- py/core/parsers/media/img_parser.py | 128 +++++-- py/core/parsers/media/openai_helpers.py | 63 ---- py/core/parsers/media/pdf_parser.py | 343 +++++++++++------- py/core/parsers/media/ppt_parser.py | 15 +- py/core/parsers/media/pyzerox/__init__.py | 10 - .../media/pyzerox/constants/__init__.py | 9 - .../media/pyzerox/constants/conversion.py | 8 - .../media/pyzerox/constants/messages.py | 56 --- .../media/pyzerox/constants/patterns.py | 6 - .../media/pyzerox/constants/prompts.py | 8 - .../parsers/media/pyzerox/errors/__init__.py | 21 -- py/core/parsers/media/pyzerox/errors/base.py | 21 -- .../media/pyzerox/errors/exceptions.py | 93 ----- .../media/pyzerox/processor/__init__.py | 14 - .../parsers/media/pyzerox/processor/image.py | 27 -- .../parsers/media/pyzerox/processor/pdf.py | 115 ------ .../parsers/media/pyzerox/processor/text.py | 14 - .../parsers/media/pyzerox/processor/utils.py | 52 --- .../media/pyzerox/zerox_core/__init__.py | 5 - .../parsers/media/pyzerox/zerox_core/types.py | 42 --- .../parsers/media/pyzerox/zerox_core/zerox.py | 151 -------- .../media/pyzerox/zerox_models/__init__.py | 7 - .../media/pyzerox/zerox_models/base.py | 43 --- .../pyzerox/zerox_models/modellitellm.py | 169 --------- .../media/pyzerox/zerox_models/types.py | 12 - py/core/parsers/structured/csv_parser.py | 23 +- py/core/parsers/structured/json_parser.py | 15 + py/core/parsers/structured/xlsx_parser.py | 21 +- py/core/parsers/text/html_parser.py | 15 + py/core/parsers/text/md_parser.py | 16 +- py/core/parsers/text/text_parser.py | 15 + py/core/providers/auth/r2r_auth.py | 2 +- .../database/prompts/vision_img.yaml | 4 + .../database/prompts/vision_pdf.yaml | 42 +++ py/core/providers/email/smtp.py | 18 +- py/core/providers/ingestion/r2r/base.py | 47 ++- .../providers/ingestion/unstructured/base.py | 82 +++-- py/poetry.lock | 13 +- py/pyproject.toml | 1 + py/r2r.toml | 9 +- py/shared/abstractions/document.py | 2 + 55 files changed, 737 insertions(+), 1255 deletions(-) delete mode 100644 .github/actions/run-script-zerox-tests/action.yml delete mode 100644 py/core/examples/scripts/run_ingest_with_zerox.py delete mode 100644 py/core/parsers/media/openai_helpers.py delete mode 100644 py/core/parsers/media/pyzerox/__init__.py delete mode 100644 py/core/parsers/media/pyzerox/constants/__init__.py delete mode 100644 py/core/parsers/media/pyzerox/constants/conversion.py delete mode 100644 py/core/parsers/media/pyzerox/constants/messages.py delete mode 100644 py/core/parsers/media/pyzerox/constants/patterns.py delete mode 100644 py/core/parsers/media/pyzerox/constants/prompts.py delete mode 100644 py/core/parsers/media/pyzerox/errors/__init__.py delete mode 100644 py/core/parsers/media/pyzerox/errors/base.py delete mode 100644 py/core/parsers/media/pyzerox/errors/exceptions.py delete mode 100644 py/core/parsers/media/pyzerox/processor/__init__.py delete mode 100644 py/core/parsers/media/pyzerox/processor/image.py delete mode 100644 py/core/parsers/media/pyzerox/processor/pdf.py delete mode 100644 py/core/parsers/media/pyzerox/processor/text.py delete mode 100644 py/core/parsers/media/pyzerox/processor/utils.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_core/__init__.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_core/types.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_core/zerox.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_models/__init__.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_models/base.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_models/modellitellm.py delete mode 100644 py/core/parsers/media/pyzerox/zerox_models/types.py create mode 100644 py/core/providers/database/prompts/vision_img.yaml create mode 100644 py/core/providers/database/prompts/vision_pdf.yaml diff --git a/.github/actions/run-script-zerox-tests/action.yml b/.github/actions/run-script-zerox-tests/action.yml deleted file mode 100644 index ea15d49e1..000000000 --- a/.github/actions/run-script-zerox-tests/action.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: 'Run SDK Auth Tests' -description: 'Runs SDK authentication tests for R2R' -runs: - using: "composite" - steps: - - name: Ingest zerox document - working-directory: ./py - shell: bash - run: poetry run python core/examples/scripts/run_ingest_with_zerox.py - - - name: Test ingested zerox document - working-directory: ./py - shell: bash - run: poetry run python tests/integration/runner_scripts.py test_ingested_zerox_document diff --git a/.github/workflows/r2r-full-integration-deep-dive-tests.yml b/.github/workflows/r2r-full-integration-deep-dive-tests.yml index 4021478d4..a4391820d 100644 --- a/.github/workflows/r2r-full-integration-deep-dive-tests.yml +++ b/.github/workflows/r2r-full-integration-deep-dive-tests.yml @@ -36,6 +36,3 @@ jobs: - name: Start R2R Full server uses: ./.github/actions/start-r2r-full - - - name: Run Test Zerox - uses: ./.github/actions/run-script-zerox-tests diff --git a/docs/cookbooks/graphrag.mdx b/docs/cookbooks/graphrag.mdx index 0d8437edf..6ebc6a14c 100644 --- a/docs/cookbooks/graphrag.mdx +++ b/docs/cookbooks/graphrag.mdx @@ -99,9 +99,6 @@ excluded_parsers = ["mp4"] semantic_similarity_threshold = 0.7 generation_config = { model = "openai/gpt-4o-mini" } - [ingestion.extra_parsers] - pdf = "zerox" - [database] provider = "postgres" batch_size = 256 @@ -204,9 +201,6 @@ max_characters = 1_024 combine_under_n_chars = 128 overlap = 256 - [ingestion.extra_parsers] - pdf = "zerox" - [orchestration] provider = "hatchet" kg_creation_concurrency_lipmit = 32 diff --git a/py/core/__init__.py b/py/core/__init__.py index babf7f1f8..d289188e7 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -195,9 +195,9 @@ "AudioParser", "DOCXParser", "ImageParser", - "PDFParser", + "VLMPDFParser", + "BasicPDFParser", "PDFParserUnstructured", - "PDFParserMarker", "PPTParser", # Structured parsers "CSVParser", diff --git a/py/core/base/parsers/base_parser.py b/py/core/base/parsers/base_parser.py index d0bc8633c..1de600404 100644 --- a/py/core/base/parsers/base_parser.py +++ b/py/core/base/parsers/base_parser.py @@ -3,14 +3,11 @@ from abc import ABC, abstractmethod from typing import AsyncGenerator, Generic, TypeVar -from ..abstractions import DataType - T = TypeVar("T") class AsyncParser(ABC, Generic[T]): + @abstractmethod - async def ingest( - self, data: T, **kwargs - ) -> AsyncGenerator[DataType, None]: + async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]: pass diff --git a/py/core/base/providers/ingestion.py b/py/core/base/providers/ingestion.py index 6e80b51e8..feb0349a5 100644 --- a/py/core/base/providers/ingestion.py +++ b/py/core/base/providers/ingestion.py @@ -1,10 +1,13 @@ import logging from abc import ABC from enum import Enum +from typing import Optional from core.base.abstractions import ChunkEnrichmentSettings from .base import Provider, ProviderConfig +from .database import DatabaseProvider +from .llm import CompletionProvider logger = logging.getLogger() @@ -15,7 +18,14 @@ class IngestionConfig(ProviderConfig): chunk_enrichment_settings: ChunkEnrichmentSettings = ( ChunkEnrichmentSettings() ) - extra_parsers: dict[str, str] = {} + + audio_transcription_model: str + + vision_img_prompt_name: Optional[str] = None + vision_img_model: str + + vision_pdf_prompt_name: Optional[str] = None + vision_pdf_model: str @property def supported_providers(self) -> list[str]: @@ -27,7 +37,21 @@ def validate_config(self) -> None: class IngestionProvider(Provider, ABC): - pass + + config: IngestionConfig + database_provider: DatabaseProvider + llm_provider: CompletionProvider + + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + super().__init__(config) + self.config: IngestionConfig = config + self.llm_provider = llm_provider + self.database_provider = database_provider class ChunkingStrategy(str, Enum): diff --git a/py/core/base/providers/llm.py b/py/core/base/providers/llm.py index 41cb00a21..cd213d208 100644 --- a/py/core/base/providers/llm.py +++ b/py/core/base/providers/llm.py @@ -147,6 +147,8 @@ async def aget_completion( "generation_config": generation_config, "kwargs": kwargs, } + if modalities := kwargs.get("modalities"): + task["modalities"] = modalities response = await self._execute_with_backoff_async(task) return LLMChatCompletion(**response.dict()) diff --git a/py/core/configs/full.toml b/py/core/configs/full.toml index 3d397527e..b6ec46b00 100644 --- a/py/core/configs/full.toml +++ b/py/core/configs/full.toml @@ -7,9 +7,6 @@ max_characters = 1_024 combine_under_n_chars = 128 overlap = 256 - [ingestion.extra_parsers] - pdf = "zerox" - [orchestration] provider = "hatchet" kg_creation_concurrency_lipmit = 32 diff --git a/py/core/examples/scripts/run_ingest_with_zerox.py b/py/core/examples/scripts/run_ingest_with_zerox.py deleted file mode 100644 index 41aba6adf..000000000 --- a/py/core/examples/scripts/run_ingest_with_zerox.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -import time - -from r2r import R2RClient - -if __name__ == "__main__": - client = R2RClient(base_url="http://localhost:7272") - script_path = os.path.dirname(__file__) - sample_file = os.path.join(script_path, "..", "data", "graphrag.pdf") - - ingest_response = client.ingest_files( - file_paths=[sample_file], - ingestion_config={"parser_overrides": {"pdf": "zerox"}}, - ) - time.sleep(60) diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index caaf0f76f..461e9abf4 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -89,7 +89,13 @@ def create_crypto_provider( @staticmethod def create_ingestion_provider( - ingestion_config: IngestionConfig, *args, **kwargs + ingestion_config: IngestionConfig, + database_provider: PostgresDBProvider, + llm_provider: Union[ + LiteLLMCompletionProvider, OpenAICompletionProvider + ], + *args, + **kwargs, ) -> Union[R2RIngestionProvider, UnstructuredIngestionProvider]: config_dict = ( @@ -104,7 +110,9 @@ def create_ingestion_provider( r2r_ingestion_config = R2RIngestionConfig( **config_dict, **extra_fields ) - return R2RIngestionProvider(r2r_ingestion_config) + return R2RIngestionProvider( + r2r_ingestion_config, database_provider, llm_provider + ) elif config_dict["provider"] in [ "unstructured_local", "unstructured_api", @@ -114,7 +122,7 @@ def create_ingestion_provider( ) return UnstructuredIngestionProvider( - unstructured_ingestion_config, + unstructured_ingestion_config, database_provider, llm_provider ) else: raise ValueError( @@ -217,10 +225,12 @@ def create_llm_provider( @staticmethod async def create_email_provider( email_config: Optional[EmailConfig] = None, *args, **kwargs - ) -> Optional[Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]]: + ) -> Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]: """Creates an email provider based on configuration.""" if not email_config: - return None + raise ValueError( + f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." + ) if email_config.provider == "smtp": return AsyncSMTPEmailProvider(email_config) @@ -263,21 +273,14 @@ async def create_providers( self.config.embedding, *args, **kwargs ) ) - ingestion_provider = ( - ingestion_provider_override - or self.create_ingestion_provider( - self.config.ingestion, *args, **kwargs - ) - ) - llm_provider = llm_provider_override or self.create_llm_provider( self.config.completion, *args, **kwargs ) + crypto_provider = ( crypto_provider_override or self.create_crypto_provider(self.config.crypto, *args, **kwargs) ) - database_provider = ( database_provider_override or await self.create_database_provider( @@ -285,6 +288,17 @@ async def create_providers( ) ) + ingestion_provider = ( + ingestion_provider_override + or self.create_ingestion_provider( + self.config.ingestion, + database_provider, + llm_provider, + *args, + **kwargs, + ) + ) + email_provider = ( email_provider_override or await self.create_email_provider( diff --git a/py/core/parsers/__init__.py b/py/core/parsers/__init__.py index 0439c320b..2915f1ab2 100644 --- a/py/core/parsers/__init__.py +++ b/py/core/parsers/__init__.py @@ -7,9 +7,10 @@ "AudioParser", "DOCXParser", "ImageParser", - "PDFParser", + "VLMPDFParser", + "BasicPDFParser", "PDFParserUnstructured", - "PDFParserMarker", + "VLMPDFParser", "PPTParser", # Structured parsers "CSVParser", diff --git a/py/core/parsers/media/__init__.py b/py/core/parsers/media/__init__.py index 71c9bf0df..38881e171 100644 --- a/py/core/parsers/media/__init__.py +++ b/py/core/parsers/media/__init__.py @@ -2,10 +2,9 @@ from .docx_parser import DOCXParser from .img_parser import ImageParser from .pdf_parser import ( # type: ignore - PDFParser, - PDFParserMarker, + BasicPDFParser, PDFParserUnstructured, - ZeroxPDFParser, + VLMPDFParser, ) from .ppt_parser import PPTParser @@ -13,9 +12,8 @@ "AudioParser", "DOCXParser", "ImageParser", - "PDFParser", + "VLMPDFParser", + "BasicPDFParser", "PDFParserUnstructured", - "ZeroxPDFParser", - "PDFParserMarker", "PPTParser", ] diff --git a/py/core/parsers/media/audio_parser.py b/py/core/parsers/media/audio_parser.py index a8026c1af..1d4421d37 100644 --- a/py/core/parsers/media/audio_parser.py +++ b/py/core/parsers/media/audio_parser.py @@ -1,35 +1,81 @@ +import base64 +import logging import os +import tempfile from typing import AsyncGenerator from core.base.parsers.base_parser import AsyncParser -from core.parsers.media.openai_helpers import process_audio_with_openai +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) + +logger = logging.getLogger() class AudioParser(AsyncParser[bytes]): - """A parser for audio data.""" + """A parser for audio data using Whisper transcription.""" def __init__( - self, api_base: str = "https://api.openai.com/v1/audio/transcriptions" + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, ): - self.api_base = api_base - self.openai_api_key = os.environ.get("OPENAI_API_KEY") + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + try: + from litellm import atranscription + + self.atranscription = atranscription + except ImportError: + logger.error("Failed to import LiteLLM transcription") + raise ImportError( + "Please install the `litellm` package to use the AudioParser." + ) async def ingest( # type: ignore - self, data: bytes, chunk_size: int = 1024, *args, **kwargs + self, data: bytes, **kwargs ) -> AsyncGenerator[str, None]: - """Ingest audio data and yield a transcription.""" - temp_audio_path = "temp_audio.wav" - with open(temp_audio_path, "wb") as f: - f.write(data) + """ + Ingest audio data and yield a transcription using Whisper via LiteLLM. + + Args: + data: Raw audio bytes + chunk_size: Size of text chunks to yield + model: The model to use for transcription (default is whisper-1) + *args, **kwargs: Additional arguments passed to the transcription call + + Yields: + Chunks of transcribed text + """ try: - transcription_text = process_audio_with_openai( - open(temp_audio_path, "rb"), self.openai_api_key # type: ignore + # Create a temporary file to store the audio data + with tempfile.NamedTemporaryFile( + suffix=".wav", delete=False + ) as temp_file: + temp_file.write(data) + temp_file_path = temp_file.name + + # Call Whisper transcription + response = await self.atranscription( + model=self.config.audio_transcription_model, + file=open(temp_file_path, "rb"), + **kwargs, ) - # split text into small chunks and yield them - for i in range(0, len(transcription_text), chunk_size): - text = transcription_text[i : i + chunk_size] - if text and text != "": - yield text + # The response should contain the transcribed text directly + yield response.text + + except Exception as e: + logger.error(f"Error processing audio with Whisper: {str(e)}") + raise + finally: - os.remove(temp_audio_path) + # Clean up the temporary file + try: + os.unlink(temp_file_path) + except Exception as e: + logger.warning(f"Failed to delete temporary file: {str(e)}") diff --git a/py/core/parsers/media/docx_parser.py b/py/core/parsers/media/docx_parser.py index 21272e1b2..86c242115 100644 --- a/py/core/parsers/media/docx_parser.py +++ b/py/core/parsers/media/docx_parser.py @@ -3,12 +3,26 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class DOCXParser(AsyncParser[DataType]): """A parser for DOCX data.""" - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + try: from docx import Document diff --git a/py/core/parsers/media/img_parser.py b/py/core/parsers/media/img_parser.py index 206c9160f..d4276b718 100644 --- a/py/core/parsers/media/img_parser.py +++ b/py/core/parsers/media/img_parser.py @@ -1,50 +1,112 @@ import base64 import logging -import os from typing import AsyncGenerator -from core.base.abstractions import DataType +from core.base.abstractions import DataType, GenerationConfig from core.base.parsers.base_parser import AsyncParser -from core.parsers.media.openai_helpers import process_frame_with_openai +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) logger = logging.getLogger() class ImageParser(AsyncParser[DataType]): - """A parser for image data.""" + """A parser for image data using vision models.""" + + DEFAULT_IMG_VISION_PROMPT_NAME = "vision_img" def __init__( self, - model: str = "gpt-4o-mini", - max_tokens: int = 2_048, - api_base: str = "https://api.openai.com/v1/chat/completions", - max_image_size: int = 1 * 1024 * 1024, # 4MB limit + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, ): - self.model = model - self.max_tokens = max_tokens - self.openai_api_key = os.environ.get("OPENAI_API_KEY") - self.api_base = api_base - self.max_image_size = max_image_size + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + self.vision_prompt_text = None + + try: + from litellm import supports_vision + + self.supports_vision = supports_vision + except ImportError: + logger.error("Failed to import LiteLLM vision support") + raise ImportError( + "Please install the `litellm` package to use the ImageParser." + ) async def ingest( # type: ignore - self, data: DataType, chunk_size: int = 1024, *args, **kwargs + self, data: DataType, **kwargs ) -> AsyncGenerator[str, None]: - """Ingest image data and yield a description.""" - - if isinstance(data, bytes): - # Encode to base64 - data = base64.b64encode(data).decode("utf-8") - - openai_text = process_frame_with_openai( - data, # type: ignore - self.openai_api_key, # type: ignore - self.model, - self.max_tokens, - self.api_base, - ) - - # split text into small chunks and yield them - for i in range(0, len(openai_text), chunk_size): - text = openai_text[i : i + chunk_size] - if text and text != "": - yield text + """ + Ingest image data and yield a description using vision model. + + Args: + data: Image data (bytes or base64 string) + chunk_size: Size of text chunks to yield + *args, **kwargs: Additional arguments passed to the completion call + + Yields: + Chunks of image description text + """ + if not self.vision_prompt_text: + self.vision_prompt_text = await self.database_provider.get_prompt( # type: ignore + prompt_name=self.config.vision_img_prompt_name + or self.DEFAULT_IMG_VISION_PROMPT_NAME + ) + try: + # Verify model supports vision + if not self.supports_vision(model=self.config.vision_img_model): + raise ValueError( + f"Model {self.config.vision_img_model} does not support vision" + ) + + # Encode image data if needed + if isinstance(data, bytes): + image_data = base64.b64encode(data).decode("utf-8") + else: + image_data = data + + # Configure the generation parameters + generation_config = GenerationConfig( + model=self.config.vision_img_model, + stream=False, + ) + + # Prepare message with image + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": self.vision_prompt_text}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + } + ] + + # Get completion from LiteLLM provider + response = await self.llm_provider.aget_completion( + messages=messages, generation_config=generation_config + ) + + # Extract description from response + if response.choices and response.choices[0].message: + content = response.choices[0].message.content + if not content: + raise ValueError("No content in response") + yield content + else: + raise ValueError("No response content") + + except Exception as e: + logger.error(f"Error processing image with vision model: {str(e)}") + raise diff --git a/py/core/parsers/media/openai_helpers.py b/py/core/parsers/media/openai_helpers.py deleted file mode 100644 index 729426a63..000000000 --- a/py/core/parsers/media/openai_helpers.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Implementations of parsers for different data types.""" - -import logging - -import requests - -logger = logging.getLogger() - - -def process_frame_with_openai( - data: bytes, - api_key: str, - model: str = "gpt-4o", - max_tokens: int = 2_048, - api_base: str = "https://api.openai.com/v1/chat/completions", -) -> str: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } - - payload = { - "model": model, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "First, provide a title for the image, then explain everything that you see. Be very thorough in your analysis as a user will need to understand the image without seeing it. If it is possible to transcribe the image to text directly, then do so. The more detail you provide, the better the user will understand the image.", - }, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{data}"}, # type: ignore - }, - ], - } - ], - "max_tokens": max_tokens, - } - - response = requests.post(api_base, headers=headers, json=payload) - response_json = response.json() - return response_json["choices"][0]["message"]["content"] - - -def process_audio_with_openai( - audio_file, - api_key: str, - audio_api_base: str = "https://api.openai.com/v1/audio/transcriptions", -) -> str: - headers = {"Authorization": f"Bearer {api_key}"} - - transcription_response = requests.post( - audio_api_base, - headers=headers, - files={"file": audio_file}, - data={"model": "whisper-1"}, - ) - - transcription = transcription_response.json() - - return transcription["text"] diff --git a/py/core/parsers/media/pdf_parser.py b/py/core/parsers/media/pdf_parser.py index 2372c7fc8..61ec1a11b 100644 --- a/py/core/parsers/media/pdf_parser.py +++ b/py/core/parsers/media/pdf_parser.py @@ -1,5 +1,6 @@ # type: ignore import asyncio +import base64 import logging import os import string @@ -7,17 +8,219 @@ from io import BytesIO from typing import AsyncGenerator -from core.base.abstractions import DataType +import aiofiles +from pdf2image import convert_from_path + +from core.base.abstractions import DataType, GenerationConfig from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) logger = logging.getLogger() -ZEROX_DEFAULT_MODEL = "openai/gpt-4o-mini" -class PDFParser(AsyncParser[DataType]): +class VLMPDFParser(AsyncParser[DataType]): + """A parser for PDF documents using vision models for page processing.""" + + DEFAULT_PDF_VISION_PROMPT_NAME = "vision_pdf" + + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + self.vision_prompt_text = None + + try: + from litellm import supports_vision + + self.supports_vision = supports_vision + except ImportError: + logger.error("Failed to import LiteLLM vision support") + raise ImportError( + "Please install the `litellm` package to use the VLMPDFParser." + ) + + async def convert_pdf_to_images( + self, pdf_path: str, temp_dir: str + ) -> list[str]: + """Convert PDF pages to images asynchronously.""" + options = { + "pdf_path": pdf_path, + "output_folder": temp_dir, + "dpi": 300, # Configurable via config if needed + "fmt": "jpeg", + "thread_count": 4, + "paths_only": True, + } + try: + image_paths = await asyncio.to_thread(convert_from_path, **options) + return image_paths + except Exception as err: + logger.error(f"Error converting PDF to images: {err}") + raise + + async def process_page( + self, image_path: str, page_num: int + ) -> dict[str, str]: + """Process a single PDF page using the vision model.""" + + try: + # Read and encode image + async with aiofiles.open(image_path, "rb") as image_file: + image_data = await image_file.read() + image_base64 = base64.b64encode(image_data).decode("utf-8") + + # Verify model supports vision + if not self.supports_vision(model=self.config.vision_pdf_model): + raise ValueError( + f"Model {self.config.vision_pdf_model} does not support vision" + ) + + # Configure generation parameters + generation_config = GenerationConfig( + model=self.config.vision_pdf_model, + stream=False, + ) + + # Prepare message with image + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": self.vision_prompt_text}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + }, + ], + } + ] + + # Get completion from LiteLLM provider + response = await self.llm_provider.aget_completion( + messages=messages, generation_config=generation_config + ) + + if response.choices and response.choices[0].message: + content = response.choices[0].message.content + if not content: + raise ValueError("No content in response") + return {"page": str(page_num), "content": content} + else: + raise ValueError("No response content") + + except Exception as e: + logger.error( + f"Error processing page {page_num} with vision model: {str(e)}" + ) + raise + + async def ingest( + self, data: DataType, maintain_order: bool = False, **kwargs + ) -> AsyncGenerator[dict[str, str], None]: + """ + Ingest PDF data and yield descriptions for each page using vision model. + + Args: + data: PDF file path or bytes + maintain_order: If True, yields results in page order. If False, yields as completed. + **kwargs: Additional arguments passed to the completion call + + Yields: + Dict containing page number and content for each processed page + """ + if not self.vision_prompt_text: + self.vision_prompt_text = await self.database_provider.get_prompt( # type: ignore + prompt_name=self.config.vision_pdf_prompt_name + or self.DEFAULT_PDF_VISION_PROMPT_NAME + ) + + temp_dir = None + try: + # Create temporary directory for image processing + temp_dir = os.path.join(os.getcwd(), "temp_pdf_images") + os.makedirs(temp_dir, exist_ok=True) + + # Handle both file path and bytes input + if isinstance(data, bytes): + pdf_path = os.path.join(temp_dir, "temp.pdf") + async with aiofiles.open(pdf_path, "wb") as f: + await f.write(data) + else: + pdf_path = data + + # Convert PDF to images + image_paths = await self.convert_pdf_to_images(pdf_path, temp_dir) + # Create tasks for all pages + tasks = { + asyncio.create_task( + self.process_page(image_path, page_num) + ): page_num + for page_num, image_path in enumerate(image_paths, 1) + } + + if maintain_order: + # Store results in order + pending = set(tasks.keys()) + results = {} + next_page = 1 + + while pending: + # Get next completed task + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + + # Process completed tasks + for task in done: + result = await task + page_num = int(result["page"]) + results[page_num] = result + + # Yield results in order + while next_page in results: + yield results.pop(next_page)["content"] + next_page += 1 + else: + # Yield results as they complete + for coro in asyncio.as_completed(tasks.keys()): + result = await coro + yield result["content"] + + except Exception as e: + logger.error(f"Error processing PDF: {str(e)}") + raise + + finally: + # Cleanup temporary files + if temp_dir and os.path.exists(temp_dir): + for file in os.listdir(temp_dir): + os.remove(os.path.join(temp_dir, file)) + os.rmdir(temp_dir) + + +class BasicPDFParser(AsyncParser[DataType]): """A parser for PDF data.""" - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config try: from pypdf import PdfReader @@ -65,54 +268,16 @@ async def ingest( yield page_text -class PDFParserSix(AsyncParser[DataType]): - """A parser for PDF data.""" - - def __init__(self): - try: - from pdfminer.high_level import extract_text_to_fp - from pdfminer.layout import LAParams - - self.extract_text_to_fp = extract_text_to_fp - self.LAParams = LAParams - except ImportError: - raise ValueError( - "Error, `pdfminer.six` is required to run `PDFParser`. Please install it using `pip install pdfminer.six`." - ) - - async def ingest(self, data: bytes, **kwargs) -> AsyncGenerator[str, None]: - """Ingest PDF data and yield text from each page.""" - if not isinstance(data, bytes): - raise ValueError("PDF data must be in bytes format.") - - pdf_file = BytesIO(data) - - async def process_page(page_number): - output = BytesIO() - await asyncio.to_thread( - self.extract_text_to_fp, - pdf_file, - output, - page_numbers=[page_number], - laparams=self.LAParams(), - ) - page_text = output.getvalue().decode("utf-8") - return "".join(filter(lambda x: x in string.printable, page_text)) - - from pdfminer.pdfdocument import PDFDocument - from pdfminer.pdfparser import PDFParser as pdfminer_PDFParser - - parser = pdfminer_PDFParser(pdf_file) - document = PDFDocument(parser) - - for page_number in range(len(list(document.get_pages()))): - page_text = await process_page(page_number) - if page_text: - yield page_text - - class PDFParserUnstructured(AsyncParser[DataType]): - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config try: from unstructured.partition.pdf import partition_pdf @@ -141,79 +306,3 @@ async def ingest( ) for element in elements: yield element.text - - -class PDFParserMarker(AsyncParser[DataType]): - model_refs = None - - def __init__(self): - try: - from marker.convert import convert_single_pdf - from marker.models import load_all_models - - self.convert_single_pdf = convert_single_pdf - if PDFParserMarker.model_refs is None: - PDFParserMarker.model_refs = load_all_models() - - except ImportError as e: - raise ValueError( - f"Error, marker is not installed {e}, please install using `pip install marker-pdf` " - ) - - async def ingest( - self, data: DataType, **kwargs - ) -> AsyncGenerator[str, None]: - if isinstance(data, str): - raise ValueError("PDF data must be in bytes format.") - - text, _, _ = self.convert_single_pdf( - BytesIO(data), PDFParserMarker.model_refs - ) - yield text - - -class ZeroxPDFParser(AsyncParser[DataType]): - """An advanced PDF parser using zerox.""" - - def __init__(self): - """ - Use the zerox library to parse PDF data. - - Args: - cleanup (bool, optional): Whether to clean up temporary files after processing. Defaults to True. - concurrency (int, optional): The number of concurrent processes to run. Defaults to 10. - file_data (Optional[str], optional): The file data to process. Defaults to an empty string. - maintain_format (bool, optional): Whether to maintain the format from the previous page. Defaults to False. - model (str, optional): The model to use for generating completions. Defaults to "gpt-4o-mini". Refer to LiteLLM Providers for the correct model name, as it may differ depending on the provider. - temp_dir (str, optional): The directory to store temporary files, defaults to some named folder in system's temp directory. If already exists, the contents will be deleted before zerox uses it. - custom_system_prompt (str, optional): The system prompt to use for the model, this overrides the default system prompt of zerox.Generally it is not required unless you want some specific behaviour. When set, it will raise a friendly warning. Defaults to None. - kwargs (dict, optional): Additional keyword arguments to pass to the litellm.completion method. Refer to the LiteLLM Documentation and Completion Input for details. - - """ - try: - # from pyzerox import zerox - from .pyzerox import zerox - - self.zerox = zerox - - except ImportError as e: - raise ValueError( - f"Error, zerox installation failed with Error='{e}', please install through the R2R ingestion bundle with `pip install r2r -E ingestion-bundle` " - ) - - async def ingest( - self, data: DataType, **kwargs - ) -> AsyncGenerator[str, None]: - if isinstance(data, str): - raise ValueError("PDF data must be in bytes format.") - - model = kwargs.get("zerox_parsing_model", ZEROX_DEFAULT_MODEL) - model = model.split("/")[-1] # remove the provider prefix - result = await self.zerox( - file_data=data, - model=model, - verbose=True, - ) - - for page in result.pages: - yield page.content diff --git a/py/core/parsers/media/ppt_parser.py b/py/core/parsers/media/ppt_parser.py index 5f19a0171..6fa8f52e9 100644 --- a/py/core/parsers/media/ppt_parser.py +++ b/py/core/parsers/media/ppt_parser.py @@ -3,12 +3,25 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class PPTParser(AsyncParser[DataType]): """A parser for PPT data.""" - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config try: from pptx import Presentation diff --git a/py/core/parsers/media/pyzerox/__init__.py b/py/core/parsers/media/pyzerox/__init__.py deleted file mode 100644 index 18cd95ac3..000000000 --- a/py/core/parsers/media/pyzerox/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .constants.prompts import Prompts -from .zerox_core import zerox - -DEFAULT_SYSTEM_PROMPT = Prompts.DEFAULT_SYSTEM_PROMPT - -__all__ = [ - "zerox", - "Prompts", - "DEFAULT_SYSTEM_PROMPT", -] diff --git a/py/core/parsers/media/pyzerox/constants/__init__.py b/py/core/parsers/media/pyzerox/constants/__init__.py deleted file mode 100644 index 4378b38e0..000000000 --- a/py/core/parsers/media/pyzerox/constants/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .conversion import PDFConversionDefaultOptions -from .messages import Messages -from .prompts import Prompts - -__all__ = [ - "PDFConversionDefaultOptions", - "Messages", - "Prompts", -] diff --git a/py/core/parsers/media/pyzerox/constants/conversion.py b/py/core/parsers/media/pyzerox/constants/conversion.py deleted file mode 100644 index 4320e3484..000000000 --- a/py/core/parsers/media/pyzerox/constants/conversion.py +++ /dev/null @@ -1,8 +0,0 @@ -class PDFConversionDefaultOptions: - """Default options for converting PDFs to images""" - - DPI = 300 - FORMAT = "png" - SIZE = (None, 1056) - THREAD_COUNT = 4 - USE_PDFTOCAIRO = True diff --git a/py/core/parsers/media/pyzerox/constants/messages.py b/py/core/parsers/media/pyzerox/constants/messages.py deleted file mode 100644 index ffa3f68ec..000000000 --- a/py/core/parsers/media/pyzerox/constants/messages.py +++ /dev/null @@ -1,56 +0,0 @@ -class Messages: - """User-facing messages""" - - MISSING_ENVIRONMENT_VARIABLES = """ - Required environment variable (keys) from the model are Missing. Please set the required environment variables for the model provider. - Refer: https://docs.litellm.ai/docs/providers - """ - - NON_VISION_MODEL = """ - The provided model is not a vision model. Please provide a vision model. - """ - - MODEL_ACCESS_ERROR = """ - Your provided model can't be accessed. Please make sure you have access to the model and also required environment variables are setup correctly including valid api key(s). - Refer: https://docs.litellm.ai/docs/providers - """ - - CUSTOM_SYSTEM_PROMPT_WARNING = """ - Custom system prompt was provided which overrides the default system prompt. We assume that you know what you are doing. - """ - - MAINTAIN_FORMAT_SELECTED_PAGES_WARNING = """ - The maintain_format flag is set to True in conjunction with select_pages input given. This may result in unexpected behavior. - """ - - PAGE_NUMBER_OUT_OF_BOUND_ERROR = """ - The page number(s) provided is out of bound. Please provide a valid page number(s). - """ - - NON_200_RESPONSE = """ - Model API returned status code {status_code}: {data} - - Please check the litellm documentation for more information. https://docs.litellm.ai/docs/exception_mapping. - """ - - COMPLETION_ERROR = """ - Error in Completion Response. Error: {0} - Please check the status of your model provider API status. - """ - - PDF_CONVERSION_FAILED = """ - Error during PDF conversion: {0} - Please check the PDF file and try again. For more information: https://github.com/Belval/pdf2image - """ - - FILE_UNREACHAGBLE = """ - File not found or unreachable. Status Code: {0} - """ - - FILE_PATH_MISSING = """ - File path is invalid or missing. - """ - - FAILED_TO_SAVE_FILE = """Failed to save file to local drive""" - - FAILED_TO_PROCESS_IMAGE = """Failed to process image""" diff --git a/py/core/parsers/media/pyzerox/constants/patterns.py b/py/core/parsers/media/pyzerox/constants/patterns.py deleted file mode 100644 index 6be1a77e1..000000000 --- a/py/core/parsers/media/pyzerox/constants/patterns.py +++ /dev/null @@ -1,6 +0,0 @@ -class Patterns: - """Regex patterns for markdown and code blocks""" - - MATCH_MARKDOWN_BLOCKS = r"^```[a-z]*\n([\s\S]*?)\n```$" - - MATCH_CODE_BLOCKS = r"^```\n([\s\S]*?)\n```$" diff --git a/py/core/parsers/media/pyzerox/constants/prompts.py b/py/core/parsers/media/pyzerox/constants/prompts.py deleted file mode 100644 index a59680a37..000000000 --- a/py/core/parsers/media/pyzerox/constants/prompts.py +++ /dev/null @@ -1,8 +0,0 @@ -class Prompts: - """Class for storing prompts for the Zerox system.""" - - DEFAULT_SYSTEM_PROMPT = """ - Convert the following PDF page to markdown. - Return only the markdown with no explanation text. - Do not exclude any content from the page. - """ diff --git a/py/core/parsers/media/pyzerox/errors/__init__.py b/py/core/parsers/media/pyzerox/errors/__init__.py deleted file mode 100644 index 7fa7bedd3..000000000 --- a/py/core/parsers/media/pyzerox/errors/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from .exceptions import ( - FailedToProcessFile, - FailedToSaveFile, - FileUnavailable, - MissingEnvironmentVariables, - ModelAccessError, - NotAVisionModel, - PageNumberOutOfBoundError, - ResourceUnreachableException, -) - -__all__ = [ - "NotAVisionModel", - "ModelAccessError", - "PageNumberOutOfBoundError", - "MissingEnvironmentVariables", - "ResourceUnreachableException", - "FileUnavailable", - "FailedToSaveFile", - "FailedToProcessFile", -] diff --git a/py/core/parsers/media/pyzerox/errors/base.py b/py/core/parsers/media/pyzerox/errors/base.py deleted file mode 100644 index f1e761141..000000000 --- a/py/core/parsers/media/pyzerox/errors/base.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional - - -class CustomException(Exception): - """ - Base class for custom exceptions - """ - - def __init__( - self, - message: Optional[str] = None, - extra_info: Optional[dict] = None, - ): - self.message = message - self.extra_info = extra_info - super().__init__(self.message) - - def __str__(self): - if self.extra_info: - return f"{self.message} (Extra Info: {self.extra_info})" - return self.message diff --git a/py/core/parsers/media/pyzerox/errors/exceptions.py b/py/core/parsers/media/pyzerox/errors/exceptions.py deleted file mode 100644 index ee90873d6..000000000 --- a/py/core/parsers/media/pyzerox/errors/exceptions.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Dict, Optional - -# Package Imports -from ..constants import Messages -from .base import CustomException - - -class MissingEnvironmentVariables(CustomException): - """Exception raised when the model provider environment variables, API key(s) are missing. Refer: https://docs.litellm.ai/docs/providers""" - - def __init__( - self, - message: str = Messages.MISSING_ENVIRONMENT_VARIABLES, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class NotAVisionModel(CustomException): - """Exception raised when the provided model is not a vision model.""" - - def __init__( - self, - message: str = Messages.NON_VISION_MODEL, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class ModelAccessError(CustomException): - """Exception raised when the provided model can't be accessed due to incorrect credentials/keys or incorrect environent variables setup.""" - - def __init__( - self, - message: str = Messages.MODEL_ACCESS_ERROR, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class PageNumberOutOfBoundError(CustomException): - """Exception invalid page number(s) provided.""" - - def __init__( - self, - message: str = Messages.PAGE_NUMBER_OUT_OF_BOUND_ERROR, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class ResourceUnreachableException(CustomException): - """Exception raised when a resource is unreachable.""" - - def __init__( - self, - message: str = Messages.FILE_UNREACHAGBLE, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class FileUnavailable(CustomException): - """Exception raised when a file is unavailable.""" - - def __init__( - self, - message: str = Messages.FILE_PATH_MISSING, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class FailedToSaveFile(CustomException): - """Exception raised when a file fails to save.""" - - def __init__( - self, - message: str = Messages.FAILED_TO_SAVE_FILE, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) - - -class FailedToProcessFile(CustomException): - """Exception raised when a file fails to process.""" - - def __init__( - self, - message: str = Messages.FAILED_TO_PROCESS_IMAGE, - extra_info: Optional[Dict] = None, - ): - super().__init__(message, extra_info) diff --git a/py/core/parsers/media/pyzerox/processor/__init__.py b/py/core/parsers/media/pyzerox/processor/__init__.py deleted file mode 100644 index 1124805e8..000000000 --- a/py/core/parsers/media/pyzerox/processor/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .image import encode_image_to_base64, save_image -from .pdf import convert_pdf_to_images, process_page, process_pages_in_batches -from .text import format_markdown -from .utils import download_file - -__all__ = [ - "save_image", - "encode_image_to_base64", - "convert_pdf_to_images", - "format_markdown", - "download_file", - "process_page", - "process_pages_in_batches", -] diff --git a/py/core/parsers/media/pyzerox/processor/image.py b/py/core/parsers/media/pyzerox/processor/image.py deleted file mode 100644 index 8ad973f4f..000000000 --- a/py/core/parsers/media/pyzerox/processor/image.py +++ /dev/null @@ -1,27 +0,0 @@ -import base64 -import io - -import aiofiles - - -async def encode_image_to_base64(image_path: str) -> str: - """Encode an image to base64 asynchronously.""" - async with aiofiles.open(image_path, "rb") as image_file: - image_data = await image_file.read() - return base64.b64encode(image_data).decode("utf-8") - - -async def save_image(image, image_path: str): - """Save an image to a file asynchronously.""" - # Convert PIL Image to BytesIO object - with io.BytesIO() as buffer: - image.save( - buffer, format=image.format - ) # Save the image to the BytesIO object - image_data = ( - buffer.getvalue() - ) # Get the image data from the BytesIO object - - # Write image data to file asynchronously - async with aiofiles.open(image_path, "wb") as f: - await f.write(image_data) diff --git a/py/core/parsers/media/pyzerox/processor/pdf.py b/py/core/parsers/media/pyzerox/processor/pdf.py deleted file mode 100644 index 5bc874382..000000000 --- a/py/core/parsers/media/pyzerox/processor/pdf.py +++ /dev/null @@ -1,115 +0,0 @@ -import asyncio -import logging -import os -from typing import TYPE_CHECKING, List, Optional, Tuple - -from pdf2image import convert_from_path - -from ..constants import Messages, PDFConversionDefaultOptions - -if TYPE_CHECKING: - from ..zerox_models import litellmmodel - -# Package Imports -from .image import save_image -from .text import format_markdown - - -async def convert_pdf_to_images(local_path: str, temp_dir: str) -> List[str]: - """Converts a PDF file to a series of images in the temp_dir. Returns a list of image paths in page order.""" - options = { - "pdf_path": local_path, - "output_folder": temp_dir, - "dpi": PDFConversionDefaultOptions.DPI, - "fmt": PDFConversionDefaultOptions.FORMAT, - "size": PDFConversionDefaultOptions.SIZE, - "thread_count": PDFConversionDefaultOptions.THREAD_COUNT, - "use_pdftocairo": PDFConversionDefaultOptions.USE_PDFTOCAIRO, - "paths_only": True, - } - - try: - image_paths = await asyncio.to_thread(convert_from_path, **options) - return image_paths - except Exception as err: - logging.error(f"Error converting PDF to images: {err}") - - -async def process_page( - image: str, - model: "litellmmodel", - temp_directory: str = "", - input_token_count: int = 0, - output_token_count: int = 0, - prior_page: str = "", - semaphore: Optional[asyncio.Semaphore] = None, -) -> Tuple[str, int, int, str]: - """Process a single page of a PDF""" - - # If semaphore is provided, acquire it before processing the page - if semaphore: - async with semaphore: - return await process_page( - image, - model, - temp_directory, - input_token_count, - output_token_count, - prior_page, - ) - - image_path = os.path.join(temp_directory, image) - - # Get the completion from LiteLLM - try: - completion = await model.completion( - image_path=image_path, - maintain_format=True, - prior_page=prior_page, - ) - - formatted_markdown = format_markdown(completion.content) - input_token_count += completion.input_tokens - output_token_count += completion.output_tokens - prior_page = formatted_markdown - - return ( - formatted_markdown, - input_token_count, - output_token_count, - prior_page, - ) - - except Exception as error: - logging.error(f"{Messages.FAILED_TO_PROCESS_IMAGE} Error:{error}") - return "", input_token_count, output_token_count, "" - - -async def process_pages_in_batches( - images: List[str], - concurrency: int, - model: "litellmmodel", - temp_directory: str = "", - input_token_count: int = 0, - output_token_count: int = 0, - prior_page: str = "", -): - # Create a semaphore to limit the number of concurrent tasks - semaphore = asyncio.Semaphore(concurrency) - - # Process each page in parallel - tasks = [ - process_page( - image, - model, - temp_directory, - input_token_count, - output_token_count, - prior_page, - semaphore, - ) - for image in images - ] - - # Wait for all tasks to complete - return await asyncio.gather(*tasks) diff --git a/py/core/parsers/media/pyzerox/processor/text.py b/py/core/parsers/media/pyzerox/processor/text.py deleted file mode 100644 index 524033e6e..000000000 --- a/py/core/parsers/media/pyzerox/processor/text.py +++ /dev/null @@ -1,14 +0,0 @@ -import re - -# Package imports -from ..constants.patterns import Patterns - - -def format_markdown(text: str) -> str: - """Format markdown text by removing markdown and code blocks""" - - formatted_markdown = re.sub(Patterns.MATCH_MARKDOWN_BLOCKS, r"\1", text) - formatted_markdown = re.sub( - Patterns.MATCH_CODE_BLOCKS, r"\1", formatted_markdown - ) - return formatted_markdown diff --git a/py/core/parsers/media/pyzerox/processor/utils.py b/py/core/parsers/media/pyzerox/processor/utils.py deleted file mode 100644 index da703240d..000000000 --- a/py/core/parsers/media/pyzerox/processor/utils.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -import re -from typing import Iterable, Optional, Union -from urllib.parse import urlparse - -import aiofiles -import aiohttp -from PyPDF2 import PdfReader, PdfWriter - -from ..constants.messages import Messages - -# Package Imports -from ..errors.exceptions import ( - PageNumberOutOfBoundError, - ResourceUnreachableException, -) - - -async def download_file( - file_path: str, - temp_dir: str, -) -> Optional[str]: - """Downloads a file from a URL or local path to a temporary directory.""" - - local_pdf_path = os.path.join(temp_dir, os.path.basename(file_path)) - if is_valid_url(file_path): - async with aiohttp.ClientSession() as session: - async with session.get(file_path) as response: - if response.status != 200: - raise ResourceUnreachableException() - async with aiofiles.open(local_pdf_path, "wb") as f: - await f.write(await response.read()) - else: - async with ( - aiofiles.open(file_path, "rb") as src, - aiofiles.open(local_pdf_path, "wb") as dst, - ): - await dst.write(await src.read()) - return local_pdf_path - - -def is_valid_url(string: str) -> bool: - """Checks if a string is a valid URL.""" - - try: - result = urlparse(string) - return all([result.scheme, result.netloc]) and result.scheme in [ - "http", - "https", - ] - except ValueError: - return False diff --git a/py/core/parsers/media/pyzerox/zerox_core/__init__.py b/py/core/parsers/media/pyzerox/zerox_core/__init__.py deleted file mode 100644 index 825ed3f77..000000000 --- a/py/core/parsers/media/pyzerox/zerox_core/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .zerox import zerox - -__all__ = [ - "zerox", -] diff --git a/py/core/parsers/media/pyzerox/zerox_core/types.py b/py/core/parsers/media/pyzerox/zerox_core/types.py deleted file mode 100644 index 8474a5524..000000000 --- a/py/core/parsers/media/pyzerox/zerox_core/types.py +++ /dev/null @@ -1,42 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Union - - -@dataclass -class ZeroxArgs: - """ - Dataclass to store the arguments for the Zerox class. - """ - - file_path: str - cleanup: bool = True - concurrency: int = 10 - maintain_format: bool = False - model: str = ("gpt-4o-mini",) - output_dir: Optional[str] = None - temp_dir: Optional[str] = None - custom_system_prompt: Optional[str] = None - kwargs: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class Page: - """ - Dataclass to store the page content. - """ - - content: str - content_length: int - page: int - - -@dataclass -class ZeroxOutput: - """ - Dataclass to store the output of the Zerox class. - """ - - completion_time: float - input_tokens: int - output_tokens: int - pages: List[Page] diff --git a/py/core/parsers/media/pyzerox/zerox_core/zerox.py b/py/core/parsers/media/pyzerox/zerox_core/zerox.py deleted file mode 100644 index b89a20e54..000000000 --- a/py/core/parsers/media/pyzerox/zerox_core/zerox.py +++ /dev/null @@ -1,151 +0,0 @@ -import asyncio -import os -import tempfile -import warnings -from datetime import datetime -from typing import Iterable, List, Optional, Union - -import aiofiles -import aiofiles.os as async_os -import aioshutil as async_shutil - -from ..constants.messages import Messages -from ..errors import FileUnavailable - -# Package Imports -from ..processor import ( - convert_pdf_to_images, - process_page, - process_pages_in_batches, -) -from ..zerox_models import litellmmodel -from .types import Page, ZeroxOutput - - -async def zerox( - cleanup: bool = True, - concurrency: int = 10, - file_data: Optional[bytes] = None, - maintain_format: bool = False, - model: str = "gpt-4o-mini", - temp_dir: Optional[str] = None, - custom_system_prompt: Optional[str] = None, - **kwargs, -) -> ZeroxOutput: - """ - API to perform OCR to markdown using Vision models. - Please setup the environment variables for the model and model provider before using this API. Refer: https://docs.litellm.ai/docs/providers - - :param cleanup: Whether to cleanup the temporary files after processing, defaults to True - :type cleanup: bool, optional - :param concurrency: The number of concurrent processes to run, defaults to 10 - :type concurrency: int, optional - :param file_path: The path or URL to the PDF file to process. - :type file_path: str, optional - :param maintain_format: Whether to maintain the format from the previous page, defaults to False - :type maintain_format: bool, optional - :param model: The model to use for generating completions, defaults to "gpt-4o-mini". Note - Refer: https://docs.litellm.ai/docs/providers to pass correct model name as according to provider it might be different from actual name. - :type model: str, optional - :param temp_dir: The directory to store temporary files, defaults to some named folder in system's temp directory. If already exists, the contents will be deleted for zerox uses it. - :type temp_dir: str, optional - :param custom_system_prompt: The system prompt to use for the model, this overrides the default system prompt of zerox. Generally it is not required unless you want some specific behaviour. When set, it will raise a friendly warning, defaults to None - :type custom_system_prompt: str, optional - - :param kwargs: Additional keyword arguments to pass to the model.completion -> litellm.completion method. Refer: https://docs.litellm.ai/docs/providers and https://docs.litellm.ai/docs/completion/input - :return: The markdown content generated by the model. - """ - - input_token_count = 0 - output_token_count = 0 - prior_page = "" - aggregated_markdown: List[str] = [] - start_time = datetime.now() - # File Data Validators - if not file_data: - raise FileUnavailable() - - # Create an instance of the litellm model interface - vision_model = litellmmodel(model=model, **kwargs) - - # override the system prompt if a custom prompt is provided - if custom_system_prompt: - vision_model.system_prompt = custom_system_prompt - - if temp_dir: - if os.path.exists(temp_dir): - await async_shutil.rmtree(temp_dir) - await async_os.makedirs(temp_dir, exist_ok=True) - - # Create a temporary directory to store the PDF and images - with tempfile.TemporaryDirectory() as temp_dir_: - - if temp_dir: - ## use the user provided temp directory - temp_directory = temp_dir - else: - ## use the system temp directory - temp_directory = temp_dir_ - - local_path = os.path.join(temp_directory, "input.pdf") - async with aiofiles.open(local_path, "wb") as f: - await f.write(file_data) - - # Convert the file to a series of images, below function returns a list of image paths in page order - images = await convert_pdf_to_images( - local_path=local_path, temp_dir=temp_directory - ) - - if maintain_format: - for image in images: - result, input_token_count, output_token_count, prior_page = ( - await process_page( - image, - vision_model, - temp_directory, - input_token_count, - output_token_count, - prior_page, - ) - ) - - if result: - aggregated_markdown.append(result) - else: - results = await process_pages_in_batches( - images, - concurrency, - vision_model, - temp_directory, - input_token_count, - output_token_count, - prior_page, - ) - - aggregated_markdown = [ - result[0] for result in results if isinstance(result[0], str) - ] - - ## add token usage - input_token_count += sum([result[1] for result in results]) - output_token_count += sum([result[2] for result in results]) - - # Cleanup the downloaded PDF file - if cleanup and os.path.exists(temp_directory): - await async_shutil.rmtree(temp_directory) - - # Format JSON response - end_time = datetime.now() - completion_time = (end_time - start_time).total_seconds() * 1000 - - # Default behavior when no is provided - formatted_pages = [ - Page(content=content, page=i + 1, content_length=len(content)) - for i, content in enumerate(aggregated_markdown) - ] - - return ZeroxOutput( - completion_time=completion_time, - input_tokens=input_token_count, - output_tokens=output_token_count, - pages=formatted_pages, - ) diff --git a/py/core/parsers/media/pyzerox/zerox_models/__init__.py b/py/core/parsers/media/pyzerox/zerox_models/__init__.py deleted file mode 100644 index f19d77392..000000000 --- a/py/core/parsers/media/pyzerox/zerox_models/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .modellitellm import litellmmodel -from .types import CompletionResponse - -__all__ = [ - "litellmmodel", - "CompletionResponse", -] diff --git a/py/core/parsers/media/pyzerox/zerox_models/base.py b/py/core/parsers/media/pyzerox/zerox_models/base.py deleted file mode 100644 index 4e85dc344..000000000 --- a/py/core/parsers/media/pyzerox/zerox_models/base.py +++ /dev/null @@ -1,43 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar - -if TYPE_CHECKING: - from ..zerox_models import CompletionResponse - -T = TypeVar("T", bound="BaseModel") - - -class BaseModel(ABC): - """ - Base class for all models. - """ - - @abstractmethod - async def completion( - self, - ) -> "CompletionResponse": - raise NotImplementedError("Subclasses must implement this method") - - @abstractmethod - def validate_access( - self, - ) -> None: - raise NotImplementedError("Subclasses must implement this method") - - @abstractmethod - def validate_model( - self, - ) -> None: - raise NotImplementedError("Subclasses must implement this method") - - def __init__( - self, - model: Optional[str] = None, - **kwargs, - ): - self.model = model - self.kwargs = kwargs - - ## validations - # self.validate_model() - # self.validate_access() diff --git a/py/core/parsers/media/pyzerox/zerox_models/modellitellm.py b/py/core/parsers/media/pyzerox/zerox_models/modellitellm.py deleted file mode 100644 index 02c5fe792..000000000 --- a/py/core/parsers/media/pyzerox/zerox_models/modellitellm.py +++ /dev/null @@ -1,169 +0,0 @@ -import os -import warnings -from typing import Any, Dict, List, Optional - -import aiohttp -import litellm - -from ..constants.messages import Messages -from ..constants.prompts import Prompts -from ..errors import ( - MissingEnvironmentVariables, - ModelAccessError, - NotAVisionModel, -) -from ..processor.image import encode_image_to_base64 - -# Package Imports -from .base import BaseModel -from .types import CompletionResponse - -DEFAULT_SYSTEM_PROMPT = Prompts.DEFAULT_SYSTEM_PROMPT - - -class litellmmodel(BaseModel): - ## setting the default system prompt - _system_prompt = DEFAULT_SYSTEM_PROMPT - - def __init__( - self, - model: Optional[str] = None, - **kwargs, - ): - """ - Initializes the Litellm model interface. - :param model: The model to use for generating completions, defaults to "gpt-4o-mini". Refer: https://docs.litellm.ai/docs/providers - :type model: str, optional - - :param kwargs: Additional keyword arguments to pass to self.completion -> litellm.completion. Refer: https://docs.litellm.ai/docs/providers and https://docs.litellm.ai/docs/completion/input - """ - super().__init__(model=model, **kwargs) - - ## calling custom methods to validate the environment and model - self.validate_environment() - self.validate_model() - self.validate_access() - - @property - def system_prompt(self) -> str: - """Returns the system prompt for the model.""" - return self._system_prompt - - @system_prompt.setter - def system_prompt(self, prompt: str) -> None: - """ - Sets/overrides the system prompt for the model. - Will raise a friendly warning to notify the user. - """ - warnings.warn( - f"{Messages.CUSTOM_SYSTEM_PROMPT_WARNING}. Default prompt for zerox is:\n {DEFAULT_SYSTEM_PROMPT}" - ) - self._system_prompt = prompt - - ## custom method on top of BaseModel - def validate_environment(self) -> None: - """Validates the environment variables required for the model.""" - env_config = litellm.validate_environment(model=self.model) - - if not env_config["keys_in_environment"]: - raise MissingEnvironmentVariables(extra_info=env_config) - - def validate_model(self) -> None: - """Validates the model to ensure it is a vision model.""" - if not litellm.supports_vision(model=self.model): - raise NotAVisionModel(extra_info={"model": self.model}) - - def validate_access(self) -> None: - """Validates access to the model -> if environment variables are set correctly with correct values.""" - if not litellm.check_valid_key(model=self.model, api_key=None): - raise ModelAccessError(extra_info={"model": self.model}) - - async def completion( - self, - image_path: str, - maintain_format: bool, - prior_page: str, - ) -> CompletionResponse: - """LitellM completion for image to markdown conversion. - - :param image_path: Path to the image file. - :type image_path: str - :param maintain_format: Whether to maintain the format from the previous page. - :type maintain_format: bool - :param prior_page: The markdown content of the previous page. - :type prior_page: str - - :return: The markdown content generated by the model. - """ - messages = await self._prepare_messages( - image_path=image_path, - maintain_format=maintain_format, - prior_page=prior_page, - ) - - try: - response = await litellm.acompletion( - model=self.model, messages=messages, **self.kwargs - ) - - ## completion response - response = CompletionResponse( - content=response["choices"][0]["message"]["content"], - input_tokens=response["usage"]["prompt_tokens"], - output_tokens=response["usage"]["completion_tokens"], - ) - return response - - except Exception as err: - raise Exception(Messages.COMPLETION_ERROR.format(err)) - - async def _prepare_messages( - self, - image_path: str, - maintain_format: bool, - prior_page: str, - ) -> List[Dict[str, Any]]: - """Prepares the messages to send to the LiteLLM Completion API. - - :param image_path: Path to the image file. - :type image_path: str - :param maintain_format: Whether to maintain the format from the previous page. - :type maintain_format: bool - :param prior_page: The markdown content of the previous page. - :type prior_page: str - """ - # Default system message - messages: List[Dict[str, Any]] = [ - { - "role": "system", - "content": self._system_prompt, - }, - ] - - # If content has already been generated, add it to context. - # This helps maintain the same format across pages. - if maintain_format and prior_page: - messages.append( - { - "role": "system", - "content": f'Markdown must maintain consistent formatting with the following page: \n\n """{prior_page}"""', - }, - ) - - # Add Image to request - base64_image = await encode_image_to_base64(image_path) - messages.append( - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{base64_image}" - }, - }, - ], - } - ) - - return messages diff --git a/py/core/parsers/media/pyzerox/zerox_models/types.py b/py/core/parsers/media/pyzerox/zerox_models/types.py deleted file mode 100644 index 0eea3e2ee..000000000 --- a/py/core/parsers/media/pyzerox/zerox_models/types.py +++ /dev/null @@ -1,12 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class CompletionResponse: - """ - A class representing the response of a completion. - """ - - content: str - input_tokens: int - output_tokens: int diff --git a/py/core/parsers/structured/csv_parser.py b/py/core/parsers/structured/csv_parser.py index ab1e55e0d..c8418f5a1 100644 --- a/py/core/parsers/structured/csv_parser.py +++ b/py/core/parsers/structured/csv_parser.py @@ -3,12 +3,26 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class CSVParser(AsyncParser[DataType]): """A parser for CSV data.""" - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + import csv from io import StringIO @@ -29,7 +43,12 @@ async def ingest( class CSVParserAdvanced(AsyncParser[DataType]): """A parser for CSV data.""" - def __init__(self): + def __init__( + self, config: IngestionConfig, llm_provider: CompletionProvider + ): + self.llm_provider = llm_provider + self.config = config + import csv from io import StringIO diff --git a/py/core/parsers/structured/json_parser.py b/py/core/parsers/structured/json_parser.py index aedb2482c..1efe29c78 100644 --- a/py/core/parsers/structured/json_parser.py +++ b/py/core/parsers/structured/json_parser.py @@ -5,11 +5,26 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class JSONParser(AsyncParser[DataType]): """A parser for JSON data.""" + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + async def ingest( self, data: DataType, *args, **kwargs ) -> AsyncGenerator[str, None]: diff --git a/py/core/parsers/structured/xlsx_parser.py b/py/core/parsers/structured/xlsx_parser.py index 5237439ea..e06a22d73 100644 --- a/py/core/parsers/structured/xlsx_parser.py +++ b/py/core/parsers/structured/xlsx_parser.py @@ -4,12 +4,25 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class XLSXParser(AsyncParser[DataType]): """A parser for XLSX data.""" - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config try: from openpyxl import load_workbook @@ -36,7 +49,11 @@ class XLSXParserAdvanced(AsyncParser[DataType]): """A parser for XLSX data.""" # identifies connected components in the excel graph and extracts data from each component - def __init__(self): + def __init__( + self, config: IngestionConfig, llm_provider: CompletionProvider + ): + self.llm_provider = llm_provider + self.config = config try: import networkx as nx import numpy as np diff --git a/py/core/parsers/text/html_parser.py b/py/core/parsers/text/html_parser.py index c2e893120..6f3e146c6 100644 --- a/py/core/parsers/text/html_parser.py +++ b/py/core/parsers/text/html_parser.py @@ -5,11 +5,26 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class HTMLParser(AsyncParser[DataType]): """A parser for HTML data.""" + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + async def ingest( self, data: DataType, *args, **kwargs ) -> AsyncGenerator[str, None]: diff --git a/py/core/parsers/text/md_parser.py b/py/core/parsers/text/md_parser.py index 725ae5724..2a181fbf9 100644 --- a/py/core/parsers/text/md_parser.py +++ b/py/core/parsers/text/md_parser.py @@ -5,12 +5,26 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class MDParser(AsyncParser[DataType]): """A parser for Markdown data.""" - def __init__(self): + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + import markdown self.markdown = markdown diff --git a/py/core/parsers/text/text_parser.py b/py/core/parsers/text/text_parser.py index da72ea85f..791f0783c 100644 --- a/py/core/parsers/text/text_parser.py +++ b/py/core/parsers/text/text_parser.py @@ -3,11 +3,26 @@ from core.base.abstractions import DataType from core.base.parsers.base_parser import AsyncParser +from core.base.providers import ( + CompletionProvider, + DatabaseProvider, + IngestionConfig, +) class TextParser(AsyncParser[DataType]): """A parser for raw text data.""" + def __init__( + self, + config: IngestionConfig, + database_provider: DatabaseProvider, + llm_provider: CompletionProvider, + ): + self.database_provider = database_provider + self.llm_provider = llm_provider + self.config = config + async def ingest( self, data: DataType, *args, **kwargs ) -> AsyncGenerator[DataType, None]: diff --git a/py/core/providers/auth/r2r_auth.py b/py/core/providers/auth/r2r_auth.py index 5babfa508..3c2fb5642 100644 --- a/py/core/providers/auth/r2r_auth.py +++ b/py/core/providers/auth/r2r_auth.py @@ -306,7 +306,7 @@ async def request_password_reset(self, email: str) -> Dict[str, str]: ) # TODO: Integrate with email provider to send reset link - await self.email_provider.send_reset_email(email, reset_token) + await self.email_provider.send_password_reset_email(email, reset_token) return {"message": "If the email exists, a reset link has been sent"} diff --git a/py/core/providers/database/prompts/vision_img.yaml b/py/core/providers/database/prompts/vision_img.yaml new file mode 100644 index 000000000..4a1aa4777 --- /dev/null +++ b/py/core/providers/database/prompts/vision_img.yaml @@ -0,0 +1,4 @@ +vision_img: + template: > + First, provide a title for the image, then explain everything that you see. Be very thorough in your analysis as a user will need to understand the image without seeing it. If it is possible to transcribe the image to text directly, then do so. The more detail you provide, the better the user will understand the image. + input_types: {} diff --git a/py/core/providers/database/prompts/vision_pdf.yaml b/py/core/providers/database/prompts/vision_pdf.yaml new file mode 100644 index 000000000..350ead2d9 --- /dev/null +++ b/py/core/providers/database/prompts/vision_pdf.yaml @@ -0,0 +1,42 @@ +vision_pdf: + template: > + Convert this PDF page to markdown format, preserving all content and formatting. Follow these guidelines: + + Text: + - Maintain the original text hierarchy (headings, paragraphs, lists) + - Preserve any special formatting (bold, italic, underline) + - Include all footnotes, citations, and references + - Keep text in its original reading order + + Tables: + - Recreate tables using markdown table syntax + - Preserve all headers, rows, and columns + - Maintain alignment and formatting where possible + - Include any table captions or notes + + Equations: + - Convert mathematical equations using LaTeX notation + - Preserve equation numbers if present + - Include any surrounding context or references + + Images: + - Enclose image descriptions within [FIG] and [/FIG] tags + - Include detailed descriptions of: + * Main subject matter + * Text overlays or captions + * Charts, graphs, or diagrams + * Relevant colors, patterns, or visual elements + - Maintain image placement relative to surrounding text + + Additional Elements: + - Include page numbers if visible + - Preserve headers and footers + - Maintain sidebars or callout boxes + - Keep any special symbols or characters + + Quality Requirements: + - Ensure 100% content preservation + - Maintain logical document flow + - Verify all markdown syntax is valid + - Double-check completeness before submitting + input_types: {} diff --git a/py/core/providers/email/smtp.py b/py/core/providers/email/smtp.py index fecbe0008..da8bf67cd 100644 --- a/py/core/providers/email/smtp.py +++ b/py/core/providers/email/smtp.py @@ -35,7 +35,9 @@ def __init__(self, config: EmailConfig): if not self.smtp_password: raise ValueError("SMTP password is required") - self.from_email = config.from_email or os.getenv("R2R_FROM_EMAIL") + self.from_email: Optional[str] = config.from_email or os.getenv( + "R2R_FROM_EMAIL" + ) if not self.from_email: raise ValueError("From email is required") @@ -52,8 +54,8 @@ async def send_email( html_body: Optional[str] = None, ) -> None: msg = MIMEMultipart("alternative") - msg["Subject"] = subject - msg["From"] = self.from_email + msg["Subject"] = subject # type: ignore + msg["From"] = self.from_email # type: ignore msg["To"] = to_email msg.attach(MIMEText(body, "plain")) @@ -63,7 +65,7 @@ async def send_email( try: smtp = SMTP( hostname=self.smtp_server, - port=self.smtp_port, + port=int(self.smtp_port) if self.smtp_port else None, use_tls=self.use_tls, ) @@ -84,9 +86,9 @@ async def send_verification_email( subject = "Verify Your Email Address" body = f""" Thank you for registering! Please verify your email address by entering the following code: - + {verification_code} - + This code will expire in 24 hours. """ html_body = f""" @@ -103,9 +105,9 @@ async def send_password_reset_email( subject = "Password Reset Request" body = f""" We received a request to reset your password. Use the following code to reset your password: - + {reset_token} - + This code will expire in 1 hour. If you didn't request this reset, please ignore this email. """ html_body = f""" diff --git a/py/core/providers/ingestion/r2r/base.py b/py/core/providers/ingestion/r2r/base.py index 2632644dc..5334ad3b0 100644 --- a/py/core/providers/ingestion/r2r/base.py +++ b/py/core/providers/ingestion/r2r/base.py @@ -19,6 +19,9 @@ from core.base.abstractions import DocumentExtraction from core.utils import generate_extraction_id +from ...database import PostgresDBProvider +from ...llm import LiteLLMCompletionProvider, OpenAICompletionProvider + logger = logging.getLogger() @@ -38,7 +41,7 @@ class R2RIngestionProvider(IngestionProvider): DocumentType.HTM: parsers.HTMLParser, DocumentType.JSON: parsers.JSONParser, DocumentType.MD: parsers.MDParser, - DocumentType.PDF: parsers.PDFParser, + DocumentType.PDF: parsers.VLMPDFParser, DocumentType.PPTX: parsers.PPTParser, DocumentType.TXT: parsers.TextParser, DocumentType.XLSX: parsers.XLSXParser, @@ -47,6 +50,8 @@ class R2RIngestionProvider(IngestionProvider): DocumentType.JPG: parsers.ImageParser, DocumentType.PNG: parsers.ImageParser, DocumentType.SVG: parsers.ImageParser, + DocumentType.WEBP: parsers.ImageParser, + DocumentType.ICO: parsers.ImageParser, DocumentType.MP3: parsers.AudioParser, } @@ -54,23 +59,25 @@ class R2RIngestionProvider(IngestionProvider): DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, DocumentType.PDF: { "unstructured": parsers.PDFParserUnstructured, - "zerox": parsers.ZeroxPDFParser, - "marker": parsers.PDFParserMarker, + "basic": parsers.BasicPDFParser, }, DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, } - IMAGE_TYPES = { - DocumentType.GIF, - DocumentType.JPG, - DocumentType.JPEG, - DocumentType.PNG, - DocumentType.SVG, - } - - def __init__(self, config: R2RIngestionConfig): - super().__init__(config) + def __init__( + self, + config: R2RIngestionConfig, + database_provider: PostgresDBProvider, + llm_provider: Union[ + LiteLLMCompletionProvider, OpenAICompletionProvider + ], + ): + super().__init__(config, database_provider, llm_provider) self.config: R2RIngestionConfig = config # for type hinting + self.database_provider: PostgresDBProvider = database_provider + self.llm_provider: Union[ + LiteLLMCompletionProvider, OpenAICompletionProvider + ] = llm_provider self.parsers: dict[DocumentType, AsyncParser] = {} self.text_splitter = self._build_text_splitter() self._initialize_parsers() @@ -83,10 +90,18 @@ def _initialize_parsers(self): for doc_type, parser in self.DEFAULT_PARSERS.items(): # will choose the first parser in the list if doc_type not in self.config.excluded_parsers: - self.parsers[doc_type] = parser() + self.parsers[doc_type] = parser( + config=self.config, + database_provider=self.database_provider, + llm_provider=self.llm_provider, + ) for doc_type, doc_parser_name in self.config.extra_parsers.items(): - self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = ( - R2RIngestionProvider.EXTRA_PARSERS[doc_type][doc_parser_name]() + self.parsers[ + f"{doc_parser_name}_{str(doc_type)}" + ] = R2RIngestionProvider.EXTRA_PARSERS[doc_type][doc_parser_name]( + config=self.config, + database_provider=self.database_provider, + llm_provider=self.llm_provider, ) def _build_text_splitter( diff --git a/py/core/providers/ingestion/unstructured/base.py b/py/core/providers/ingestion/unstructured/base.py index a1d57af9e..e296782be 100644 --- a/py/core/providers/ingestion/unstructured/base.py +++ b/py/core/providers/ingestion/unstructured/base.py @@ -6,7 +6,7 @@ import time from copy import copy from io import BytesIO -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator, Optional, Union import httpx from unstructured_client import UnstructuredClient @@ -25,6 +25,9 @@ from core.base.providers.ingestion import IngestionConfig, IngestionProvider from core.utils import generate_extraction_id +from ...database import PostgresDBProvider +from ...llm import LiteLLMCompletionProvider, OpenAICompletionProvider + logger = logging.getLogger() @@ -83,6 +86,7 @@ class UnstructuredIngestionProvider(IngestionProvider): DocumentType.JPG: [parsers.ImageParser], DocumentType.PNG: [parsers.ImageParser], DocumentType.SVG: [parsers.ImageParser], + DocumentType.PDF: [parsers.VLMPDFParser], DocumentType.MP3: [parsers.AudioParser], DocumentType.JSON: [parsers.JSONParser], # type: ignore DocumentType.HTML: [parsers.HTMLParser], # type: ignore @@ -92,24 +96,27 @@ class UnstructuredIngestionProvider(IngestionProvider): EXTRA_PARSERS = { DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore DocumentType.PDF: { - "unstructured": parsers.PDFParserUnstructured, - "zerox": parsers.ZeroxPDFParser, - "marker": parsers.PDFParserMarker, + "basic": parsers.BasicPDFParser, }, DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore } - IMAGE_TYPES = { - DocumentType.GIF, - DocumentType.JPG, - DocumentType.JPEG, - DocumentType.PNG, - DocumentType.SVG, - } - - def __init__(self, config: UnstructuredIngestionConfig): - super().__init__(config) + def __init__( + self, + config: UnstructuredIngestionConfig, + database_provider: PostgresDBProvider, + llm_provider: Union[ + LiteLLMCompletionProvider, OpenAICompletionProvider + ], + ): + + super().__init__(config, database_provider, llm_provider) self.config: UnstructuredIngestionConfig = config + self.database_provider: PostgresDBProvider = database_provider + self.llm_provider: Union[ + LiteLLMCompletionProvider, OpenAICompletionProvider + ] = llm_provider + if config.provider == "unstructured_api": try: self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"] @@ -142,25 +149,33 @@ def __init__(self, config: UnstructuredIngestionConfig): self.client = httpx.AsyncClient() - super().__init__(config) + super().__init__(config, database_provider, llm_provider) self.parsers: dict[DocumentType, AsyncParser] = {} self._initialize_parsers() def _initialize_parsers(self): - for doc_type, parser_infos in self.R2R_FALLBACK_PARSERS.items(): - for parser_info in parser_infos: + for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items(): + for parser in parsers: if ( doc_type not in self.config.excluded_parsers and doc_type not in self.parsers ): # will choose the first parser in the list - self.parsers[doc_type] = parser_info() + self.parsers[doc_type] = parser( + config=self.config, + database_provider=self.database_provider, + llm_provider=self.llm_provider, + ) # TODO - Reduce code duplication between Unstructured & R2R for doc_type, doc_parser_name in self.config.extra_parsers.items(): - self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = ( - UnstructuredIngestionProvider.EXTRA_PARSERS[doc_type][ - doc_parser_name - ]() + self.parsers[ + f"{doc_parser_name}_{str(doc_type)}" + ] = UnstructuredIngestionProvider.EXTRA_PARSERS[doc_type][ + doc_parser_name + ]( + config=self.config, + database_provider=self.database_provider, + llm_provider=self.llm_provider, ) async def parse_fallback( @@ -213,9 +228,25 @@ async def parse( ) elements = [] + # allow user to re-override places where unstructured is overriden above + # e.g. + # "ingestion_config": { + # ..., + # "parser_overrides": { + # "pdf": "unstructured" + # } + # } + reoverride_with_unst = ( + parser_overrides.get(document.document_type.value, None) + == "unstructured" + ) + # TODO - Cleanup this approach to be less hardcoded # TODO - Remove code duplication between Unstructured & R2R - if document.document_type.value in parser_overrides: + if ( + document.document_type.value in parser_overrides + and not reoverride_with_unst + ): logger.info( f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}" ) @@ -226,7 +257,10 @@ async def parse( ): elements.append(element) - elif document.document_type in self.R2R_FALLBACK_PARSERS.keys(): + elif ( + document.document_type in self.R2R_FALLBACK_PARSERS.keys() + and not reoverride_with_unst + ): logger.info( f"Parsing {document.document_type}: {document.id} with fallback parser" ) diff --git a/py/poetry.lock b/py/poetry.lock index e441531f2..50fc2e603 100644 --- a/py/poetry.lock +++ b/py/poetry.lock @@ -5321,6 +5321,17 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "types-aiofiles" +version = "24.1.0.20240626" +description = "Typing stubs for aiofiles" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-aiofiles-24.1.0.20240626.tar.gz", hash = "sha256:48604663e24bc2d5038eac05ccc33e75799b0779e93e13d6a8f711ddc306ac08"}, + {file = "types_aiofiles-24.1.0.20240626-py3-none-any.whl", hash = "sha256:7939eca4a8b4f9c6491b6e8ef160caee9a21d32e18534a57d5ed90aee47c66b4"}, +] + [[package]] name = "types-requests" version = "2.32.0.20241016" @@ -5858,4 +5869,4 @@ ingestion-bundle = ["aiofiles", "aioshutil", "beautifulsoup4", "bs4", "markdown" [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "076bf72cff07b22d020e62cbe5e477865157be0deba19f817b1b2a41787625df" +content-hash = "fb41515396b9a34291521c668a4d9b889406c781731a00cf6b06ef2e6347b28a" diff --git a/py/pyproject.toml b/py/pyproject.toml index e56393c3e..d8d1866e0 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -81,6 +81,7 @@ pypdf2 = { version = "^3.0.1", optional = true } python-pptx = { version = "^1.0.1", optional = true } python-docx = { version = "^1.1.0", optional = true } aiosmtplib = "^3.0.2" +types-aiofiles = "^24.1.0.20240626" [tool.poetry.extras] core = [ diff --git a/py/r2r.toml b/py/r2r.toml index 5ece375f4..6cb95f422 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -88,6 +88,12 @@ chunk_size = 1_024 chunk_overlap = 512 excluded_parsers = ["mp4"] +audio_transcription_model="openai/whisper-1" +vision_img_model = "gpt-4o-mini" +vision_pdf_model = "gpt-4o-mini" +# vision_img_prompt_name = "vision_img" # optional, default is "vision_img" +# vision_pdf_prompt_name = "vision_pdf" # optional, default is "vision_pdf" + [ingestion.chunk_enrichment_settings] enable_chunk_enrichment = false # disabled by default strategies = ["semantic", "neighborhood"] @@ -97,9 +103,6 @@ excluded_parsers = ["mp4"] semantic_similarity_threshold = 0.7 generation_config = { model = "openai/gpt-4o-mini" } - [ingestion.extra_parsers] - pdf = "zerox" - [logging] provider = "r2r" log_table = "logs" diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 04629ffad..7c4daee3b 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -49,6 +49,8 @@ class DocumentType(str, Enum): TIFF = "tiff" JPG = "jpg" SVG = "svg" + WEBP = "webp" + ICO = "ico" # Markdown MD = "md" From 2b71f74007f6a753126822c635f057e798ee7729 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 30 Oct 2024 09:37:35 -0700 Subject: [PATCH 07/12] cleanup; pre-commit --- .../r2r-full-py-integration-tests.yml | 5 - .../r2r-light-py-integration-tests.yml | 5 - .gitignore | 2 + .../2024-10-28/10-02-34/.hydra/config.yaml | 35 ---- outputs/2024-10-28/10-02-34/.hydra/hydra.yaml | 159 ------------------ .../2024-10-28/10-02-34/.hydra/overrides.yaml | 1 - outputs/2024-10-28/10-02-34/main.log | 2 - py/core/base/providers/ingestion.py | 8 - py/core/main/api/management_router.py | 1 - .../hatchet/ingestion_workflow.py | 13 +- py/core/main/services/ingestion_service.py | 6 +- py/r2r.toml | 1 + 12 files changed, 16 insertions(+), 222 deletions(-) delete mode 100644 outputs/2024-10-28/10-02-34/.hydra/config.yaml delete mode 100644 outputs/2024-10-28/10-02-34/.hydra/hydra.yaml delete mode 100644 outputs/2024-10-28/10-02-34/.hydra/overrides.yaml delete mode 100644 outputs/2024-10-28/10-02-34/main.log diff --git a/.github/workflows/r2r-full-py-integration-tests.yml b/.github/workflows/r2r-full-py-integration-tests.yml index bbd349607..55dd6ac7e 100644 --- a/.github/workflows/r2r-full-py-integration-tests.yml +++ b/.github/workflows/r2r-full-py-integration-tests.yml @@ -58,27 +58,22 @@ jobs: - name: Run CLI Ingestion Tests if: matrix.test_category == 'cli-ingestion' uses: ./.github/actions/run-cli-ingestion-tests - continue-on-error: true - name: Run CLI Retrieval Tests if: matrix.test_category == 'cli-retrieval' uses: ./.github/actions/run-cli-retrieval-tests - continue-on-error: true - name: Run SDK Ingestion Tests if: matrix.test_category == 'sdk-ingestion' uses: ./.github/actions/run-sdk-ingestion-tests - continue-on-error: true - name: Run SDK Retrieval Tests if: matrix.test_category == 'sdk-retrieval' uses: ./.github/actions/run-sdk-retrieval-tests - continue-on-error: true - name: Run SDK Auth Tests if: matrix.test_category == 'sdk-auth' uses: ./.github/actions/run-sdk-auth-tests - continue-on-error: true - name: Run SDK Collections Tests if: matrix.test_category == 'sdk-collections' diff --git a/.github/workflows/r2r-light-py-integration-tests.yml b/.github/workflows/r2r-light-py-integration-tests.yml index 01da102f4..eb7d7b04b 100644 --- a/.github/workflows/r2r-light-py-integration-tests.yml +++ b/.github/workflows/r2r-light-py-integration-tests.yml @@ -61,27 +61,22 @@ jobs: - name: Run CLI Ingestion Tests if: matrix.test_category == 'cli-ingestion' uses: ./.github/actions/run-cli-ingestion-tests - continue-on-error: true - name: Run CLI Retrieval Tests if: matrix.test_category == 'cli-retrieval' uses: ./.github/actions/run-cli-retrieval-tests - continue-on-error: true - name: Run SDK Ingestion Tests if: matrix.test_category == 'sdk-ingestion' uses: ./.github/actions/run-sdk-ingestion-tests - continue-on-error: true - name: Run SDK Retrieval Tests if: matrix.test_category == 'sdk-retrieval' uses: ./.github/actions/run-sdk-retrieval-tests - continue-on-error: true - name: Run SDK Auth Tests if: matrix.test_category == 'sdk-auth' uses: ./.github/actions/run-sdk-auth-tests - continue-on-error: true - name: Run SDK Collections Tests if: matrix.test_category == 'sdk-collections' diff --git a/.gitignore b/.gitignore index b80511bfd..7fafd2767 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,5 @@ dist/ *.test go.work go.work.sum + +.vscode/ diff --git a/outputs/2024-10-28/10-02-34/.hydra/config.yaml b/outputs/2024-10-28/10-02-34/.hydra/config.yaml deleted file mode 100644 index 52d5c6a13..000000000 --- a/outputs/2024-10-28/10-02-34/.hydra/config.yaml +++ /dev/null @@ -1,35 +0,0 @@ -name: vector_rag -step: ingest -run_settings: - use_graph: false - use_vector: true - use_hybrid: false - contextual_enrichment: false - deduplication: false -server: - type: r2r - settings: - port: 7272 - use_docker: true - config_path: /Users/shreyas/new7/R2R/py/workspace/evals/configs/server/r2r/r2r_full.toml - stdout_path: r2r_stdout.log - stderr_path: r2r_stderr.log - project_name: ${name} - contextual_enrichment: ${run_settings.contextual_enrichment} - r2r_image: r2r-test -ingestion_dataset: - id: 671c1dddea3bd8151d1ceb3c -eval_dataset: - id: 671c1dddea3bd8151d1ceb3c -evaluator: - base_url: https://api.relari.com/v1 - api_key: ek-e581232b34d3d6c3a7af8079b50ab67a - project_id: 671854e7db1ad90003b5c04b - eval_dataset: '{eval_dataset}' -metrics: - metrics_list: - - answer_relevancy - - context_relevancy - - faithfulness - - answer_similarity - - latency diff --git a/outputs/2024-10-28/10-02-34/.hydra/hydra.yaml b/outputs/2024-10-28/10-02-34/.hydra/hydra.yaml deleted file mode 100644 index fef701667..000000000 --- a/outputs/2024-10-28/10-02-34/.hydra/hydra.yaml +++ /dev/null @@ -1,159 +0,0 @@ -hydra: - run: - dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} - sweep: - dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} - subdir: ${hydra.job.num} - launcher: - _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher - sweeper: - _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper - max_batch_size: null - params: null - help: - app_name: ${hydra.job.name} - header: '${hydra.help.app_name} is powered by Hydra. - - ' - footer: 'Powered by Hydra (https://hydra.cc) - - Use --hydra-help to view Hydra specific help - - ' - template: '${hydra.help.header} - - == Configuration groups == - - Compose your configuration from those groups (group=option) - - - $APP_CONFIG_GROUPS - - - == Config == - - Override anything in the config (foo.bar=value) - - - $CONFIG - - - ${hydra.help.footer} - - ' - hydra_help: - template: 'Hydra (${hydra.runtime.version}) - - See https://hydra.cc for more info. - - - == Flags == - - $FLAGS_HELP - - - == Configuration groups == - - Compose your configuration from those groups (For example, append hydra/job_logging=disabled - to command line) - - - $HYDRA_CONFIG_GROUPS - - - Use ''--cfg hydra'' to Show the Hydra config. - - ' - hydra_help: ??? - hydra_logging: - version: 1 - formatters: - simple: - format: '[%(asctime)s][HYDRA] %(message)s' - handlers: - console: - class: logging.StreamHandler - formatter: simple - stream: ext://sys.stdout - root: - level: INFO - handlers: - - console - loggers: - logging_example: - level: DEBUG - disable_existing_loggers: false - job_logging: - version: 1 - formatters: - simple: - format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' - handlers: - console: - class: logging.StreamHandler - formatter: simple - stream: ext://sys.stdout - file: - class: logging.FileHandler - formatter: simple - filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log - root: - level: INFO - handlers: - - console - - file - disable_existing_loggers: false - env: {} - mode: RUN - searchpath: [] - callbacks: {} - output_subdir: .hydra - overrides: - hydra: - - hydra.mode=RUN - task: [] - job: - name: main - chdir: true - override_dirname: '' - id: ??? - num: ??? - config_name: config.yaml - env_set: {} - env_copy: [] - config: - override_dirname: - kv_sep: '=' - item_sep: ',' - exclude_keys: [] - runtime: - version: 1.3.2 - version_base: '1.3' - cwd: /Users/shreyas/new7/R2R - config_sources: - - path: hydra.conf - schema: pkg - provider: hydra - - path: /Users/shreyas/new7/R2R/py/workspace/evals/configs - schema: file - provider: main - - path: '' - schema: structured - provider: schema - output_dir: /Users/shreyas/new7/R2R/outputs/2024-10-28/10-02-34 - choices: - metrics: llm - evaluator: relari - eval_dataset: aristotle - ingestion_dataset: aristotle - server: r2r - hydra/env: default - hydra/callbacks: null - hydra/job_logging: default - hydra/hydra_logging: default - hydra/hydra_help: default - hydra/help: default - hydra/sweeper: basic - hydra/launcher: basic - hydra/output: default - verbose: false diff --git a/outputs/2024-10-28/10-02-34/.hydra/overrides.yaml b/outputs/2024-10-28/10-02-34/.hydra/overrides.yaml deleted file mode 100644 index fe51488c7..000000000 --- a/outputs/2024-10-28/10-02-34/.hydra/overrides.yaml +++ /dev/null @@ -1 +0,0 @@ -[] diff --git a/outputs/2024-10-28/10-02-34/main.log b/outputs/2024-10-28/10-02-34/main.log deleted file mode 100644 index 8ee85a065..000000000 --- a/outputs/2024-10-28/10-02-34/main.log +++ /dev/null @@ -1,2 +0,0 @@ -[2024-10-28 10:02:37,858][httpx][INFO] - HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK" -[2024-10-28 10:02:38,987][root][INFO] - Posthog telemetry disabled, debug mode off diff --git a/py/core/base/providers/ingestion.py b/py/core/base/providers/ingestion.py index ddc3c5e41..72b687f78 100644 --- a/py/core/base/providers/ingestion.py +++ b/py/core/base/providers/ingestion.py @@ -19,14 +19,6 @@ class IngestionConfig(ProviderConfig): ChunkEnrichmentSettings() ) - audio_transcription_model: str - - vision_img_prompt_name: Optional[str] = None - vision_img_model: str - - vision_pdf_prompt_name: Optional[str] = None - vision_pdf_model: str - audio_transcription_model: str = "openai/whisper-1" vision_img_prompt_name: str = "vision_img" diff --git a/py/core/main/api/management_router.py b/py/core/main/api/management_router.py index f4ee8cd5e..12c13a288 100644 --- a/py/core/main/api/management_router.py +++ b/py/core/main/api/management_router.py @@ -888,7 +888,6 @@ async def delete_conversation( await self.service.delete_conversation(conversation_id) return None # type: ignore - @self.router.get("/r2r_project_name") @self.base_endpoint async def r2r_project_name( diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 49cf12b0d..aa5ff5ab8 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -15,7 +15,10 @@ increment_version, ) from core.base.abstractions import DocumentInfo, R2RException -from core.utils import generate_default_user_collection_id, update_settings_from_dict +from core.utils import ( + generate_default_user_collection_id, + update_settings_from_dict, +) from ...services import IngestionService, IngestionServiceAdapter @@ -249,10 +252,10 @@ async def on_failure(self, context: Context) -> None: document_info = documents_overview[0] # Update the document status to FAILED - if ( - document_info.ingestion_status - not in [IngestionStatus.SUCCESS, IngestionStatus.ENRICHED] - ): + if document_info.ingestion_status not in [ + IngestionStatus.SUCCESS, + IngestionStatus.ENRICHED, + ]: await self.ingestion_service.update_document_status( document_info, status=IngestionStatus.FAILED, diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 356a9b9b6..71f38b121 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -531,7 +531,11 @@ async def _get_enriched_chunk_text( metadata=chunk["metadata"], ) - async def chunk_enrichment(self, document_id: UUID, chunk_enrichment_settings: ChunkEnrichmentSettings) -> int: + async def chunk_enrichment( + self, + document_id: UUID, + chunk_enrichment_settings: ChunkEnrichmentSettings, + ) -> int: # just call the pipe on every chunk of the document # get all document_chunks diff --git a/py/r2r.toml b/py/r2r.toml index 1677bd4bb..7ef4b5bea 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -84,6 +84,7 @@ batch_size = 128 add_title_as_prefix = false rerank_model = "None" concurrent_request_limit = 256 +quantization_settings = { quantization_type = "FP32" } [file] provider = "postgres" From a57dbbb93508a4da277c00507abc964ba763afdb Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 30 Oct 2024 09:39:12 -0700 Subject: [PATCH 08/12] rm launch json --- .vscode/launch.json | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 7d10ad3be..000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "Python: main.py with RAG", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder}/py/workspace/evals/main.py", - "console": "integratedTerminal", - "args": ["step=rag"], - "python": "/Users/shreyas/Library/Caches/pypoetry/virtualenvs/r2r-R0EY5voQ-py3.12/bin/python" - }, - { - "name": "Python Debugger: Current File with Arguments", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder}/py/workspace/evals/main.py", - "console": "integratedTerminal", - "args": "${command:pickArgs}", - "python": "/Users/shreyas/Library/Caches/pypoetry/virtualenvs/r2r-R0EY5voQ-py3.12/bin/python" - } - ] -} From 34b9158a9ed955a013acbaa75cfd6dffd60cf949 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 30 Oct 2024 10:48:37 -0700 Subject: [PATCH 09/12] checkin work --- py/core/base/providers/database.py | 5 +- py/core/base/providers/ingestion.py | 2 + .../main/orchestration/hatchet/kg_workflow.py | 45 +++++++------- py/core/main/services/ingestion_service.py | 12 ++-- py/core/providers/database/collection.py | 5 +- py/core/providers/database/document.py | 60 ++++++++++--------- py/shared/abstractions/document.py | 33 ++++++++++ py/shared/api/models/management/responses.py | 1 + 8 files changed, 107 insertions(+), 56 deletions(-) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index a035ea639..569dbc054 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -1130,7 +1130,10 @@ async def get_workflow_status( return await self.document_handler.get_workflow_status(id, status_type) async def set_workflow_status( - self, id: Union[UUID, list[UUID]], status_type: str, status: str + self, + id: Union[UUID, list[UUID]], + status_type: str, + status: str, ): return await self.document_handler.set_workflow_status( id, status_type, status diff --git a/py/core/base/providers/ingestion.py b/py/core/base/providers/ingestion.py index 72b687f78..7826c8c84 100644 --- a/py/core/base/providers/ingestion.py +++ b/py/core/base/providers/ingestion.py @@ -19,6 +19,8 @@ class IngestionConfig(ProviderConfig): ChunkEnrichmentSettings() ) + extra_parsers: dict[str, str] = {} + audio_transcription_model: str = "openai/whisper-1" vision_img_prompt_name: str = "vision_img" diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 51f81632f..15abc4241 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -9,7 +9,7 @@ from core import GenerationConfig from core.base import OrchestrationProvider -from core.base.abstractions import KGExtractionStatus, KGEnrichmentStatus +from core.base.abstractions import KGEnrichmentStatus, KGExtractionStatus from ...services import KgService @@ -291,29 +291,12 @@ async def kg_entity_deduplication_setup( key=f"{i}/{total_workflows}_entity_deduplication_part", ) ) - result = await asyncio.gather(*workflows) - # set status to success - await self.kg_service.providers.database.set_workflow_status( - id=collection_id, - status_type="kg_enrichment_status", - status=KGEnrichmentStatus.SUCCESS, - ) + await asyncio.gather(*workflows) return { "result": f"successfully queued kg entity deduplication for collection {collection_id} with {number_of_distinct_entities} distinct entities" } - @orchestration_provider.failure() - async def on_failure(self, context: Context) -> None: - collection_id = context.workflow_input()["request"][ - "collection_id" - ] - await self.kg_service.providers.database.set_workflow_status( - id=collection_id, - status_type="kg_enrichment_status", - status=KGEnrichmentStatus.FAILED, - ) - @orchestration_provider.workflow( name="kg-entity-deduplication-summary", timeout="360m" ) @@ -417,11 +400,31 @@ async def kg_community_summary(self, context: Context) -> dict: key=f"{i}/{total_workflows}_community_summary", ) ) - await asyncio.gather(*workflows) + + results = await asyncio.gather(*workflows) + + # set status to success + await self.kg_service.providers.database.set_workflow_status( + id=collection_id, + status_type="kg_enrichment_status", + status=KGEnrichmentStatus.SUCCESS, + ) + return { - "result": f"Successfully spawned summary workflows for {num_communities} communities." + "result": f"Successfully completed enrichment for collection {collection_id} in {len(results)} workflows." } + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + collection_id = context.workflow_input()["request"][ + "collection_id" + ] + await self.kg_service.providers.database.set_workflow_status( + id=collection_id, + status_type="kg_enrichment_status", + status=KGEnrichmentStatus.FAILED, + ) + @orchestration_provider.workflow( name="kg-community-summary", timeout="360m" ) diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 71f38b121..54832a785 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -457,7 +457,9 @@ async def _get_enriched_chunk_text( for neighbor in semantic_neighbors ) - context_chunk_ids = list( + # weird behavior, sometimes we get UUIDs + # FIXME: figure out why + context_chunk_ids_str = list( set( [ str(context_chunk_id) @@ -465,8 +467,10 @@ async def _get_enriched_chunk_text( ] ) ) - context_chunk_ids = [ - UUID(context_chunk_id) for context_chunk_id in context_chunk_ids + + context_chunk_ids_uuid = [ + UUID(context_chunk_id) + for context_chunk_id in context_chunk_ids_str ] context_chunk_texts = [ @@ -476,7 +480,7 @@ async def _get_enriched_chunk_text( "chunk_order" ], ) - for context_chunk_id in context_chunk_ids + for context_chunk_id in context_chunk_ids_uuid ] # sort by chunk_order diff --git a/py/core/providers/database/collection.py b/py/core/providers/database/collection.py index 7dff5d8b6..1e6cf65f4 100644 --- a/py/core/providers/database/collection.py +++ b/py/core/providers/database/collection.py @@ -352,14 +352,14 @@ async def get_collections_overview( """Get an overview of collections, optionally filtered by collection IDs, with pagination.""" query = f""" WITH collection_overview AS ( - SELECT g.collection_id, g.name, g.description, g.created_at, g.updated_at, + SELECT g.collection_id, g.name, g.description, g.created_at, g.updated_at, g.kg_enrichment_status, COUNT(DISTINCT u.user_id) AS user_count, COUNT(DISTINCT d.document_id) AS document_count FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} g LEFT JOIN {self._get_table_name('users')} u ON g.collection_id = ANY(u.collection_ids) LEFT JOIN {self._get_table_name('document_info')} d ON g.collection_id = ANY(d.collection_ids) {' WHERE g.collection_id = ANY($1)' if collection_ids else ''} - GROUP BY g.collection_id, g.name, g.description, g.created_at, g.updated_at + GROUP BY g.collection_id, g.name, g.description, g.created_at, g.updated_at, g.kg_enrichment_status ), counted_overview AS ( SELECT *, COUNT(*) OVER() AS total_entries @@ -393,6 +393,7 @@ async def get_collections_overview( updated_at=row["updated_at"], user_count=row["user_count"], document_count=row["document_count"], + kg_enrichment_status=row["kg_enrichment_status"], ) for row in results ] diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index e586200bc..3074fad69 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -181,7 +181,11 @@ async def delete_from_documents_overview( await self.connection_manager.execute_query(query, params) async def _get_status_from_table( - self, ids: list[UUID], table_name: str, status_type: str + self, + ids: list[UUID], + table_name: str, + status_type: str, + column_name: str, ): """ Get the workflow status for a given document or list of documents. @@ -196,7 +200,7 @@ async def _get_status_from_table( """ query = f""" SELECT {status_type} FROM {self._get_table_name(table_name)} - WHERE document_id = ANY($1) + WHERE {column_name} = ANY($1) """ return await self.connection_manager.fetch_query(query, [ids]) @@ -226,7 +230,12 @@ async def _get_ids_from_table( return document_ids async def _set_status_in_table( - self, ids: list[UUID], status: str, table_name: str, status_type: str + self, + ids: list[UUID], + status: str, + table_name: str, + status_type: str, + column_name: str, ): """ Set the workflow status for a given document or list of documents. @@ -236,33 +245,31 @@ async def _set_status_in_table( status (str): The status to set. table_name (str): The table name. status_type (str): The type of status to set. + column_name (str): The column name in the table to update. """ query = f""" UPDATE {self._get_table_name(table_name)} SET {status_type} = $1 - WHERE document_id = Any($2) + WHERE {column_name} = Any($2) """ await self.connection_manager.execute_query(query, [status, ids]) - def _get_status_model_and_table_name(self, status_type: str): + def _get_status_model(self, status_type: str): """ - Get the status model and table name for a given status type. + Get the status model for a given status type. Args: status_type (str): The type of status to retrieve. Returns: - The status model and table name for the given status type. + The status model for the given status type. """ if status_type == "ingestion": - return IngestionStatus, "document_info" + return IngestionStatus elif status_type == "kg_extraction_status": - return KGExtractionStatus, "document_info" + return KGExtractionStatus elif status_type == "kg_enrichment_status": - return ( - KGEnrichmentStatus, - "collections", - ) # TODO: Rename to collection info? + return KGEnrichmentStatus else: raise R2RException( status_code=400, message=f"Invalid status type: {status_type}" @@ -282,15 +289,11 @@ async def get_workflow_status( The workflow status for the given document or list of documents. """ ids = [id] if isinstance(id, UUID) else id - out_model, table_name = self._get_status_model_and_table_name( - status_type - ) + out_model = self._get_status_model(status_type) result = list( map( - ( - await self._get_status_from_table( - ids, table_name, status_type - ) + await self._get_status_from_table( + ids, out_model.table_name, out_model.value, out_model.id_column ), out_model, ) @@ -309,11 +312,10 @@ async def set_workflow_status( status (str): The status to set. """ ids = [id] if isinstance(id, UUID) else id - out_model, table_name = self._get_status_model_and_table_name( - status_type - ) + out_model = self._get_status_model(status_type) + return await self._set_status_in_table( - ids, status, table_name, status_type + ids, status, out_model.table_name, out_model.value, out_model.id_column ) async def get_document_ids_by_status( @@ -334,11 +336,13 @@ async def get_document_ids_by_status( if isinstance(status, str): status = [status] - out_model, table_name = self._get_status_model_and_table_name( - status_type - ) + out_model = self._get_status_model(status_type) + print(out_model) + print(out_model.table_name) + print(out_model.value) + print(collection_id) result = await self._get_ids_from_table( - status, table_name, status_type, collection_id + status, out_model.table_name, str(out_model), collection_id ) return result diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 7c4daee3b..8a4e015af 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -123,6 +123,19 @@ class IngestionStatus(str, Enum): FAILED = "failed" SUCCESS = "success" + def __str__(self): + return self.value + + @property + def table_name(self) -> str: + """Returns the table name this status applies to.""" + return "document_info" + + @property + def id_column(self) -> str: + """Returns the id column this status applies to.""" + return "document_id" + class KGExtractionStatus(str, Enum): """Status of KG Creation per document.""" @@ -135,6 +148,16 @@ class KGExtractionStatus(str, Enum): def __str__(self): return self.value + @property + def table_name(self) -> str: + """Returns the table name this status applies to.""" + return "document_info" + + @property + def id_column(self) -> str: + """Returns the id column this status applies to.""" + return "document_id" + class KGEnrichmentStatus(str, Enum): """Status of KG Enrichment per collection.""" @@ -147,6 +170,16 @@ class KGEnrichmentStatus(str, Enum): def __str__(self): return self.value + @property + def table_name(self) -> str: + """Returns the table name this status applies to.""" + return "collections" + + @property + def id_column(self) -> str: + """Returns the id column this status applies to.""" + return "collection_id" + class DocumentInfo(R2RSerializable): """Base class for document information handling.""" diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index a285af620..9f07d6fc3 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -126,6 +126,7 @@ class CollectionOverviewResponse(BaseModel): updated_at: datetime user_count: int document_count: int + kg_enrichment_status: str class ConversationOverviewResponse(BaseModel): From ce73f2ed32007f005e10bc08dad157ed8fe0f454 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 31 Oct 2024 12:16:31 -0700 Subject: [PATCH 10/12] up --- .../main/orchestration/hatchet/kg_workflow.py | 38 +++++++++++-------- .../main/orchestration/simple/kg_workflow.py | 5 +++ py/core/providers/database/document.py | 17 +++++---- py/shared/abstractions/document.py | 32 +++++++--------- 4 files changed, 51 insertions(+), 41 deletions(-) diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 15abc4241..1e31dc6fa 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -26,6 +26,10 @@ def hatchet_kg_factory( def get_input_data_dict(input_data): for key, value in input_data.items(): + + if key == "collection_id": + input_data[key] = uuid.UUID(value) + if key == "kg_creation_settings": input_data[key] = json.loads(value) input_data[key]["generation_config"] = GenerationConfig( @@ -384,25 +388,29 @@ async def kg_community_summary(self, context: Context) -> dict: for i in range(total_workflows): offset = i * parallel_communities workflows.append( - context.aio.spawn_workflow( - "kg-community-summary", - { - "request": { - "offset": offset, - "limit": min( - parallel_communities, - num_communities - offset, - ), - "collection_id": collection_id, - **input_data["kg_enrichment_settings"], - } - }, - key=f"{i}/{total_workflows}_community_summary", - ) + ( + await context.aio.spawn_workflow( + "kg-community-summary", + { + "request": { + "offset": offset, + "limit": min( + parallel_communities, + num_communities - offset, + ), + "collection_id": str(collection_id), + **input_data["kg_enrichment_settings"], + } + }, + key=f"{i}/{total_workflows}_community_summary", + ) + ).result() ) results = await asyncio.gather(*workflows) + logger.info(f"Ran {len(results)} workflows for community summary") + # set status to success await self.kg_service.providers.database.set_workflow_status( id=collection_id, diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index a028de3ef..ff111b448 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -1,6 +1,7 @@ import json import logging import math +import uuid from core import GenerationConfig from core import R2RException @@ -15,6 +16,10 @@ def simple_kg_factory(service: KgService): def get_input_data_dict(input_data): for key, value in input_data.items(): + + if key == "collection_id": + input_data[key] = uuid.UUID(value) + if key == "kg_creation_settings": input_data[key] = json.loads(value) input_data[key]["generation_config"] = GenerationConfig( diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index 3074fad69..db08aaad8 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -293,7 +293,10 @@ async def get_workflow_status( result = list( map( await self._get_status_from_table( - ids, out_model.table_name, out_model.value, out_model.id_column + ids, + out_model.table_name(), + status_type, + out_model.id_column(), ), out_model, ) @@ -315,7 +318,11 @@ async def set_workflow_status( out_model = self._get_status_model(status_type) return await self._set_status_in_table( - ids, status, out_model.table_name, out_model.value, out_model.id_column + ids, + status, + out_model.table_name(), + status_type, + out_model.id_column(), ) async def get_document_ids_by_status( @@ -337,12 +344,8 @@ async def get_document_ids_by_status( status = [status] out_model = self._get_status_model(status_type) - print(out_model) - print(out_model.table_name) - print(out_model.value) - print(collection_id) result = await self._get_ids_from_table( - status, out_model.table_name, str(out_model), collection_id + status, out_model.table_name(), status_type, collection_id ) return result diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 8a4e015af..66064d337 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -4,7 +4,7 @@ import logging from datetime import datetime from enum import Enum -from typing import Optional, Union +from typing import Optional, Union, ClassVar from uuid import UUID, uuid4 from pydantic import Field @@ -126,14 +126,12 @@ class IngestionStatus(str, Enum): def __str__(self): return self.value - @property - def table_name(self) -> str: - """Returns the table name this status applies to.""" + @classmethod + def table_name(cls) -> str: return "document_info" - @property - def id_column(self) -> str: - """Returns the id column this status applies to.""" + @classmethod + def id_column(cls) -> str: return "document_id" @@ -148,14 +146,12 @@ class KGExtractionStatus(str, Enum): def __str__(self): return self.value - @property - def table_name(self) -> str: - """Returns the table name this status applies to.""" + @classmethod + def table_name(cls) -> str: return "document_info" - @property - def id_column(self) -> str: - """Returns the id column this status applies to.""" + @classmethod + def id_column(cls) -> str: return "document_id" @@ -170,14 +166,12 @@ class KGEnrichmentStatus(str, Enum): def __str__(self): return self.value - @property - def table_name(self) -> str: - """Returns the table name this status applies to.""" + @classmethod + def table_name(cls) -> str: return "collections" - @property - def id_column(self) -> str: - """Returns the id column this status applies to.""" + @classmethod + def id_column(cls) -> str: return "collection_id" From ef87b47aa79caa1edd092b906074368620aa0b61 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 31 Oct 2024 13:43:52 -0700 Subject: [PATCH 11/12] modify endpoint --- py/core/main/api/management_router.py | 7 ------- py/core/main/services/management_service.py | 6 ++++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/py/core/main/api/management_router.py b/py/core/main/api/management_router.py index 12c13a288..9fe1d97fe 100644 --- a/py/core/main/api/management_router.py +++ b/py/core/main/api/management_router.py @@ -887,10 +887,3 @@ async def delete_conversation( ) -> WrappedDeleteResponse: await self.service.delete_conversation(conversation_id) return None # type: ignore - - @self.router.get("/r2r_project_name") - @self.base_endpoint - async def r2r_project_name( - auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> dict: - return {"project_name": os.environ["R2R_PROJECT_NAME"]} diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 3490fbec0..006d83139 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -4,6 +4,7 @@ from uuid import UUID import toml +import os from core.base import ( AnalysisTypes, @@ -26,6 +27,9 @@ from ..config import R2RConfig from .base import Service +from importlib.metadata import version as get_version + + logger = logging.getLogger() @@ -196,6 +200,8 @@ async def app_settings(self, *args: Any, **kwargs: Any): return { "config": config_dict, "prompts": prompts, + "project_name": os.environ["R2R_PROJECT_NAME"], + "r2r_version": get_version("r2r"), } @telemetry_event("UsersOverview") From 24abc6df7173500102a6f14d4696c21dc9b46b56 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 31 Oct 2024 13:52:38 -0700 Subject: [PATCH 12/12] up --- py/core/main/services/management_service.py | 4 ++-- py/shared/api/models/management/responses.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 006d83139..e28712b85 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -200,8 +200,8 @@ async def app_settings(self, *args: Any, **kwargs: Any): return { "config": config_dict, "prompts": prompts, - "project_name": os.environ["R2R_PROJECT_NAME"], - "r2r_version": get_version("r2r"), + "r2r_project_name": os.environ["R2R_PROJECT_NAME"], + # "r2r_version": get_version("r2r"), } @telemetry_event("UsersOverview") diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index 9f07d6fc3..530d07579 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -55,6 +55,8 @@ class AnalyticsResponse(BaseModel): class AppSettingsResponse(BaseModel): config: dict[str, Any] prompts: dict[str, Any] + r2r_project_name: str + # r2r_version: str class ScoreCompletionResponse(BaseModel):