-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b3cb5e8
commit b659e5c
Showing
51 changed files
with
3,131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
name: Ruff | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
ruff: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: actions/setup-python@v2 | ||
with: | ||
python-version: 3.9 | ||
- run: pip install ruff | ||
- run: ruff check . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: Python Tests | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- '**' | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v2 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: '3.10' | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install ".[dev]" | ||
- name: Run tests library | ||
run: | | ||
make tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,11 @@ __pycache__/ | |
*.py[cod] | ||
*$py.class | ||
|
||
*.duckdb | ||
duckdb_tmp/ | ||
|
||
evaluation_datasets/ | ||
|
||
# C extensions | ||
*.so | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Variables | ||
DIALECT := duckdb | ||
|
||
# Rules | ||
fix: | ||
sqlfluff fix --dialect $(DIALECT) | ||
|
||
lint: | ||
sqlfluff lint --dialect $(DIALECT) | ||
|
||
tests: | ||
@echo "Removing test.duckdb if it exists..." | ||
rm -rf test.duckdb | ||
rm -rf test.duckdb.wal | ||
pytest ducksearch/tables/create.py | ||
pytest ducksearch/tables/insert.py | ||
pytest ducksearch/tables/select.py | ||
rm -rf test.duckdb | ||
rm -rf test.duckdb.wal | ||
pytest ducksearch/evaluation/evaluation.py | ||
rm -rf test.duckdb | ||
rm -rf test.duckdb.wal | ||
pytest ducksearch/upload/upload.py | ||
rm -rf test.duckdb | ||
rm -rf test.duckdb.wal | ||
pytest ducksearch/search/create.py | ||
pytest ducksearch/search/select.py | ||
|
||
view: | ||
harlequin test.duckdb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,12 @@ | ||
# ducksearch | ||
Search with DuckDB | ||
|
||
|
||
``` | ||
@misc{PyLate, | ||
title={DuckSearch, efficient search with DuckDB}, | ||
author={Sourty, Raphaël}, | ||
url={https://github.com/lightonai/ducksearch}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = ["decorators", "evaluation", "hf", "search", "tables", "upload", "utils"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
VERSION = (1, 0, 0) | ||
|
||
__version__ = ".".join(map(str, VERSION)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .execute_with_duckdb import connect_to_duckdb, execute_with_duckdb | ||
|
||
__all__ = ["execute_with_duckdb", "connect_to_duckdb"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import pathlib | ||
from functools import wraps | ||
|
||
import duckdb | ||
import pandas as pd | ||
|
||
|
||
def connect_to_duckdb( | ||
database: str, | ||
read_only: bool = False, | ||
config: dict | None = None, | ||
): | ||
"""Connect to the DuckDB database.""" | ||
return ( | ||
duckdb.connect(database=database, read_only=read_only, config=config) | ||
if config | ||
else duckdb.connect(database=database, read_only=read_only) | ||
) | ||
|
||
|
||
def execute_with_duckdb( | ||
relative_path: str | list[str], | ||
df_name: str | None = None, | ||
read_only: bool = False, | ||
fields: list[str] | None = None, | ||
fetch_df: bool = False, | ||
**kwargs, | ||
): | ||
"""Decorator to execute a SQL query using DuckDB.""" | ||
|
||
def decorator(func): | ||
@wraps(func) | ||
def wrapper( | ||
*args, | ||
database: str, | ||
config: dict | None = None, | ||
df: pd.DataFrame = None, | ||
relative_path: str | list[str] = relative_path, | ||
**kwargs, | ||
): | ||
"""Connect to the database and execute the query.""" | ||
# Ensure a DataFrame and table name consistency | ||
if df is not None: | ||
assert df_name is not None, "Table name must be provided." | ||
assert ( | ||
read_only is False | ||
), "Read-only mode is not supported for writing." | ||
|
||
if df_name is not None: | ||
assert df is not None, "DataFrame must be provided." | ||
assert ( | ||
read_only is False | ||
), "Read-only mode is not supported for writing." | ||
|
||
# Open the DuckDB connection | ||
conn = connect_to_duckdb( | ||
database=database, read_only=read_only, config=config | ||
) | ||
|
||
if isinstance(relative_path, str): | ||
relative_path = [relative_path] | ||
|
||
try: | ||
if df is not None and df_name is not None: | ||
conn.register(df_name, df) | ||
|
||
# Execute the SQL query | ||
for path in relative_path: | ||
# Get the directory of the current file (the file where this decorator is used) | ||
path = pathlib.Path(__file__).parent.parent.joinpath(path) | ||
|
||
# Load the SQL query from the file specified by the path | ||
with open(file=path, mode="r") as sql_file: | ||
query = sql_file.read() | ||
|
||
if kwargs: | ||
query = query.format(**kwargs) | ||
|
||
if fetch_df: | ||
data = conn.execute(query).fetchdf() | ||
data.columns = data.columns.str.lower() | ||
data = data.to_dict(orient="records") | ||
|
||
else: | ||
data = conn.execute(query).fetchall() | ||
if fields is not None: | ||
data = [dict(zip(fields, row)) for row in data] | ||
except duckdb.duckdb.IOException: | ||
message = "\n--------\nDuckDB exception, too many files open.\nGet current ulimit: ulimit -n\nIncrease ulimit with `ulimit -n 4096` or more.\n--------\n" | ||
raise duckdb.duckdb.IOException(message) | ||
except Exception as error: | ||
raise ValueError( | ||
"\n{}:\n{}\n{}:\n{}".format( | ||
type(error).__name__, path, error, query | ||
) | ||
) | ||
finally: | ||
conn.close() | ||
|
||
if fetch_df: | ||
return data | ||
|
||
if data: | ||
return data | ||
|
||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .evaluation import evaluate, load_beir | ||
|
||
__all__ = ["evaluate", "load_beir"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import collections | ||
from typing import Dict | ||
|
||
__all__ = ["evaluate", "load_beir"] | ||
|
||
|
||
def load_beir(dataset_name: str, split: str = "test") -> tuple[list, list, dict]: | ||
"""Load BEIR dataset. | ||
Parameters | ||
---------- | ||
dataset_name | ||
Dataset name: scifact. | ||
split | ||
Dataset split: test. | ||
Examples | ||
-------- | ||
>>> documents, queries, qrels = load_beir("scifact", split="test") | ||
>>> len(documents) | ||
5183 | ||
>>> len(queries) | ||
300 | ||
""" | ||
from beir import util | ||
from beir.datasets.data_loader import GenericDataLoader | ||
|
||
data_path = util.download_and_unzip( | ||
url=f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip", | ||
out_dir="./evaluation_datasets/", | ||
) | ||
|
||
documents, queries, qrels = GenericDataLoader(data_folder=data_path).load( | ||
split=split | ||
) | ||
|
||
documents = [ | ||
{ | ||
"id": document_id, | ||
"title": document["title"], | ||
"text": document["text"], | ||
} | ||
for document_id, document in documents.items() | ||
] | ||
|
||
_queries = [queries[query_id] for query_id, _ in qrels.items()] | ||
|
||
_qrels = collections.defaultdict(dict) | ||
for query_id, query_documents in qrels.items(): | ||
for document in list(query_documents.keys()): | ||
if query_id in queries: | ||
_qrels[document][queries[query_id]] = 1 | ||
|
||
return ( | ||
documents, | ||
_queries, | ||
_qrels, | ||
) | ||
|
||
|
||
def evaluate( | ||
scores: list[list[dict]], | ||
qrels: dict, | ||
queries: list[str], | ||
metrics: list = [], | ||
) -> Dict[str, float]: | ||
"""Evaluate candidates matchs. | ||
Parameters | ||
---------- | ||
matchs | ||
Matchs. | ||
qrels | ||
Qrels. | ||
queries | ||
index of queries of qrels. | ||
k | ||
Number of documents to retrieve. | ||
metrics | ||
Metrics to compute. | ||
Examples | ||
-------- | ||
>>> from ducksearch import evaluation, upload, search | ||
>>> documents, queries, qrels = evaluation.load_beir("scifact", split="test") | ||
>>> upload.documents( | ||
... database="test.duckdb", | ||
... key="id", | ||
... fields=["title", "text"], | ||
... documents=documents, | ||
... ) | ||
| Table | Size | | ||
|-----------|------| | ||
| documents | 5183 | | ||
>>> upload.indexes(database="test.duckdb") | ||
| Table | Size | | ||
|----------------|------| | ||
| documents | 5183 | | ||
| bm25_documents | 5183 | | ||
>>> scores = search.documents( | ||
... database="test.duckdb", | ||
... queries=queries, | ||
... top_k=10, | ||
... ) | ||
>>> evaluation_scores = evaluation.evaluate( | ||
... scores=scores, | ||
... qrels=qrels, | ||
... queries=queries, | ||
... metrics=["ndcg@10", "hits@1", "hits@2", "hits@3", "hits@4", "hits@5", "hits@10"], | ||
... ) | ||
>>> assert evaluation_scores["ndcg@10"] > 0.64 | ||
>>> assert evaluation_scores["hits@1"] > 0.51 | ||
>>> assert evaluation_scores["hits@10"] > 0.86 | ||
""" | ||
from ranx import Qrels, Run, evaluate | ||
|
||
_qrels = collections.defaultdict(dict) | ||
for document_id, document_queries in qrels.items(): | ||
for query, score in document_queries.items(): | ||
_qrels[query][document_id] = score | ||
|
||
qrels = Qrels( | ||
qrels=_qrels, | ||
) | ||
|
||
run_dict = { | ||
query: { | ||
match["id"]: 1 - (rank / len(query_matchs)) | ||
for rank, match in enumerate(iterable=query_matchs) | ||
} | ||
for query, query_matchs in zip(queries, scores) | ||
} | ||
|
||
run = Run(run=run_dict) | ||
|
||
if not metrics: | ||
metrics = ["ndcg@10"] + [f"hits@{k}" for k in [1, 2, 3, 4, 5, 10]] | ||
|
||
return evaluate( | ||
qrels=qrels, | ||
run=run, | ||
metrics=metrics, | ||
make_comparable=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .insert import insert_documents | ||
|
||
__all__ = ["insert_documents"] |
Oops, something went wrong.