diff --git a/backend/app/rag/embeddings/openai_like_embedding.py b/backend/app/rag/embeddings/openai_like_embedding.py index 8573fc66..9ba6be17 100644 --- a/backend/app/rag/embeddings/openai_like_embedding.py +++ b/backend/app/rag/embeddings/openai_like_embedding.py @@ -34,6 +34,7 @@ def __init__( **kwargs, ) + self._model_kwargs = kwargs or {} self.model = model self._client = OpenAI(api_key=api_key, base_url=api_base) self._aclient = AsyncOpenAI(api_key=api_key, base_url=api_base) @@ -42,7 +43,7 @@ def get_embeddings(self, sentences: list[str]) -> List[List[float]]: """Get embeddings.""" # Call Zhipu AI Embedding API via OpenAI client embedding_objs = self._client.embeddings.create( - input=sentences, model=self.model + input=sentences, model=self.model, **self._model_kwargs ).data embeddings = [obj.embedding for obj in embedding_objs] @@ -51,7 +52,7 @@ def get_embeddings(self, sentences: list[str]) -> List[List[float]]: async def aget_embeddings(self, sentences: list[str]) -> List[List[float]]: """Asynchronously get text embeddings.""" result = await self._aclient.embeddings.create( - input=sentences, model=self.model + input=sentences, model=self.model, **self._model_kwargs ) embeddings = [obj.embedding for obj in result.data] diff --git a/backend/app/rag/knowledge_graph/graph_store/tidb_graph_store.py b/backend/app/rag/knowledge_graph/graph_store/tidb_graph_store.py index 326a7dc7..04c3a678 100644 --- a/backend/app/rag/knowledge_graph/graph_store/tidb_graph_store.py +++ b/backend/app/rag/knowledge_graph/graph_store/tidb_graph_store.py @@ -211,6 +211,7 @@ def save(self, chunk_id, entities_df, relationships_df): description=row["description"], metadata=row["meta"], ), + commit=False, ) ) @@ -233,28 +234,36 @@ def _find_or_create_entity_for_relation( description=description, metadata={"status": "need-revised"}, ), + commit=False, ) - for _, row in relationships_df.iterrows(): - source_entity = _find_or_create_entity_for_relation( - row["source_entity"], row["source_entity_description"] - ) - target_entity = _find_or_create_entity_for_relation( - row["target_entity"], row["target_entity_description"] - ) + try: + for _, row in relationships_df.iterrows(): + logger.info("save entities for relationship %s -> %s -> %s", row["source_entity"], row["relationship_desc"], row["target_entity"]) + source_entity = _find_or_create_entity_for_relation( + row["source_entity"], row["source_entity_description"] + ) + target_entity = _find_or_create_entity_for_relation( + row["target_entity"], row["target_entity_description"] + ) - self.create_relationship( - source_entity, - target_entity, - Relationship( - source_entity=source_entity.name, - target_entity=target_entity.name, - relationship_desc=row["relationship_desc"], - ), - relationship_metadata=row["meta"], - commit=False, - ) - self._session.commit() + self.create_relationship( + source_entity, + target_entity, + Relationship( + source_entity=source_entity.name, + target_entity=target_entity.name, + relationship_desc=row["relationship_desc"], + ), + relationship_metadata=row["meta"], + commit=False, + ) + + self._session.commit() + except Exception as e: + logger.error(e, exc_info=True) + self._session.rollback() + raise e def create_relationship( self, @@ -283,8 +292,11 @@ def create_relationship( self._session.add(relationship_object) if commit: self._session.commit() + self._session.refresh(relationship_object) + else: + self._session.flush() - def get_or_create_entity(self, entity: Entity) -> SQLModel: + def get_or_create_entity(self, entity: Entity, commit: bool = True) -> SQLModel: # using the cosine distance between the description vectors to determine if the entity already exists entity_type = ( EntityType.synopsis @@ -347,8 +359,11 @@ def get_or_create_entity(self, entity: Entity) -> SQLModel: db_obj.meta_vec = get_entity_metadata_embedding( db_obj.meta, self._embed_model ) - self._session.commit() - self._session.refresh(db_obj) + if commit: + self._session.commit() + self._session.refresh(db_obj) + else: + self._session.flush() return db_obj synopsis_info_str = ( @@ -367,8 +382,12 @@ def get_or_create_entity(self, entity: Entity) -> SQLModel: entity_type=entity_type, ) self._session.add(db_obj) - self._session.commit() - self._session.refresh(db_obj) + if commit: + self._session.commit() + self._session.refresh(db_obj) + else: + self._session.flush() + return db_obj def _try_merge_entities(self, entities: List[Entity]) -> Entity: diff --git a/backend/app/tasks/build_index.py b/backend/app/tasks/build_index.py index df20680d..618187d5 100644 --- a/backend/app/tasks/build_index.py +++ b/backend/app/tasks/build_index.py @@ -135,7 +135,7 @@ def build_kg_index_for_chunk(knowledge_base_id: int, chunk_id: UUID): with Session(engine) as session: error_msg = traceback.format_exc() logger.error( - f"Failed to build knowledge graph index for chunk #{chunk_id}: {error_msg}" + f"Failed to build knowledge graph index for chunk #{chunk_id}", exc_info=True ) db_chunk.index_status = KgIndexStatus.FAILED db_chunk.index_result = error_msg