Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support rerank models provided by vLLM, Xinference, Bedrock #572

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.postprocessor.jinaai_rerank import JinaRerank
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank
from sqlmodel import Session
from google.oauth2 import service_account
from google.auth.transport.requests import Request
Expand All @@ -30,6 +32,7 @@
from app.rag.node_postprocessor.metadata_post_filter import MetadataFilters
from app.rag.node_postprocessor.baisheng_reranker import BaishengRerank
from app.rag.node_postprocessor.local_reranker import LocalRerank
from app.rag.node_postprocessor.vllm_reranker import VLLMRerank
from app.rag.embeddings.local_embedding import LocalEmbedding
from app.repositories import chat_engine_repo, knowledge_base_repo
from app.repositories.embedding_model import embed_model_repo
Expand Down Expand Up @@ -423,6 +426,27 @@ def get_reranker_model(
top_n=top_n,
**config,
)
case RerankerProvider.VLLM:
return VLLMRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.XINFERENCE:
return XinferenceRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.BEDROCK:
return AWSBedrockRerank(
rerank_model_name=model,
top_n=top_n,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=credentials["aws_region_name"],
jrj5423 marked this conversation as resolved.
Show resolved Hide resolved
**config,
)
case _:
raise ValueError(f"Got unknown reranker provider: {provider}")

Expand Down
97 changes: 97 additions & 0 deletions backend/app/rag/node_postprocessor/vllm_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, List, Optional
import requests

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle

dispatcher = get_dispatcher(__name__)


class VLLMRerank(BaseNodePostprocessor):
base_url: str = Field(default="", description="The base URL of vLLM API.")
model: str = Field(default="", description="The model to use when calling API.")

top_n: int = Field(description="Top N nodes to return.")

_session: Any = PrivateAttr()

def __init__(
self,
top_n: int = 2,
model: str = "BAAI/bge-reranker-v2-m3",
base_url: str = "http://localhost:8000",
):
super().__init__(top_n=top_n, model=model)
self.base_url = base_url
self.model = model
self._session = requests.Session()

@classmethod
def class_name(cls) -> str:
return "VLLMRerank"

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.model,
)
)

if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []

with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
resp = self._session.post( # type: ignore
url=f"{self.base_url}/v1/score",
json={
"text_1": query_bundle.query_str,
"model": self.model,
"text_2": texts,
},
)
resp.raise_for_status()
resp_json = resp.json()
if "data" not in resp_json:
raise RuntimeError(f"Got error from reranker: {resp_json}")

results = zip(range(len(nodes)), resp_json["data"])
results = sorted(results, key=lambda x: x[1]["score"], reverse=True)[: self.top_n]

new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result[0]].node, score=result[1]["score"]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})

dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
49 changes: 49 additions & 0 deletions backend/app/rag/reranker_model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,53 @@ class RerankerModelOption(BaseModel):
credentials_type="str",
default_credentials="dummy",
),
RerankerModelOption(
provider=RerankerProvider.VLLM,
provider_display_name="vLLM",
provider_description="vLLM is a fast and easy-to-use library for LLM inference and serving.",
default_reranker_model="BAAI/bge-reranker-v2-m3",
reranker_model_description="Reference: https://docs.vllm.ai/en/latest/models/supported_models.html#sentence-pair-scoring-task-score",
default_top_n=10,
default_config={
"base_url": "http://localhost:8000",
},
config_description="base_url is the base url of the vLLM server, ensure it can be accessed from this server",
credentials_display_name="vLLM API Key",
credentials_description="vLLM doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
RerankerModelOption(
provider=RerankerProvider.XINFERENCE,
provider_display_name="Xinference Reranker",
provider_description="Xorbits Inference (Xinference) is an open-source platform to streamline the operation and integration of a wide array of AI models.",
default_reranker_model="bge-reranker-v2-m3",
reranker_model_description="Reference: https://inference.readthedocs.io/en/latest/models/model_abilities/rerank.html",
default_top_n=10,
default_config={
"base_url": "http://localhost:9997",
},
config_description="base_url is the url of the Xinference server, ensure it can be accessed from this server",
credentials_display_name="Xinference API Key",
credentials_description="Xinference doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
RerankerModelOption(
provider=RerankerProvider.BEDROCK,
provider_display_name="Bedrock Reranker",
provider_description="Amazon Bedrock is a fully managed foundation models service.",
provider_url="https://docs.aws.amazon.com/bedrock/",
default_reranker_model="amazon.rerank-v1:0",
reranker_model_description="Find more models in https://docs.aws.amazon.com/bedrock/latest/userguide/foundation-models-reference.html.",
default_top_n=10,
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2",
},
)
]
3 changes: 3 additions & 0 deletions backend/app/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class RerankerProvider(str, enum.Enum):
COHERE = "cohere"
BAISHENG = "baisheng"
LOCAL = "local"
VLLM = "vllm"
XINFERENCE = "xinference"
BEDROCK = "bedrock"


class MimeTypes(str, enum.Enum):
Expand Down
12 changes: 7 additions & 5 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ dependencies = [
"llama-index-postprocessor-cohere-rerank>=0.1.7",
"llama-index-llms-bedrock>=0.1.12",
"pypdf>=4.3.1",
"llama-index-llms-ollama<=0.3.0",
"llama-index-embeddings-ollama<=0.3.0",
"llama-index-embeddings-jinaai<=0.3.0",
"llama-index-embeddings-cohere<=0.3.0",
"llama-index-llms-ollama>=0.3.0",
"llama-index-embeddings-ollama>=0.3.0",
"llama-index-embeddings-jinaai>=0.3.0",
"llama-index-embeddings-cohere>=0.2.0",
"python-docx>=1.1.2",
"python-pptx>=1.0.2",
"colorama>=0.4.6",
Expand All @@ -58,7 +58,9 @@ dependencies = [
"retry>=0.9.2",
"langchain-openai>=0.2.9",
"ragas>=0.2.6",
"llama-index-embeddings-bedrock<=0.3.0",
"llama-index-embeddings-bedrock>=0.2.0",
"llama-index-postprocessor-xinference-rerank>=0.2.0",
"llama-index-postprocessor-bedrock-rerank>=0.3.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
Loading
Loading