Skip to content

Commit

Permalink
feat: support rerank models provided by vLLM, Xinference, Bedrock (#572)
Browse files Browse the repository at this point in the history
Close #537.

- Support for the rerank models provided by vLLM, Xinference and Amazon
Bedrock.
- Since Ollama does not support rerank models officially for the time
being, these changes do not include support for Ollama.

---------

Co-authored-by: Mini256 <minianter@foxmail.com>
  • Loading branch information
jrj5423 and Mini256 authored Jan 10, 2025
1 parent 69e6098 commit 39319ed
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 109 deletions.
24 changes: 24 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.postprocessor.jinaai_rerank import JinaRerank
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank
from sqlmodel import Session
from google.oauth2 import service_account
from google.auth.transport.requests import Request
Expand All @@ -29,6 +31,7 @@
from app.rag.node_postprocessor.metadata_post_filter import MetadataFilters
from app.rag.node_postprocessor.baisheng_reranker import BaishengRerank
from app.rag.node_postprocessor.local_reranker import LocalRerank
from app.rag.node_postprocessor.vllm_reranker import VLLMRerank
from app.rag.embeddings.local_embedding import LocalEmbedding
from app.repositories import chat_engine_repo, knowledge_base_repo
from app.repositories.embedding_model import embed_model_repo
Expand Down Expand Up @@ -425,6 +428,27 @@ def get_reranker_model(
top_n=top_n,
**config,
)
case RerankerProvider.VLLM:
return VLLMRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.XINFERENCE:
return XinferenceRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.BEDROCK:
return AWSBedrockRerank(
rerank_model_name=model,
top_n=top_n,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=credentials["aws_region_name"],
**config,
)
case _:
raise ValueError(f"Got unknown reranker provider: {provider}")

Expand Down
97 changes: 97 additions & 0 deletions backend/app/rag/node_postprocessor/vllm_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, List, Optional
import requests

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle

dispatcher = get_dispatcher(__name__)


class VLLMRerank(BaseNodePostprocessor):
base_url: str = Field(default="", description="The base URL of vLLM API.")
model: str = Field(default="", description="The model to use when calling API.")

top_n: int = Field(description="Top N nodes to return.")

_session: Any = PrivateAttr()

def __init__(
self,
top_n: int = 2,
model: str = "BAAI/bge-reranker-v2-m3",
base_url: str = "http://localhost:8000",
):
super().__init__(top_n=top_n, model=model)
self.base_url = base_url
self.model = model
self._session = requests.Session()

@classmethod
def class_name(cls) -> str:
return "VLLMRerank"

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.model,
)
)

if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []

with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
resp = self._session.post( # type: ignore
url=f"{self.base_url}/v1/score",
json={
"text_1": query_bundle.query_str,
"model": self.model,
"text_2": texts,
},
)
resp.raise_for_status()
resp_json = resp.json()
if "data" not in resp_json:
raise RuntimeError(f"Got error from reranker: {resp_json}")

results = zip(range(len(nodes)), resp_json["data"])
results = sorted(results, key=lambda x: x[1]["score"], reverse=True)[: self.top_n]

new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result[0]].node, score=result[1]["score"]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})

dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
49 changes: 49 additions & 0 deletions backend/app/rag/reranker_model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,53 @@ class RerankerModelOption(BaseModel):
credentials_type="str",
default_credentials="dummy",
),
RerankerModelOption(
provider=RerankerProvider.VLLM,
provider_display_name="vLLM",
provider_description="vLLM is a fast and easy-to-use library for LLM inference and serving.",
default_reranker_model="BAAI/bge-reranker-v2-m3",
reranker_model_description="Reference: https://docs.vllm.ai/en/latest/models/supported_models.html#sentence-pair-scoring-task-score",
default_top_n=10,
default_config={
"base_url": "http://localhost:8000",
},
config_description="base_url is the base url of the vLLM server, ensure it can be accessed from this server",
credentials_display_name="vLLM API Key",
credentials_description="vLLM doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
RerankerModelOption(
provider=RerankerProvider.XINFERENCE,
provider_display_name="Xinference Reranker",
provider_description="Xorbits Inference (Xinference) is an open-source platform to streamline the operation and integration of a wide array of AI models.",
default_reranker_model="bge-reranker-v2-m3",
reranker_model_description="Reference: https://inference.readthedocs.io/en/latest/models/model_abilities/rerank.html",
default_top_n=10,
default_config={
"base_url": "http://localhost:9997",
},
config_description="base_url is the url of the Xinference server, ensure it can be accessed from this server",
credentials_display_name="Xinference API Key",
credentials_description="Xinference doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
RerankerModelOption(
provider=RerankerProvider.BEDROCK,
provider_display_name="Bedrock Reranker",
provider_description="Amazon Bedrock is a fully managed foundation models service.",
provider_url="https://docs.aws.amazon.com/bedrock/",
default_reranker_model="amazon.rerank-v1:0",
reranker_model_description="Find more models in https://docs.aws.amazon.com/bedrock/latest/userguide/foundation-models-reference.html.",
default_top_n=10,
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2",
},
)
]
3 changes: 3 additions & 0 deletions backend/app/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class RerankerProvider(str, enum.Enum):
COHERE = "cohere"
BAISHENG = "baisheng"
LOCAL = "local"
VLLM = "vllm"
XINFERENCE = "xinference"
BEDROCK = "bedrock"


class MimeTypes(str, enum.Enum):
Expand Down
12 changes: 7 additions & 5 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ dependencies = [
"llama-index-postprocessor-cohere-rerank>=0.1.7",
"llama-index-llms-bedrock>=0.1.12",
"pypdf>=4.3.1",
"llama-index-llms-ollama<=0.3.0",
"llama-index-embeddings-ollama<=0.3.0",
"llama-index-embeddings-jinaai<=0.3.0",
"llama-index-embeddings-cohere<=0.3.0",
"llama-index-llms-ollama>=0.3.0",
"llama-index-embeddings-ollama>=0.3.0",
"llama-index-embeddings-jinaai>=0.3.0",
"llama-index-embeddings-cohere>=0.2.0",
"python-docx>=1.1.2",
"python-pptx>=1.0.2",
"colorama>=0.4.6",
Expand All @@ -58,7 +58,9 @@ dependencies = [
"retry>=0.9.2",
"langchain-openai>=0.2.9",
"ragas>=0.2.6",
"llama-index-embeddings-bedrock<=0.3.0",
"llama-index-embeddings-bedrock>=0.2.0",
"llama-index-postprocessor-xinference-rerank>=0.2.0",
"llama-index-postprocessor-bedrock-rerank>=0.3.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
Loading

0 comments on commit 39319ed

Please sign in to comment.