Skip to content

Commit

Permalink
Reranker implementation (#20)
Browse files Browse the repository at this point in the history
* add reranker

* add reranker to rag service

* fix mypy errors

* fix eval mypy errors

* increase retrieval_k and reduce max_context_length

* fix lock file

* update lock file

* add retriever and use_reranking argument

* [pre-commit.ci] Add auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix mypy error

* update docs

* generalize prompt, add rerank docstring, and fix api docs

* change query rerank name

---------

Co-authored-by: Amrit Krishnan <amrit110@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent be22414 commit 7a02a47
Show file tree
Hide file tree
Showing 10 changed files with 369 additions and 171 deletions.
8 changes: 6 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ Health Recommendations
"query": "I need mental health support",
"latitude": 43.6532,
"longitude": -79.3832,
"radius": 5000
"radius": 5000,
"rerank": false
}
:<json string query: The user's health-related query (required)
:<json number latitude: Optional latitude for location-based search
:<json number longitude: Optional longitude for location-based search
:<json number radius: Optional search radius in meters
:<json number radius: Optional search radius in meters (default: 5000)
:<json boolean rerank: Optional flag to enable/disable reranking of the services (default: false)
:>json string recommendation: Generated recommendation text
:>json array services: List of relevant health services

**Response Body**

Expand Down
2 changes: 1 addition & 1 deletion eval/evaluate_topkacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Dict, Any, DefaultDict


def load_embeddings(path: str) -> torch.Tensor:
def load_embeddings(path: str) -> Any:
return torch.load(path)


Expand Down
4 changes: 4 additions & 0 deletions health_rec/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ class Config:
CHROMA_PORT: int = 8000
COLLECTION_NAME: str = getenv("COLLECTION_NAME", "211_gta")
RELEVANCY_WEIGHT: float = float(getenv("RELEVANCY_WEIGHT", "0.5"))
MAX_CONTEXT_LENGTH: int = 300
TOP_K: int = 5
RERANKER_MAX_CONTEXT_LENGTH: int = 150
RERANKER_MAX_SERVICES: int = 20
3 changes: 3 additions & 0 deletions health_rec/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,15 @@ class Query(BaseModel):
The latitude coordinate of the user.
radius : Optional[float]
The radius of the search.
rerank : Optional[bool]
Whether to use reranking for the recommendations.
"""

query: str
latitude: Optional[float] = Field(default=None)
longitude: Optional[float] = Field(default=None)
radius: Optional[float] = Field(default=None)
rerank: Optional[bool] = Field(default=False)


class RefineRequest(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion health_rec/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, texts: Documents) -> Embeddings:
"""
try:
response = self.client.embeddings.create(input=texts, model=self.model)
return [data.embedding for data in response.data]
return [data.embedding for data in response.data] # type: ignore
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
Expand Down
2 changes: 1 addition & 1 deletion health_rec/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion health_rec/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ python = "^3.11"
fastapi = "^0.115.2"
uvicorn = "^0.30.6"
openai = "^1.45.1"
chromadb = "^0.5.5"
chromadb = "0.5.15"
python-dotenv = "^1.0.1"

[tool.poetry.group.test]
Expand Down
Loading

0 comments on commit 7a02a47

Please sign in to comment.