diff --git a/backend/app/api/admin_routes/knowledge_base/graph/models.py b/backend/app/api/admin_routes/knowledge_base/graph/models.py index 18156c8a..ca7f45a8 100644 --- a/backend/app/api/admin_routes/knowledge_base/graph/models.py +++ b/backend/app/api/admin_routes/knowledge_base/graph/models.py @@ -33,6 +33,7 @@ class GraphSearchRequest(BaseModel): include_meta: bool = True depth: int = 2 with_degree: bool = True + relationship_meta_filters: dict = {} class KnowledgeRequest(BaseModel): diff --git a/backend/app/api/admin_routes/knowledge_base/graph/routes.py b/backend/app/api/admin_routes/knowledge_base/graph/routes.py index 943681cb..92679d54 100644 --- a/backend/app/api/admin_routes/knowledge_base/graph/routes.py +++ b/backend/app/api/admin_routes/knowledge_base/graph/routes.py @@ -207,7 +207,7 @@ def search_graph(session: SessionDep, kb_id: int, request: GraphSearchRequest): request.include_meta, request.with_degree, False, - {}, + request.relationship_meta_filters, ) return { "entities": entities, diff --git a/backend/app/api/admin_routes/models.py b/backend/app/api/admin_routes/models.py index 3319dc08..b1b3acbd 100644 --- a/backend/app/api/admin_routes/models.py +++ b/backend/app/api/admin_routes/models.py @@ -42,3 +42,5 @@ class RetrieveRequest(BaseModel): query: str chat_engine: Optional[str] = "default" top_k: Optional[int] = 5 + similarity_top_k: Optional[int] = None + oversampling_factor: Optional[int] = 5 diff --git a/backend/app/api/admin_routes/retrieve.py b/backend/app/api/admin_routes/retrieve.py index 50d160b3..b96a5b00 100644 --- a/backend/app/api/admin_routes/retrieve.py +++ b/backend/app/api/admin_routes/retrieve.py @@ -17,9 +17,16 @@ async def retrieve_documents( question: str, chat_engine: str = "default", top_k: Optional[int] = 5, + similarity_top_k: Optional[int] = None, + oversampling_factor: Optional[int] = 5, ) -> List[Document]: retrieve_service = RetrieveService(session, chat_engine) - return retrieve_service.retrieve(question, top_k=top_k) + return retrieve_service.retrieve( + question, + top_k=top_k, + similarity_top_k=similarity_top_k, + oversampling_factor=oversampling_factor, + ) @router.get("/admin/embedding_retrieve") @@ -29,9 +36,16 @@ async def embedding_retrieve( question: str, chat_engine: str = "default", top_k: Optional[int] = 5, + similarity_top_k: Optional[int] = None, + oversampling_factor: Optional[int] = 5, ) -> List[NodeWithScore]: retrieve_service = RetrieveService(session, chat_engine) - return retrieve_service._embedding_retrieve(question, top_k=top_k) + return retrieve_service._embedding_retrieve( + question, + top_k=top_k, + similarity_top_k=similarity_top_k, + oversampling_factor=oversampling_factor, + ) @router.post("/admin/embedding_retrieve") @@ -41,4 +55,9 @@ async def embedding_search( request: RetrieveRequest, ) -> List[NodeWithScore]: retrieve_service = RetrieveService(session, request.chat_engine) - return retrieve_service._embedding_retrieve(request.query, top_k=request.top_k) + return retrieve_service._embedding_retrieve( + request.query, + top_k=request.top_k, + similarity_top_k=request.similarity_top_k, + oversampling_factor=request.oversampling_factor, + ) diff --git a/backend/app/rag/chat_config.py b/backend/app/rag/chat_config.py index cb6bf735..029bd0e8 100644 --- a/backend/app/rag/chat_config.py +++ b/backend/app/rag/chat_config.py @@ -176,13 +176,18 @@ def get_fast_dspy_lm(self, session: Session) -> dspy.LM: llama_llm = self.get_fast_llama_llm(session) return get_dspy_lm_by_llama_llm(llama_llm) - def get_reranker(self, session: Session) -> Optional[BaseNodePostprocessor]: + # FIXME: Reranker top_n should be config in the retrival config. + def get_reranker( + self, session: Session, top_n: int = None + ) -> Optional[BaseNodePostprocessor]: if not self._db_reranker: - return get_default_reranker_model(session) + return get_default_reranker_model(session, top_n) + + top_n = self._db_reranker.top_n if top_n is None else top_n return get_reranker_model( self._db_reranker.provider, self._db_reranker.model, - self._db_reranker.top_n, + top_n, self._db_reranker.config, self._db_reranker.credentials, ) @@ -427,14 +432,18 @@ def get_reranker_model( raise ValueError(f"Got unknown reranker provider: {provider}") -def get_default_reranker_model(session: Session) -> Optional[BaseNodePostprocessor]: +# FIXME: Reranker top_n should be config in the retrival config. +def get_default_reranker_model( + session: Session, top_n: int = None +) -> Optional[BaseNodePostprocessor]: db_reranker = reranker_model_repo.get_default(session) if not db_reranker: return None + top_n = db_reranker.top_n if top_n is None else top_n return get_reranker_model( db_reranker.provider, db_reranker.model, - db_reranker.top_n, + top_n, db_reranker.config, db_reranker.credentials, ) diff --git a/backend/app/rag/retrieve.py b/backend/app/rag/retrieve.py index 5013a34f..e20e9570 100644 --- a/backend/app/rag/retrieve.py +++ b/backend/app/rag/retrieve.py @@ -1,5 +1,5 @@ import logging -from typing import List, Type +from typing import List, Optional, Type from llama_index.core import VectorStoreIndex from llama_index.core.schema import NodeWithScore @@ -42,7 +42,6 @@ def __init__( self.chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name) self.db_chat_engine = self.chat_engine_config.get_db_chat_engine() - self._reranker = self.chat_engine_config.get_reranker(db_session) if self.chat_engine_config.knowledge_base: # TODO: Support multiple knowledge base retrieve. @@ -57,7 +56,13 @@ def __init__( else: self._embed_model = get_default_embed_model(self.db_session) - def retrieve(self, question: str, top_k: int = 10) -> List[DBDocument]: + def retrieve( + self, + question: str, + top_k: int = 10, + similarity_top_k: Optional[int] = None, + oversampling_factor: int = 5, + ) -> List[DBDocument]: """ Retrieve the related documents based on the user question. Args: @@ -69,13 +74,19 @@ def retrieve(self, question: str, top_k: int = 10) -> List[DBDocument]: A list of related documents. """ try: - return self._retrieve(question, top_k) + return self._retrieve( + question, top_k, similarity_top_k, oversampling_factor + ) except Exception as e: logger.exception(e) - def _retrieve(self, question: str, top_k: int) -> List[DBDocument]: - # TODO: move to __init__? - _llm = self.chat_engine_config.get_llama_llm(self.db_session) + def _retrieve( + self, + question: str, + top_k: int = 10, + similarity_top_k: Optional[int] = None, + oversampling_factor: int = 5, + ) -> List[DBDocument]: _fast_llm = self.chat_engine_config.get_fast_llama_llm(self.db_session) _fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm(self.db_session) @@ -135,29 +146,31 @@ def _retrieve(self, question: str, top_k: int) -> List[DBDocument]: # 3. Retrieve the related chunks from the vector store # 4. Rerank after the retrieval - # 5. Generate a response using the refined question and related chunks - text_qa_template = get_prompt_by_jinja2_template( - self.chat_engine_config.llm.text_qa_prompt, - graph_knowledges=graph_knowledges_context, - ) - refine_template = get_prompt_by_jinja2_template( - self.chat_engine_config.llm.refine_prompt, - graph_knowledges=graph_knowledges_context, - ) + + # Vector Index vector_store = TiDBVectorStore( - session=self.db_session, chunk_db_model=self._chunk_model + session=self.db_session, + chunk_db_model=self._chunk_model, + oversampling_factor=oversampling_factor, ) vector_index = VectorStoreIndex.from_vector_store( vector_store, embed_model=self._embed_model, ) + # Node postprocessors + metadata_filter = self.chat_engine_config.get_metadata_filter() + reranker = self.chat_engine_config.get_reranker(self.db_session, top_n=top_k) + if reranker: + node_postprocessors = [metadata_filter, reranker] + else: + node_postprocessors = [metadata_filter] + + # Retriever Engine retrieve_engine = vector_index.as_retriever( - node_postprocessors=[self._reranker], - streaming=True, - text_qa_template=text_qa_template, - refine_template=refine_template, - similarity_top_k=top_k, + similarity_top_k=similarity_top_k or top_k, + filters=metadata_filter.filters, + node_postprocessors=node_postprocessors, ) node_list: List[NodeWithScore] = retrieve_engine.retrieve(refined_question) @@ -165,17 +178,36 @@ def _retrieve(self, question: str, top_k: int) -> List[DBDocument]: return source_documents - def _embedding_retrieve(self, question: str, top_k: int) -> List[NodeWithScore]: + def _embedding_retrieve( + self, + question: str, + top_k: int = 10, + similarity_top_k: Optional[int] = None, + oversampling_factor: int = 5, + ) -> List[NodeWithScore]: + # Vector Index vector_store = TiDBVectorStore( - session=self.db_session, chunk_db_model=self._chunk_model + session=self.db_session, + chunk_db_model=self._chunk_model, + oversampling_factor=oversampling_factor, ) vector_index = VectorStoreIndex.from_vector_store( - vector_store, embed_model=self._embed_model + vector_store, + embed_model=self._embed_model, ) + # Node postprocessors + metadata_filter = self.chat_engine_config.get_metadata_filter() + reranker = self.chat_engine_config.get_reranker(self.db_session, top_n=top_k) + if reranker: + node_postprocessors = [metadata_filter, reranker] + else: + node_postprocessors = [metadata_filter] + + # Retriever Engine retrieve_engine = vector_index.as_retriever( - node_postprocessors=[self._reranker], - similarity_top_k=top_k, + node_postprocessors=node_postprocessors, + similarity_top_k=similarity_top_k or top_k, ) node_list: List[NodeWithScore] = retrieve_engine.retrieve(question) diff --git a/backend/app/rag/vector_store/tidb_vector_store.py b/backend/app/rag/vector_store/tidb_vector_store.py index e1acc65c..6748fb70 100644 --- a/backend/app/rag/vector_store/tidb_vector_store.py +++ b/backend/app/rag/vector_store/tidb_vector_store.py @@ -55,8 +55,13 @@ def __init__( self, session: Optional[Session] = None, chunk_db_model: SQLModel = Chunk, + oversampling_factor: int = 5, **kwargs: Any, ) -> None: + """ + Args: + oversampling_factor (int): The oversampling factor for the similarity search. The higher the factor, the higher recall rate. + """ super().__init__(**kwargs) self._session = session self._owns_session = session is None @@ -64,6 +69,7 @@ def __init__( self._session = Session(engine) self._chunk_db_model = chunk_db_model + self._oversampling_factor = oversampling_factor def ensure_table_schema(self) -> None: inspector = sqlalchemy.inspect(engine) @@ -179,7 +185,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul if query.query_embedding is None: raise ValueError("Query embedding must be provided.") - stmt = select( + subquery = select( self._chunk_db_model.id, self._chunk_db_model.text, self._chunk_db_model.meta, @@ -190,9 +196,14 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul if query.filters: for f in query.filters.filters: - stmt = stmt.where(self._chunk_db_model.meta[f.key] == f.value) + subquery = subquery.where(self._chunk_db_model.meta[f.key] == f.value) - stmt = stmt.order_by(asc("distance")).limit(query.similarity_top_k) + subquery = ( + subquery.order_by(asc("distance")) + .limit(query.similarity_top_k * self._oversampling_factor) + .subquery("sub") + ) + stmt = select(subquery).order_by(asc("distance")).limit(query.similarity_top_k) results = self._session.exec(stmt) nodes = []