Skip to content

Commit

Permalink
Fix embedding model bug and support arm platform (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
IANTHEREAL authored Dec 25, 2024
1 parent dc6030a commit b7a2022
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 27 deletions.
5 changes: 3 additions & 2 deletions backend/app/rag/embeddings/openai_like_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
**kwargs,
)

self._model_kwargs = kwargs or {}
self.model = model
self._client = OpenAI(api_key=api_key, base_url=api_base)
self._aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
Expand All @@ -42,7 +43,7 @@ def get_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Get embeddings."""
# Call Zhipu AI Embedding API via OpenAI client
embedding_objs = self._client.embeddings.create(
input=sentences, model=self.model
input=sentences, model=self.model, **self._model_kwargs
).data
embeddings = [obj.embedding for obj in embedding_objs]

Expand All @@ -51,7 +52,7 @@ def get_embeddings(self, sentences: list[str]) -> List[List[float]]:
async def aget_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
result = await self._aclient.embeddings.create(
input=sentences, model=self.model
input=sentences, model=self.model, **self._model_kwargs
)
embeddings = [obj.embedding for obj in result.data]

Expand Down
67 changes: 43 additions & 24 deletions backend/app/rag/knowledge_graph/graph_store/tidb_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def save(self, chunk_id, entities_df, relationships_df):
description=row["description"],
metadata=row["meta"],
),
commit=False,
)
)

Expand All @@ -233,28 +234,36 @@ def _find_or_create_entity_for_relation(
description=description,
metadata={"status": "need-revised"},
),
commit=False,
)

for _, row in relationships_df.iterrows():
source_entity = _find_or_create_entity_for_relation(
row["source_entity"], row["source_entity_description"]
)
target_entity = _find_or_create_entity_for_relation(
row["target_entity"], row["target_entity_description"]
)
try:
for _, row in relationships_df.iterrows():
logger.info("save entities for relationship %s -> %s -> %s", row["source_entity"], row["relationship_desc"], row["target_entity"])
source_entity = _find_or_create_entity_for_relation(
row["source_entity"], row["source_entity_description"]
)
target_entity = _find_or_create_entity_for_relation(
row["target_entity"], row["target_entity_description"]
)

self.create_relationship(
source_entity,
target_entity,
Relationship(
source_entity=source_entity.name,
target_entity=target_entity.name,
relationship_desc=row["relationship_desc"],
),
relationship_metadata=row["meta"],
commit=False,
)
self._session.commit()
self.create_relationship(
source_entity,
target_entity,
Relationship(
source_entity=source_entity.name,
target_entity=target_entity.name,
relationship_desc=row["relationship_desc"],
),
relationship_metadata=row["meta"],
commit=False,
)

self._session.commit()
except Exception as e:
logger.error(e, exc_info=True)
self._session.rollback()
raise e

def create_relationship(
self,
Expand Down Expand Up @@ -283,8 +292,11 @@ def create_relationship(
self._session.add(relationship_object)
if commit:
self._session.commit()
self._session.refresh(relationship_object)
else:
self._session.flush()

def get_or_create_entity(self, entity: Entity) -> SQLModel:
def get_or_create_entity(self, entity: Entity, commit: bool = True) -> SQLModel:
# using the cosine distance between the description vectors to determine if the entity already exists
entity_type = (
EntityType.synopsis
Expand Down Expand Up @@ -347,8 +359,11 @@ def get_or_create_entity(self, entity: Entity) -> SQLModel:
db_obj.meta_vec = get_entity_metadata_embedding(
db_obj.meta, self._embed_model
)
self._session.commit()
self._session.refresh(db_obj)
if commit:
self._session.commit()
self._session.refresh(db_obj)
else:
self._session.flush()
return db_obj

synopsis_info_str = (
Expand All @@ -367,8 +382,12 @@ def get_or_create_entity(self, entity: Entity) -> SQLModel:
entity_type=entity_type,
)
self._session.add(db_obj)
self._session.commit()
self._session.refresh(db_obj)
if commit:
self._session.commit()
self._session.refresh(db_obj)
else:
self._session.flush()

return db_obj

def _try_merge_entities(self, entities: List[Entity]) -> Entity:
Expand Down
2 changes: 1 addition & 1 deletion backend/app/tasks/build_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def build_kg_index_for_chunk(knowledge_base_id: int, chunk_id: UUID):
with Session(engine) as session:
error_msg = traceback.format_exc()
logger.error(
f"Failed to build knowledge graph index for chunk #{chunk_id}: {error_msg}"
f"Failed to build knowledge graph index for chunk #{chunk_id}", exc_info=True
)
db_chunk.index_status = KgIndexStatus.FAILED
db_chunk.index_result = error_msg
Expand Down

0 comments on commit b7a2022

Please sign in to comment.