Skip to content

Commit

Permalink
feat: support relationship_meta_filters / similarity_top_k / oversamp…
Browse files Browse the repository at this point in the history
…ling_factor for retrieve api (#575)

close #570

For  `/admin/knowledge_bases/{kb_id}/graph/search` API, , add parameter:

- `relationship_meta_filters`: pass the filtering conditions.

For `/admin/retrieve/documents` API, add two parameters:
- `similarity_top_k` to control how many nodes should the vector search
return, if not set, using the value of `top_k` by default.
- `oversampling_factor`: This is similar to the `ef_search` parameter of
the HNSW index, the larger the parameter, the higher the recall rate.
Since TiDB does not yet support modifying the value of `ef_search`, the
current implementation uses subquery. The subquery returns the
`similarity_top_k * oversampling_factor` rows, and the outer query
finally returns the `similarity_top_k` rows.
- At this time, if you need to turn on `metadata_filter`, you need to
modify the Chat Engine configuration, this problem will be fixed after
the retrieve API refactor (#573), thr new retrieve API will not
dependant on the ChatEngine configuration.
  • Loading branch information
Mini256 authored Jan 7, 2025
1 parent 414d9c3 commit b5436f6
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GraphSearchRequest(BaseModel):
include_meta: bool = True
depth: int = 2
with_degree: bool = True
relationship_meta_filters: dict = {}


class KnowledgeRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def search_graph(session: SessionDep, kb_id: int, request: GraphSearchRequest):
request.include_meta,
request.with_degree,
False,
{},
request.relationship_meta_filters,
)
return {
"entities": entities,
Expand Down
2 changes: 2 additions & 0 deletions backend/app/api/admin_routes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ class RetrieveRequest(BaseModel):
query: str
chat_engine: Optional[str] = "default"
top_k: Optional[int] = 5
similarity_top_k: Optional[int] = None
oversampling_factor: Optional[int] = 5
25 changes: 22 additions & 3 deletions backend/app/api/admin_routes/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ async def retrieve_documents(
question: str,
chat_engine: str = "default",
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
) -> List[Document]:
retrieve_service = RetrieveService(session, chat_engine)
return retrieve_service.retrieve(question, top_k=top_k)
return retrieve_service.retrieve(
question,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
)


@router.get("/admin/embedding_retrieve")
Expand All @@ -29,9 +36,16 @@ async def embedding_retrieve(
question: str,
chat_engine: str = "default",
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
) -> List[NodeWithScore]:
retrieve_service = RetrieveService(session, chat_engine)
return retrieve_service._embedding_retrieve(question, top_k=top_k)
return retrieve_service._embedding_retrieve(
question,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
)


@router.post("/admin/embedding_retrieve")
Expand All @@ -41,4 +55,9 @@ async def embedding_search(
request: RetrieveRequest,
) -> List[NodeWithScore]:
retrieve_service = RetrieveService(session, request.chat_engine)
return retrieve_service._embedding_retrieve(request.query, top_k=request.top_k)
return retrieve_service._embedding_retrieve(
request.query,
top_k=request.top_k,
similarity_top_k=request.similarity_top_k,
oversampling_factor=request.oversampling_factor,
)
19 changes: 14 additions & 5 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,18 @@ def get_fast_dspy_lm(self, session: Session) -> dspy.LM:
llama_llm = self.get_fast_llama_llm(session)
return get_dspy_lm_by_llama_llm(llama_llm)

def get_reranker(self, session: Session) -> Optional[BaseNodePostprocessor]:
# FIXME: Reranker top_n should be config in the retrival config.
def get_reranker(
self, session: Session, top_n: int = None
) -> Optional[BaseNodePostprocessor]:
if not self._db_reranker:
return get_default_reranker_model(session)
return get_default_reranker_model(session, top_n)

top_n = self._db_reranker.top_n if top_n is None else top_n
return get_reranker_model(
self._db_reranker.provider,
self._db_reranker.model,
self._db_reranker.top_n,
top_n,
self._db_reranker.config,
self._db_reranker.credentials,
)
Expand Down Expand Up @@ -427,14 +432,18 @@ def get_reranker_model(
raise ValueError(f"Got unknown reranker provider: {provider}")


def get_default_reranker_model(session: Session) -> Optional[BaseNodePostprocessor]:
# FIXME: Reranker top_n should be config in the retrival config.
def get_default_reranker_model(
session: Session, top_n: int = None
) -> Optional[BaseNodePostprocessor]:
db_reranker = reranker_model_repo.get_default(session)
if not db_reranker:
return None
top_n = db_reranker.top_n if top_n is None else top_n
return get_reranker_model(
db_reranker.provider,
db_reranker.model,
db_reranker.top_n,
top_n,
db_reranker.config,
db_reranker.credentials,
)
Expand Down
86 changes: 59 additions & 27 deletions backend/app/rag/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Type
from typing import List, Optional, Type

from llama_index.core import VectorStoreIndex
from llama_index.core.schema import NodeWithScore
Expand Down Expand Up @@ -42,7 +42,6 @@ def __init__(

self.chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
self.db_chat_engine = self.chat_engine_config.get_db_chat_engine()
self._reranker = self.chat_engine_config.get_reranker(db_session)

if self.chat_engine_config.knowledge_base:
# TODO: Support multiple knowledge base retrieve.
Expand All @@ -57,7 +56,13 @@ def __init__(
else:
self._embed_model = get_default_embed_model(self.db_session)

def retrieve(self, question: str, top_k: int = 10) -> List[DBDocument]:
def retrieve(
self,
question: str,
top_k: int = 10,
similarity_top_k: Optional[int] = None,
oversampling_factor: int = 5,
) -> List[DBDocument]:
"""
Retrieve the related documents based on the user question.
Args:
Expand All @@ -69,13 +74,19 @@ def retrieve(self, question: str, top_k: int = 10) -> List[DBDocument]:
A list of related documents.
"""
try:
return self._retrieve(question, top_k)
return self._retrieve(
question, top_k, similarity_top_k, oversampling_factor
)
except Exception as e:
logger.exception(e)

def _retrieve(self, question: str, top_k: int) -> List[DBDocument]:
# TODO: move to __init__?
_llm = self.chat_engine_config.get_llama_llm(self.db_session)
def _retrieve(
self,
question: str,
top_k: int = 10,
similarity_top_k: Optional[int] = None,
oversampling_factor: int = 5,
) -> List[DBDocument]:
_fast_llm = self.chat_engine_config.get_fast_llama_llm(self.db_session)
_fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm(self.db_session)

Expand Down Expand Up @@ -135,47 +146,68 @@ def _retrieve(self, question: str, top_k: int) -> List[DBDocument]:

# 3. Retrieve the related chunks from the vector store
# 4. Rerank after the retrieval
# 5. Generate a response using the refined question and related chunks
text_qa_template = get_prompt_by_jinja2_template(
self.chat_engine_config.llm.text_qa_prompt,
graph_knowledges=graph_knowledges_context,
)
refine_template = get_prompt_by_jinja2_template(
self.chat_engine_config.llm.refine_prompt,
graph_knowledges=graph_knowledges_context,
)

# Vector Index
vector_store = TiDBVectorStore(
session=self.db_session, chunk_db_model=self._chunk_model
session=self.db_session,
chunk_db_model=self._chunk_model,
oversampling_factor=oversampling_factor,
)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
)

# Node postprocessors
metadata_filter = self.chat_engine_config.get_metadata_filter()
reranker = self.chat_engine_config.get_reranker(self.db_session, top_n=top_k)
if reranker:
node_postprocessors = [metadata_filter, reranker]
else:
node_postprocessors = [metadata_filter]

# Retriever Engine
retrieve_engine = vector_index.as_retriever(
node_postprocessors=[self._reranker],
streaming=True,
text_qa_template=text_qa_template,
refine_template=refine_template,
similarity_top_k=top_k,
similarity_top_k=similarity_top_k or top_k,
filters=metadata_filter.filters,
node_postprocessors=node_postprocessors,
)

node_list: List[NodeWithScore] = retrieve_engine.retrieve(refined_question)
source_documents = self._get_source_documents(node_list)

return source_documents

def _embedding_retrieve(self, question: str, top_k: int) -> List[NodeWithScore]:
def _embedding_retrieve(
self,
question: str,
top_k: int = 10,
similarity_top_k: Optional[int] = None,
oversampling_factor: int = 5,
) -> List[NodeWithScore]:
# Vector Index
vector_store = TiDBVectorStore(
session=self.db_session, chunk_db_model=self._chunk_model
session=self.db_session,
chunk_db_model=self._chunk_model,
oversampling_factor=oversampling_factor,
)
vector_index = VectorStoreIndex.from_vector_store(
vector_store, embed_model=self._embed_model
vector_store,
embed_model=self._embed_model,
)

# Node postprocessors
metadata_filter = self.chat_engine_config.get_metadata_filter()
reranker = self.chat_engine_config.get_reranker(self.db_session, top_n=top_k)
if reranker:
node_postprocessors = [metadata_filter, reranker]
else:
node_postprocessors = [metadata_filter]

# Retriever Engine
retrieve_engine = vector_index.as_retriever(
node_postprocessors=[self._reranker],
similarity_top_k=top_k,
node_postprocessors=node_postprocessors,
similarity_top_k=similarity_top_k or top_k,
)

node_list: List[NodeWithScore] = retrieve_engine.retrieve(question)
Expand Down
17 changes: 14 additions & 3 deletions backend/app/rag/vector_store/tidb_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,21 @@ def __init__(
self,
session: Optional[Session] = None,
chunk_db_model: SQLModel = Chunk,
oversampling_factor: int = 5,
**kwargs: Any,
) -> None:
"""
Args:
oversampling_factor (int): The oversampling factor for the similarity search. The higher the factor, the higher recall rate.
"""
super().__init__(**kwargs)
self._session = session
self._owns_session = session is None
if self._session is None:
self._session = Session(engine)

self._chunk_db_model = chunk_db_model
self._oversampling_factor = oversampling_factor

def ensure_table_schema(self) -> None:
inspector = sqlalchemy.inspect(engine)
Expand Down Expand Up @@ -179,7 +185,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
if query.query_embedding is None:
raise ValueError("Query embedding must be provided.")

stmt = select(
subquery = select(
self._chunk_db_model.id,
self._chunk_db_model.text,
self._chunk_db_model.meta,
Expand All @@ -190,9 +196,14 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul

if query.filters:
for f in query.filters.filters:
stmt = stmt.where(self._chunk_db_model.meta[f.key] == f.value)
subquery = subquery.where(self._chunk_db_model.meta[f.key] == f.value)

stmt = stmt.order_by(asc("distance")).limit(query.similarity_top_k)
subquery = (
subquery.order_by(asc("distance"))
.limit(query.similarity_top_k * self._oversampling_factor)
.subquery("sub")
)
stmt = select(subquery).order_by(asc("distance")).limit(query.similarity_top_k)
results = self._session.exec(stmt)

nodes = []
Expand Down

0 comments on commit b5436f6

Please sign in to comment.