-
-
Notifications
You must be signed in to change notification settings - Fork 269
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into Feature/#962
- Loading branch information
Showing
14 changed files
with
843 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.3.9 | ||
0.3.10rc1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
import logging | ||
|
||
from datetime import timedelta | ||
|
||
from couchbase.auth import PasswordAuthenticator | ||
from couchbase.cluster import Cluster | ||
from couchbase.options import ClusterOptions | ||
|
||
from typing import List, Tuple, Optional | ||
|
||
from autorag.utils.util import make_batch | ||
from autorag.vectordb import BaseVectorStore | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
class Couchbase(BaseVectorStore): | ||
def __init__( | ||
self, | ||
embedding_model: str, | ||
bucket_name: str, | ||
scope_name: str, | ||
collection_name: str, | ||
index_name: str, | ||
embedding_batch: int = 100, | ||
connection_string: str = "", | ||
username: str = "", | ||
password: str = "", | ||
ingest_batch: int = 100, | ||
text_key: Optional[str] = "text", | ||
embedding_key: Optional[str] = "embedding", | ||
scoped_index: bool = True, | ||
): | ||
super().__init__( | ||
embedding_model=embedding_model, | ||
similarity_metric="ip", | ||
embedding_batch=embedding_batch, | ||
) | ||
|
||
self.index_name = index_name | ||
self.bucket_name = bucket_name | ||
self.scope_name = scope_name | ||
self.collection_name = collection_name | ||
self.scoped_index = scoped_index | ||
self.text_key = text_key | ||
self.embedding_key = embedding_key | ||
self.ingest_batch = ingest_batch | ||
|
||
auth = PasswordAuthenticator(username, password) | ||
self.cluster = Cluster(connection_string, ClusterOptions(auth)) | ||
|
||
# Wait until the cluster is ready for use. | ||
self.cluster.wait_until_ready(timedelta(seconds=5)) | ||
|
||
# Check if the bucket exists | ||
if not self._check_bucket_exists(): | ||
raise ValueError( | ||
f"Bucket {self.bucket_name} does not exist. " | ||
" Please create the bucket before searching." | ||
) | ||
|
||
try: | ||
self.bucket = self.cluster.bucket(self.bucket_name) | ||
self.scope = self.bucket.scope(self.scope_name) | ||
self.collection = self.scope.collection(self.collection_name) | ||
except Exception as e: | ||
raise ValueError( | ||
"Error connecting to couchbase. " | ||
"Please check the connection and credentials." | ||
) from e | ||
|
||
# Check if the index exists. Throws ValueError if it doesn't | ||
try: | ||
self._check_index_exists() | ||
except Exception: | ||
raise | ||
|
||
# Reinitialize to ensure a consistent state | ||
self.bucket = self.cluster.bucket(self.bucket_name) | ||
self.scope = self.bucket.scope(self.scope_name) | ||
self.collection = self.scope.collection(self.collection_name) | ||
|
||
async def add(self, ids: List[str], texts: List[str]): | ||
from couchbase.exceptions import DocumentExistsException | ||
|
||
texts = self.truncated_inputs(texts) | ||
text_embeddings: List[ | ||
List[float] | ||
] = await self.embedding.aget_text_embedding_batch(texts) | ||
|
||
documents_to_insert = [] | ||
for _id, text, embedding in zip(ids, texts, text_embeddings): | ||
doc = { | ||
self.text_key: text, | ||
self.embedding_key: embedding, | ||
} | ||
documents_to_insert.append({_id: doc}) | ||
|
||
batch_documents_to_insert = make_batch(documents_to_insert, self.ingest_batch) | ||
|
||
for batch in batch_documents_to_insert: | ||
insert_batch = {} | ||
for doc in batch: | ||
insert_batch.update(doc) | ||
try: | ||
self.collection.upsert_multi(insert_batch) | ||
except DocumentExistsException as e: | ||
logger.debug(f"Document already exists: {e}") | ||
|
||
async def fetch(self, ids: List[str]) -> List[List[float]]: | ||
# Fetch vectors by IDs | ||
fetched_result = self.collection.get_multi(ids) | ||
fetched_vectors = { | ||
k: v.value[f"{self.embedding_key}"] | ||
for k, v in fetched_result.results.items() | ||
} | ||
return list(map(lambda x: fetched_vectors[x], ids)) | ||
|
||
async def is_exist(self, ids: List[str]) -> List[bool]: | ||
existed_result = self.collection.exists_multi(ids) | ||
existed_ids = {k: v.exists for k, v in existed_result.results.items()} | ||
return list(map(lambda x: existed_ids[x], ids)) | ||
|
||
async def query( | ||
self, queries: List[str], top_k: int, **kwargs | ||
) -> Tuple[List[List[str]], List[List[float]]]: | ||
import couchbase.search as search | ||
from couchbase.options import SearchOptions | ||
from couchbase.vector_search import VectorQuery, VectorSearch | ||
|
||
queries = self.truncated_inputs(queries) | ||
query_embeddings: List[ | ||
List[float] | ||
] = await self.embedding.aget_text_embedding_batch(queries) | ||
|
||
ids, scores = [], [] | ||
for query_embedding in query_embeddings: | ||
# Create Search Request | ||
search_req = search.SearchRequest.create( | ||
VectorSearch.from_vector_query( | ||
VectorQuery( | ||
self.embedding_key, | ||
query_embedding, | ||
top_k, | ||
) | ||
) | ||
) | ||
|
||
# Search | ||
if self.scoped_index: | ||
search_iter = self.scope.search( | ||
self.index_name, | ||
search_req, | ||
SearchOptions(limit=top_k), | ||
) | ||
|
||
else: | ||
search_iter = self.cluster.search( | ||
self.index_name, | ||
search_req, | ||
SearchOptions(limit=top_k), | ||
) | ||
|
||
# Parse the search results | ||
# search_iter.rows() can only be iterated once. | ||
id_list, score_list = [], [] | ||
for result in search_iter.rows(): | ||
id_list.append(result.id) | ||
score_list.append(result.score) | ||
|
||
ids.append(id_list) | ||
scores.append(score_list) | ||
|
||
return ids, scores | ||
|
||
async def delete(self, ids: List[str]): | ||
self.collection.remove_multi(ids) | ||
|
||
def _check_bucket_exists(self) -> bool: | ||
"""Check if the bucket exists in the linked Couchbase cluster. | ||
Returns: | ||
True if the bucket exists | ||
""" | ||
bucket_manager = self.cluster.buckets() | ||
try: | ||
bucket_manager.get_bucket(self.bucket_name) | ||
return True | ||
except Exception as e: | ||
logger.debug("Error checking if bucket exists:", e) | ||
return False | ||
|
||
def _check_index_exists(self) -> bool: | ||
"""Check if the Search index exists in the linked Couchbase cluster | ||
Returns: | ||
bool: True if the index exists, False otherwise. | ||
Raises a ValueError if the index does not exist. | ||
""" | ||
if self.scoped_index: | ||
all_indexes = [ | ||
index.name for index in self.scope.search_indexes().get_all_indexes() | ||
] | ||
if self.index_name not in all_indexes: | ||
raise ValueError( | ||
f"Index {self.index_name} does not exist. " | ||
" Please create the index before searching." | ||
) | ||
else: | ||
all_indexes = [ | ||
index.name for index in self.cluster.search_indexes().get_all_indexes() | ||
] | ||
if self.index_name not in all_indexes: | ||
raise ValueError( | ||
f"Index {self.index_name} does not exist. " | ||
" Please create the index before searching." | ||
) | ||
|
||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import logging | ||
|
||
from qdrant_client import QdrantClient | ||
from qdrant_client.models import ( | ||
Distance, | ||
VectorParams, | ||
PointStruct, | ||
PointIdsList, | ||
HasIdCondition, | ||
Filter, | ||
SearchRequest, | ||
) | ||
|
||
from typing import List, Tuple | ||
|
||
from autorag.vectordb import BaseVectorStore | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
class Qdrant(BaseVectorStore): | ||
def __init__( | ||
self, | ||
embedding_model: str, | ||
collection_name: str, | ||
embedding_batch: int = 100, | ||
similarity_metric: str = "cosine", | ||
client_type: str = "docker", | ||
url: str = "http://localhost:6333", | ||
host: str = "", | ||
api_key: str = "", | ||
dimension: int = 1536, | ||
ingest_batch: int = 64, | ||
parallel: int = 1, | ||
max_retries: int = 3, | ||
): | ||
super().__init__(embedding_model, similarity_metric, embedding_batch) | ||
|
||
self.collection_name = collection_name | ||
self.ingest_batch = ingest_batch | ||
self.parallel = parallel | ||
self.max_retries = max_retries | ||
|
||
if similarity_metric == "cosine": | ||
distance = Distance.COSINE | ||
elif similarity_metric == "ip": | ||
distance = Distance.DOT | ||
elif similarity_metric == "l2": | ||
distance = Distance.EUCLID | ||
else: | ||
raise ValueError( | ||
f"similarity_metric {similarity_metric} is not supported\n" | ||
"supported similarity metrics are: cosine, ip, l2" | ||
) | ||
|
||
if client_type == "docker": | ||
self.client = QdrantClient( | ||
url=url, | ||
) | ||
elif client_type == "cloud": | ||
self.client = QdrantClient( | ||
host=host, | ||
api_key=api_key, | ||
) | ||
else: | ||
raise ValueError( | ||
f"client_type {client_type} is not supported\n" | ||
"supported client types are: docker, cloud" | ||
) | ||
|
||
if not self.client.collection_exists(collection_name): | ||
self.client.create_collection( | ||
collection_name, | ||
vectors_config=VectorParams( | ||
size=dimension, | ||
distance=distance, | ||
), | ||
) | ||
self.collection = self.client.get_collection(collection_name) | ||
|
||
async def add(self, ids: List[str], texts: List[str]): | ||
texts = self.truncated_inputs(texts) | ||
text_embeddings = await self.embedding.aget_text_embedding_batch(texts) | ||
|
||
points = list( | ||
map(lambda x: PointStruct(id=x[0], vector=x[1]), zip(ids, text_embeddings)) | ||
) | ||
|
||
self.client.upload_points( | ||
collection_name=self.collection_name, | ||
points=points, | ||
batch_size=self.ingest_batch, | ||
parallel=self.parallel, | ||
max_retries=self.max_retries, | ||
wait=True, | ||
) | ||
|
||
async def fetch(self, ids: List[str]) -> List[List[float]]: | ||
# Fetch vectors by IDs | ||
fetched_results = self.client.retrieve( | ||
collection_name=self.collection_name, | ||
ids=ids, | ||
with_vectors=True, | ||
) | ||
return list(map(lambda x: x.vector, fetched_results)) | ||
|
||
async def is_exist(self, ids: List[str]) -> List[bool]: | ||
existed_result = self.client.scroll( | ||
collection_name=self.collection_name, | ||
scroll_filter=Filter( | ||
must=[ | ||
HasIdCondition(has_id=ids), | ||
], | ||
), | ||
) | ||
# existed_result is tuple. So we use existed_result[0] to get list of Record | ||
existed_ids = list(map(lambda x: x.id, existed_result[0])) | ||
return list(map(lambda x: x in existed_ids, ids)) | ||
|
||
async def query( | ||
self, queries: List[str], top_k: int, **kwargs | ||
) -> Tuple[List[List[str]], List[List[float]]]: | ||
queries = self.truncated_inputs(queries) | ||
query_embeddings: List[ | ||
List[float] | ||
] = await self.embedding.aget_text_embedding_batch(queries) | ||
|
||
search_queries = list( | ||
map( | ||
lambda x: SearchRequest(vector=x, limit=top_k, with_vector=True), | ||
query_embeddings, | ||
) | ||
) | ||
|
||
search_result = self.client.search_batch( | ||
collection_name=self.collection_name, requests=search_queries | ||
) | ||
|
||
# Extract IDs and distances | ||
ids = [[str(hit.id) for hit in result] for result in search_result] | ||
scores = [[hit.score for hit in result] for result in search_result] | ||
|
||
return ids, scores | ||
|
||
async def delete(self, ids: List[str]): | ||
self.client.delete( | ||
collection_name=self.collection_name, | ||
points_selector=PointIdsList(points=ids), | ||
) | ||
|
||
def delete_collection(self): | ||
# Delete the collection | ||
self.client.delete_collection(self.collection_name) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.