diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt
new file mode 100644
index 0000000..048c989
--- /dev/null
+++ b/CONTRIBUTORS.txt
@@ -0,0 +1,7 @@
+Individual Contributors to the BEIR Repository (BEIR contributors) include:
+1. Nandan Thakur
+2. Nils Reimers
+3. Iryna Gurevych
+4. Jimmy Lin
+5. Andreas Rücklé
+6. Abhishek Srivastava
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..cb993f3
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2020-2023 Nandan Thakur
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/NOTICE.txt b/NOTICE.txt
new file mode 100644
index 0000000..55b539f
--- /dev/null
+++ b/NOTICE.txt
@@ -0,0 +1,11 @@
+-------------------------------------------------------------------------------
+Copyright since 2022
+University of Waterloo
+-------------------------------------------------------------------------------
+
+-------------------------------------------------------------------------------
+Copyright since 2020
+Ubiquitous Knowledge Processing (UKP) Lab, Technische Universität Darmstadt
+-------------------------------------------------------------------------------
+
+For individual contributors, please refer to the CONTRIBUTORS file.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..8975917
--- /dev/null
+++ b/README.md
@@ -0,0 +1,246 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Paper |
+ Installation |
+ Quick Example |
+ Datasets |
+ Wiki |
+ Hugging Face
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## :beers: What is it?
+
+**BEIR** is a **heterogeneous benchmark** containing diverse IR tasks. It also provides a **common and easy framework** for evaluation of your NLP-based retrieval models within the benchmark.
+
+For **an overview**, checkout our **new wiki** page: [https://github.com/beir-cellar/beir/wiki](https://github.com/beir-cellar/beir/wiki).
+
+For **models and datasets**, checkout out **HuggingFace (HF)** page: [https://huggingface.co/BeIR](https://huggingface.co/BeIR).
+
+For **Leaderboard**, checkout out **Eval AI** page: [https://eval.ai/web/challenges/challenge-page/1897](https://eval.ai/web/challenges/challenge-page/1897).
+
+For more information, checkout out our publications:
+
+- [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://openreview.net/forum?id=wCu6T5xFjeJ) (NeurIPS 2021, Datasets and Benchmarks Track)
+
+## :beers: Installation
+
+Install via pip:
+
+```python
+pip install beir
+```
+
+If you want to build from source, use:
+
+```python
+$ git clone https://github.com/beir-cellar/beir.git
+$ cd beir
+$ pip install -e .
+```
+
+Tested with python versions 3.6 and 3.7
+
+## :beers: Features
+
+- Preprocess your own IR dataset or use one of the already-preprocessed 17 benchmark datasets
+- Wide settings included, covers diverse benchmarks useful for both academia and industry
+- Includes well-known retrieval architectures (lexical, dense, sparse and reranking-based)
+- Add and evaluate your own model in a easy framework using different state-of-the-art evaluation metrics
+
+## :beers: Quick Example
+
+For other example codes, please refer to our **[Examples and Tutorials](https://github.com/beir-cellar/beir/wiki/Examples-and-tutorials)** Wiki page.
+
+```python
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where scifact has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Load the SBERT model and retrieve using cosine-similarity
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-tas-b"), batch_size=16)
+retriever = EvaluateRetrieval(model, score_function="dot") # or "cos_sim" for cosine similarity
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+```
+
+## :beers: Available Datasets
+
+Command to generate md5hash using Terminal: ``md5sum filename.zip``.
+
+You can view all datasets available **[here](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/)** or on **[HuggingFace](https://huggingface.co/BeIR)**.
+
+
+| Dataset | Website| BEIR-Name | Public? | Type | Queries | Corpus | Rel D/Q | Down-load | md5 |
+| -------- | -----| ---------| ------- | --------- | ----------- | ---------| ---------| :----------: | :------:|
+| MSMARCO | [Homepage](https://microsoft.github.io/msmarco/)| ``msmarco`` | ✅ | ``train``
``dev``
``test``| 6,980 | 8.84M | 1.1 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/msmarco.zip) | ``444067daf65d982533ea17ebd59501e4`` |
+| TREC-COVID | [Homepage](https://ir.nist.gov/covidSubmit/index.html)| ``trec-covid``| ✅ | ``test``| 50| 171K| 493.5 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip) | ``ce62140cb23feb9becf6270d0d1fe6d1`` |
+| NFCorpus | [Homepage](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/) | ``nfcorpus`` | ✅ |``train``
``dev``
``test``| 323 | 3.6K | 38.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip) | ``a89dba18a62ef92f7d323ec890a0d38d`` |
+| BioASQ | [Homepage](http://bioasq.org) | ``bioasq``| ❌ | ``train``
``test`` | 500 | 14.91M | 8.05 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#2-bioasq) |
+| NQ | [Homepage](https://ai.google.com/research/NaturalQuestions) | ``nq``| ✅ | ``train``
``test``| 3,452 | 2.68M | 1.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip) | ``d4d3d2e48787a744b6f6e691ff534307`` |
+| HotpotQA | [Homepage](https://hotpotqa.github.io) | ``hotpotqa``| ✅ |``train``
``dev``
``test``| 7,405 | 5.23M | 2.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/hotpotqa.zip) | ``f412724f78b0d91183a0e86805e16114`` |
+| FiQA-2018 | [Homepage](https://sites.google.com/view/fiqa/) | ``fiqa`` | ✅ | ``train``
``dev``
``test``| 648 | 57K | 2.6 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip) | ``17918ed23cd04fb15047f73e6c3bd9d9`` |
+| Signal-1M(RT) | [Homepage](https://research.signal-ai.com/datasets/signal1m-tweetir.html)| ``signal1m`` | ❌ | ``test``| 97 | 2.86M | 19.6 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#4-signal-1m) |
+| TREC-NEWS | [Homepage](https://trec.nist.gov/data/news2019.html) | ``trec-news`` | ❌ | ``test``| 57 | 595K | 19.6 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#1-trec-news) |
+| Robust04 | [Homepage](https://trec.nist.gov/data/robust/04.guidelines.html) | ``robust04``| ❌ | ``test``| 249 | 528K | 69.9 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#3-robust04) |
+| ArguAna | [Homepage](http://argumentation.bplaced.net/arguana/data) | ``arguana``| ✅ |``test`` | 1,406 | 8.67K | 1.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/arguana.zip) | ``8ad3e3c2a5867cdced806d6503f29b99`` |
+| Touche-2020| [Homepage](https://webis.de/events/touche-20/shared-task-1.html) | ``webis-touche2020``| ✅ | ``test``| 49 | 382K | 19.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/webis-touche2020.zip) | ``46f650ba5a527fc69e0a6521c5a23563`` |
+| CQADupstack| [Homepage](http://nlp.cis.unimelb.edu.au/resources/cqadupstack/) | ``cqadupstack``| ✅ | ``test``| 13,145 | 457K | 1.4 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/cqadupstack.zip) | ``4e41456d7df8ee7760a7f866133bda78`` |
+| Quora| [Homepage](https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs) | ``quora``| ✅ | ``dev``
``test``| 10,000 | 523K | 1.6 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip) | ``18fb154900ba42a600f84b839c173167`` |
+| DBPedia | [Homepage](https://github.com/iai-group/DBpedia-Entity/) | ``dbpedia-entity``| ✅ | ``dev``
``test``| 400 | 4.63M | 38.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/dbpedia-entity.zip) | ``c2a39eb420a3164af735795df012ac2c`` |
+| SCIDOCS| [Homepage](https://allenai.org/data/scidocs) | ``scidocs``| ✅ | ``test``| 1,000 | 25K | 4.9 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scidocs.zip) | ``38121350fc3a4d2f48850f6aff52e4a9`` |
+| FEVER | [Homepage](http://fever.ai) | ``fever``| ✅ | ``train``
``dev``
``test``| 6,666 | 5.42M | 1.2| [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip) | ``5a818580227bfb4b35bb6fa46d9b6c03`` |
+| Climate-FEVER| [Homepage](http://climatefever.ai) | ``climate-fever``| ✅ |``test``| 1,535 | 5.42M | 3.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/climate-fever.zip) | ``8b66f0a9126c521bae2bde127b4dc99d`` |
+| SciFact| [Homepage](https://github.com/allenai/scifact) | ``scifact``| ✅ | ``train``
``test``| 300 | 5K | 1.1 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip) | ``5f7d1de60b170fc8027bb7898e2efca1`` |
+
+
+## :beers: Additional Information
+
+We also provide a variety of additional information in our **[Wiki](https://github.com/beir-cellar/beir/wiki)** page.
+Please refer to these pages for the following:
+
+
+### Quick Start
+
+- [Installing BEIR](https://github.com/beir-cellar/beir/wiki/Installing-beir)
+- [Examples and Tutorials](https://github.com/beir-cellar/beir/wiki/Examples-and-tutorials)
+
+### Datasets
+
+- [Datasets Available](https://github.com/beir-cellar/beir/wiki/Datasets-available)
+- [Multilingual Datasets](https://github.com/beir-cellar/beir/wiki/Multilingual-datasets)
+- [Load your Custom Dataset](https://github.com/beir-cellar/beir/wiki/Load-your-custom-dataset)
+
+### Models
+- [Models Available](https://github.com/beir-cellar/beir/wiki/Models-available)
+- [Evaluate your Custom Model](https://github.com/beir-cellar/beir/wiki/Evaluate-your-custom-model)
+
+### Metrics
+
+- [Metrics Available](https://github.com/beir-cellar/beir/wiki/Metrics-available)
+
+### Miscellaneous
+
+- [BEIR Leaderboard](https://github.com/beir-cellar/beir/wiki/Leaderboard)
+- [Couse Material on IR](https://github.com/beir-cellar/beir/wiki/Course-material-on-ir)
+
+## :beers: Disclaimer
+
+Similar to Tensorflow [datasets](https://github.com/tensorflow/datasets) or HuggingFace's [datasets](https://github.com/huggingface/datasets) library, we just downloaded and prepared public datasets. We only distribute these datasets in a specific format, but we do not vouch for their quality or fairness, or claim that you have license to use the dataset. It remains the user's responsibility to determine whether you as a user have permission to use the dataset under the dataset's license and to cite the right owner of the dataset.
+
+If you're a dataset owner and wish to update any part of it, or do not want your dataset to be included in this library, feel free to post an issue here or make a pull request!
+
+If you're a dataset owner and wish to include your dataset or model in this library, feel free to post an issue here or make a pull request!
+
+## :beers: Citing & Authors
+
+If you find this repository helpful, feel free to cite our publication [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://arxiv.org/abs/2104.08663):
+
+```
+@inproceedings{
+ thakur2021beir,
+ title={{BEIR}: A Heterogeneous Benchmark for Zero-shot Evaluation of Information Retrieval Models},
+ author={Nandan Thakur and Nils Reimers and Andreas R{\"u}ckl{\'e} and Abhishek Srivastava and Iryna Gurevych},
+ booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 2)},
+ year={2021},
+ url={https://openreview.net/forum?id=wCu6T5xFjeJ}
+}
+```
+
+The main contributors of this repository are:
+- [Nandan Thakur](https://github.com/Nthakur20), Personal Website: [nandan-thakur.com](https://nandan-thakur.com)
+
+Contact person: Nandan Thakur, [nandant@gmail.com](mailto:nandant@gmail.com)
+
+Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
+
+> This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.
+
+## :beers: Collaboration
+
+The BEIR Benchmark has been made possible due to a collaborative effort of the following universities and organizations:
+- [UKP Lab, Technical University of Darmstadt](http://www.ukp.tu-darmstadt.de/)
+- [University of Waterloo](https://uwaterloo.ca/)
+- [HuggingFace](https://huggingface.co/)
+
+## :beers: Contributors
+
+Thanks go to all these wonderful collaborations for their contribution towards the BEIR benchmark:
+
+
+
+
+
+
+
+
+
diff --git a/beir/__init__.py b/beir/__init__.py
new file mode 100644
index 0000000..6658535
--- /dev/null
+++ b/beir/__init__.py
@@ -0,0 +1 @@
+from .logging import LoggingHandler
\ No newline at end of file
diff --git a/beir/datasets/__init__.py b/beir/datasets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/beir/datasets/data_loader.py b/beir/datasets/data_loader.py
new file mode 100644
index 0000000..cb057ed
--- /dev/null
+++ b/beir/datasets/data_loader.py
@@ -0,0 +1,126 @@
+from typing import Dict, Tuple
+from tqdm.autonotebook import tqdm
+import json
+import os
+import logging
+import csv
+
+logger = logging.getLogger(__name__)
+
+class GenericDataLoader:
+
+ def __init__(self, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl",
+ qrels_folder: str = "qrels", qrels_file: str = ""):
+ self.corpus = {}
+ self.queries = {}
+ self.qrels = {}
+
+ if prefix:
+ query_file = prefix + "-" + query_file
+ qrels_folder = prefix + "-" + qrels_folder
+
+ self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
+ self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
+ self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
+ self.qrels_file = qrels_file
+
+ @staticmethod
+ def check(fIn: str, ext: str):
+ if not os.path.exists(fIn):
+ raise ValueError("File {} not present! Please provide accurate file.".format(fIn))
+
+ if not fIn.endswith(ext):
+ raise ValueError("File {} must be present with extension {}".format(fIn, ext))
+
+ def load_custom(self) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
+
+ self.check(fIn=self.corpus_file, ext="jsonl")
+ self.check(fIn=self.query_file, ext="jsonl")
+ self.check(fIn=self.qrels_file, ext="tsv")
+
+ if not len(self.corpus):
+ logger.info("Loading Corpus...")
+ self._load_corpus()
+ logger.info("Loaded %d Documents.", len(self.corpus))
+ logger.info("Doc Example: %s", list(self.corpus.values())[0])
+
+ if not len(self.queries):
+ logger.info("Loading Queries...")
+ self._load_queries()
+
+ if os.path.exists(self.qrels_file):
+ self._load_qrels()
+ self.queries = {qid: self.queries[qid] for qid in self.qrels}
+ logger.info("Loaded %d Queries.", len(self.queries))
+ logger.info("Query Example: %s", list(self.queries.values())[0])
+
+ return self.corpus, self.queries, self.qrels
+
+ def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
+
+ self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
+ self.check(fIn=self.corpus_file, ext="jsonl")
+ self.check(fIn=self.query_file, ext="jsonl")
+ self.check(fIn=self.qrels_file, ext="tsv")
+
+ if not len(self.corpus):
+ logger.info("Loading Corpus...")
+ self._load_corpus()
+ logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
+ logger.info("Doc Example: %s", list(self.corpus.values())[0])
+
+ if not len(self.queries):
+ logger.info("Loading Queries...")
+ self._load_queries()
+
+ if os.path.exists(self.qrels_file):
+ self._load_qrels()
+ self.queries = {qid: self.queries[qid] for qid in self.qrels}
+ logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
+ logger.info("Query Example: %s", list(self.queries.values())[0])
+
+ return self.corpus, self.queries, self.qrels
+
+ def load_corpus(self) -> Dict[str, Dict[str, str]]:
+
+ self.check(fIn=self.corpus_file, ext="jsonl")
+
+ if not len(self.corpus):
+ logger.info("Loading Corpus...")
+ self._load_corpus()
+ logger.info("Loaded %d Documents.", len(self.corpus))
+ logger.info("Doc Example: %s", list(self.corpus.values())[0])
+
+ return self.corpus
+
+ def _load_corpus(self):
+
+ num_lines = sum(1 for i in open(self.corpus_file, 'rb'))
+ with open(self.corpus_file, encoding='utf8') as fIn:
+ for line in tqdm(fIn, total=num_lines):
+ line = json.loads(line)
+ self.corpus[line.get("_id")] = {
+ "text": line.get("text"),
+ "title": line.get("title"),
+ }
+
+ def _load_queries(self):
+
+ with open(self.query_file, encoding='utf8') as fIn:
+ for line in fIn:
+ line = json.loads(line)
+ self.queries[line.get("_id")] = line.get("text")
+
+ def _load_qrels(self):
+
+ reader = csv.reader(open(self.qrels_file, encoding="utf-8"),
+ delimiter="\t", quoting=csv.QUOTE_MINIMAL)
+ next(reader)
+
+ for id, row in enumerate(reader):
+ query_id, corpus_id, score = row[0], row[1], int(row[2])
+
+ if query_id not in self.qrels:
+ self.qrels[query_id] = {corpus_id: score}
+ else:
+ self.qrels[query_id][corpus_id] = score
\ No newline at end of file
diff --git a/beir/datasets/data_loader_hf.py b/beir/datasets/data_loader_hf.py
new file mode 100644
index 0000000..33b651f
--- /dev/null
+++ b/beir/datasets/data_loader_hf.py
@@ -0,0 +1,118 @@
+from collections import defaultdict
+from typing import Dict, Tuple
+import os
+import logging
+from datasets import load_dataset, Value, Features
+
+logger = logging.getLogger(__name__)
+
+
+class HFDataLoader:
+
+ def __init__(self, hf_repo: str = None, hf_repo_qrels: str = None, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl",
+ qrels_folder: str = "qrels", qrels_file: str = "", streaming: bool = False, keep_in_memory: bool = False):
+ self.corpus = {}
+ self.queries = {}
+ self.qrels = {}
+ self.hf_repo = hf_repo
+ if hf_repo:
+ logger.warn("A huggingface repository is provided. This will override the data_folder, prefix and *_file arguments.")
+ self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo + "-qrels"
+ else:
+ # data folder would contain these files:
+ # (1) fiqa/corpus.jsonl (format: jsonlines)
+ # (2) fiqa/queries.jsonl (format: jsonlines)
+ # (3) fiqa/qrels/test.tsv (format: tsv ("\t"))
+ if prefix:
+ query_file = prefix + "-" + query_file
+ qrels_folder = prefix + "-" + qrels_folder
+
+ self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
+ self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
+ self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
+ self.qrels_file = qrels_file
+ self.streaming = streaming
+ self.keep_in_memory = keep_in_memory
+
+ @staticmethod
+ def check(fIn: str, ext: str):
+ if not os.path.exists(fIn):
+ raise ValueError("File {} not present! Please provide accurate file.".format(fIn))
+
+ if not fIn.endswith(ext):
+ raise ValueError("File {} must be present with extension {}".format(fIn, ext))
+
+ def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
+
+ if not self.hf_repo:
+ self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
+ self.check(fIn=self.corpus_file, ext="jsonl")
+ self.check(fIn=self.query_file, ext="jsonl")
+ self.check(fIn=self.qrels_file, ext="tsv")
+
+ if not len(self.corpus):
+ logger.info("Loading Corpus...")
+ self._load_corpus()
+ logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
+ logger.info("Doc Example: %s", self.corpus[0])
+
+ if not len(self.queries):
+ logger.info("Loading Queries...")
+ self._load_queries()
+
+ self._load_qrels(split)
+ # filter queries with no qrels
+ qrels_dict = defaultdict(dict)
+
+ def qrels_dict_init(row):
+ qrels_dict[row['query-id']][row['corpus-id']] = int(row['score'])
+ self.qrels.map(qrels_dict_init)
+ self.qrels = qrels_dict
+ self.queries = self.queries.filter(lambda x: x['id'] in self.qrels)
+ logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
+ logger.info("Query Example: %s", self.queries[0])
+
+ return self.corpus, self.queries, self.qrels
+
+ def load_corpus(self) -> Dict[str, Dict[str, str]]:
+ if not self.hf_repo:
+ self.check(fIn=self.corpus_file, ext="jsonl")
+
+ if not len(self.corpus):
+ logger.info("Loading Corpus...")
+ self._load_corpus()
+ logger.info("Loaded %d %s Documents.", len(self.corpus))
+ logger.info("Doc Example: %s", self.corpus[0])
+
+ return self.corpus
+
+ def _load_corpus(self):
+ if self.hf_repo:
+ corpus_ds = load_dataset(self.hf_repo, 'corpus', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
+ else:
+ corpus_ds = load_dataset('json', data_files=self.corpus_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
+ corpus_ds = next(iter(corpus_ds.values())) # get first split
+ corpus_ds = corpus_ds.cast_column('_id', Value('string'))
+ corpus_ds = corpus_ds.rename_column('_id', 'id')
+ corpus_ds = corpus_ds.remove_columns([col for col in corpus_ds.column_names if col not in ['id', 'text', 'title']])
+ self.corpus = corpus_ds
+
+ def _load_queries(self):
+ if self.hf_repo:
+ queries_ds = load_dataset(self.hf_repo, 'queries', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
+ else:
+ queries_ds = load_dataset('json', data_files=self.query_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
+ queries_ds = next(iter(queries_ds.values())) # get first split
+ queries_ds = queries_ds.cast_column('_id', Value('string'))
+ queries_ds = queries_ds.rename_column('_id', 'id')
+ queries_ds = queries_ds.remove_columns([col for col in queries_ds.column_names if col not in ['id', 'text']])
+ self.queries = queries_ds
+
+ def _load_qrels(self, split):
+ if self.hf_repo:
+ qrels_ds = load_dataset(self.hf_repo_qrels, keep_in_memory=self.keep_in_memory, streaming=self.streaming)[split]
+ else:
+ qrels_ds = load_dataset('csv', data_files=self.qrels_file, delimiter='\t', keep_in_memory=self.keep_in_memory)
+ features = Features({'query-id': Value('string'), 'corpus-id': Value('string'), 'score': Value('float')})
+ qrels_ds = qrels_ds.cast(features)
+ self.qrels = qrels_ds
\ No newline at end of file
diff --git a/beir/generation/__init__.py b/beir/generation/__init__.py
new file mode 100644
index 0000000..7e3c202
--- /dev/null
+++ b/beir/generation/__init__.py
@@ -0,0 +1 @@
+from .generate import QueryGenerator, PassageExpansion
\ No newline at end of file
diff --git a/beir/generation/generate.py b/beir/generation/generate.py
new file mode 100644
index 0000000..98206ff
--- /dev/null
+++ b/beir/generation/generate.py
@@ -0,0 +1,185 @@
+from tqdm.autonotebook import trange
+from ..util import write_to_json, write_to_tsv
+from typing import Dict
+import logging, os
+
+logger = logging.getLogger(__name__)
+
+class PassageExpansion:
+ def __init__(self, model, **kwargs):
+ self.model = model
+ self.corpus_exp = {}
+
+ @staticmethod
+ def save(output_dir: str, corpus: Dict[str, str], prefix: str):
+ os.makedirs(output_dir, exist_ok=True)
+
+ corpus_file = os.path.join(output_dir, prefix + "-corpus.jsonl")
+
+ logger.info("Saving expanded passages to {}".format(corpus_file))
+ write_to_json(output_file=corpus_file, data=corpus)
+
+ def expand(self,
+ corpus: Dict[str, Dict[str, str]],
+ output_dir: str,
+ top_k: int = 200,
+ max_length: int = 350,
+ prefix: str = "gen",
+ batch_size: int = 32,
+ sep: str = " "):
+
+ logger.info("Starting to expand Passages with {} tokens chosen...".format(top_k))
+ logger.info("Params: top_k = {}".format(top_k))
+ logger.info("Params: passage max_length = {}".format(max_length))
+ logger.info("Params: batch size = {}".format(batch_size))
+
+ corpus_ids = list(corpus.keys())
+ corpus_list = [corpus[doc_id] for doc_id in corpus_ids]
+
+ for start_idx in trange(0, len(corpus_list), batch_size, desc='pas'):
+ expansions = self.model.generate(
+ corpus=corpus_list[start_idx:start_idx + batch_size],
+ max_length=max_length,
+ top_k=top_k)
+
+ for idx in range(len(expansions)):
+ doc_id = corpus_ids[start_idx + idx]
+ self.corpus_exp[doc_id] = {
+ "title": corpus[doc_id]["title"],
+ "text": corpus[doc_id]["text"] + sep + expansions[idx],
+ }
+
+ # Saving finally all the questions
+ logger.info("Saving {} Expanded Passages...".format(len(self.corpus_exp)))
+ self.save(output_dir, self.corpus_exp, prefix)
+
+
+class QueryGenerator:
+ def __init__(self, model, **kwargs):
+ self.model = model
+ self.qrels = {}
+ self.queries = {}
+
+ @staticmethod
+ def save(output_dir: str, queries: Dict[str, str], qrels: Dict[str, Dict[str, int]], prefix: str):
+
+ os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(os.path.join(output_dir, prefix + "-qrels"), exist_ok=True)
+
+ query_file = os.path.join(output_dir, prefix + "-queries.jsonl")
+ qrels_file = os.path.join(output_dir, prefix + "-qrels", "train.tsv")
+
+ logger.info("Saving Generated Queries to {}".format(query_file))
+ write_to_json(output_file=query_file, data=queries)
+
+ logger.info("Saving Generated Qrels to {}".format(qrels_file))
+ write_to_tsv(output_file=qrels_file, data=qrels)
+
+ def generate(self,
+ corpus: Dict[str, Dict[str, str]],
+ output_dir: str,
+ top_p: int = 0.95,
+ top_k: int = 25,
+ max_length: int = 64,
+ ques_per_passage: int = 1,
+ prefix: str = "gen",
+ batch_size: int = 32,
+ save: bool = True,
+ save_after: int = 100000):
+
+ logger.info("Starting to Generate {} Questions Per Passage using top-p (nucleus) sampling...".format(ques_per_passage))
+ logger.info("Params: top_p = {}".format(top_p))
+ logger.info("Params: top_k = {}".format(top_k))
+ logger.info("Params: max_length = {}".format(max_length))
+ logger.info("Params: ques_per_passage = {}".format(ques_per_passage))
+ logger.info("Params: batch size = {}".format(batch_size))
+
+ count = 0
+ corpus_ids = list(corpus.keys())
+ corpus = [corpus[doc_id] for doc_id in corpus_ids]
+
+ for start_idx in trange(0, len(corpus), batch_size, desc='pas'):
+
+ size = len(corpus[start_idx:start_idx + batch_size])
+ queries = self.model.generate(
+ corpus=corpus[start_idx:start_idx + batch_size],
+ ques_per_passage=ques_per_passage,
+ max_length=max_length,
+ top_p=top_p,
+ top_k=top_k
+ )
+
+ assert len(queries) == size * ques_per_passage
+
+ for idx in range(size):
+ # Saving generated questions after every "save_after" corpus ids
+ if (len(self.queries) % save_after == 0 and len(self.queries) >= save_after):
+ logger.info("Saving {} Generated Queries...".format(len(self.queries)))
+ self.save(output_dir, self.queries, self.qrels, prefix)
+
+ corpus_id = corpus_ids[start_idx + idx]
+ start_id = idx * ques_per_passage
+ end_id = start_id + ques_per_passage
+ query_set = set([q.strip() for q in queries[start_id:end_id]])
+
+ for query in query_set:
+ count += 1
+ query_id = "genQ" + str(count)
+ self.queries[query_id] = query
+ self.qrels[query_id] = {corpus_id: 1}
+
+ # Saving finally all the questions
+ logger.info("Saving {} Generated Queries...".format(len(self.queries)))
+ self.save(output_dir, self.queries, self.qrels, prefix)
+
+ def generate_multi_process(self,
+ corpus: Dict[str, Dict[str, str]],
+ pool: Dict[str, object],
+ output_dir: str,
+ top_p: int = 0.95,
+ top_k: int = 25,
+ max_length: int = 64,
+ ques_per_passage: int = 1,
+ prefix: str = "gen",
+ batch_size: int = 32,
+ chunk_size: int = None):
+
+ logger.info("Starting to Generate {} Questions Per Passage using top-p (nucleus) sampling...".format(ques_per_passage))
+ logger.info("Params: top_p = {}".format(top_p))
+ logger.info("Params: top_k = {}".format(top_k))
+ logger.info("Params: max_length = {}".format(max_length))
+ logger.info("Params: ques_per_passage = {}".format(ques_per_passage))
+ logger.info("Params: batch size = {}".format(batch_size))
+
+ count = 0
+ corpus_ids = list(corpus.keys())
+ corpus = [corpus[doc_id] for doc_id in corpus_ids]
+
+ queries = self.model.generate_multi_process(
+ corpus=corpus,
+ pool=pool,
+ ques_per_passage=ques_per_passage,
+ max_length=max_length,
+ top_p=top_p,
+ top_k=top_k,
+ chunk_size=chunk_size,
+ batch_size=batch_size,
+ )
+
+ assert len(queries) == len(corpus) * ques_per_passage
+
+ for idx in range(len(corpus)):
+ corpus_id = corpus_ids[idx]
+ start_id = idx * ques_per_passage
+ end_id = start_id + ques_per_passage
+ query_set = set([q.strip() for q in queries[start_id:end_id]])
+
+ for query in query_set:
+ count += 1
+ query_id = "genQ" + str(count)
+ self.queries[query_id] = query
+ self.qrels[query_id] = {corpus_id: 1}
+
+ # Saving finally all the questions
+ logger.info("Saving {} Generated Queries...".format(len(self.queries)))
+ self.save(output_dir, self.queries, self.qrels, prefix)
\ No newline at end of file
diff --git a/beir/generation/models/__init__.py b/beir/generation/models/__init__.py
new file mode 100644
index 0000000..6367259
--- /dev/null
+++ b/beir/generation/models/__init__.py
@@ -0,0 +1,2 @@
+from .auto_model import QGenModel
+from .tilde import TILDE
\ No newline at end of file
diff --git a/beir/generation/models/auto_model.py b/beir/generation/models/auto_model.py
new file mode 100644
index 0000000..a6ae650
--- /dev/null
+++ b/beir/generation/models/auto_model.py
@@ -0,0 +1,161 @@
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+from tqdm.autonotebook import trange
+import torch, logging, math, queue
+import torch.multiprocessing as mp
+from typing import List, Dict
+
+logger = logging.getLogger(__name__)
+
+
+class QGenModel:
+ def __init__(self, model_path: str, gen_prefix: str = "", use_fast: bool = True, device: str = None, **kwargs):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
+ self.gen_prefix = gen_prefix
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+ logger.info("Use pytorch device: {}".format(self.device))
+ self.model = self.model.to(self.device)
+
+ def generate(self, corpus: List[Dict[str, str]], ques_per_passage: int, top_k: int, max_length: int, top_p: float = None, temperature: float = None) -> List[str]:
+
+ texts = [(self.gen_prefix + doc["title"] + " " + doc["text"]) for doc in corpus]
+ encodings = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
+
+ # Top-p nucleus sampling
+ # https://huggingface.co/blog/how-to-generate
+ with torch.no_grad():
+ if not temperature:
+ outs = self.model.generate(
+ input_ids=encodings['input_ids'].to(self.device),
+ do_sample=True,
+ max_length=max_length, # 64
+ top_k=top_k, # 25
+ top_p=top_p, # 0.95
+ num_return_sequences=ques_per_passage # 1
+ )
+ else:
+ outs = self.model.generate(
+ input_ids=encodings['input_ids'].to(self.device),
+ do_sample=True,
+ max_length=max_length, # 64
+ top_k=top_k, # 25
+ temperature=temperature,
+ num_return_sequences=ques_per_passage # 1
+ )
+
+ return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
+
+ def start_multi_process_pool(self, target_devices: List[str] = None):
+ """
+ Starts multi process to process the encoding with several, independent processes.
+ This method is recommended if you want to encode on multiple GPUs. It is advised
+ to start only one process per GPU. This method works together with encode_multi_process
+ :param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used
+ :return: Returns a dict with the target processes, an input queue and and output queue.
+ """
+ if target_devices is None:
+ if torch.cuda.is_available():
+ target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
+ else:
+ logger.info("CUDA is not available. Start 4 CPU worker")
+ target_devices = ['cpu']*4
+
+ logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
+
+ ctx = mp.get_context('spawn')
+ input_queue = ctx.Queue()
+ output_queue = ctx.Queue()
+ processes = []
+
+ for cuda_id in target_devices:
+ p = ctx.Process(target=QGenModel._generate_multi_process_worker, args=(cuda_id, self.model, self.tokenizer, input_queue, output_queue), daemon=True)
+ p.start()
+ processes.append(p)
+
+ return {'input': input_queue, 'output': output_queue, 'processes': processes}
+
+ @staticmethod
+ def stop_multi_process_pool(pool):
+ """
+ Stops all processes started with start_multi_process_pool
+ """
+ for p in pool['processes']:
+ p.terminate()
+
+ for p in pool['processes']:
+ p.join()
+ p.close()
+
+ pool['input'].close()
+ pool['output'].close()
+
+ @staticmethod
+ def _generate_multi_process_worker(target_device: str, model, tokenizer, input_queue, results_queue):
+ """
+ Internal working process to generate questions in multi-process setup
+ """
+ while True:
+ try:
+ id, batch_size, texts, ques_per_passage, top_p, top_k, max_length = input_queue.get()
+ model = model.to(target_device)
+ generated_texts = []
+
+ for start_idx in trange(0, len(texts), batch_size, desc='{}'.format(target_device)):
+ texts_batch = texts[start_idx:start_idx + batch_size]
+ encodings = tokenizer(texts_batch, padding=True, truncation=True, return_tensors="pt")
+ with torch.no_grad():
+ outs = model.generate(
+ input_ids=encodings['input_ids'].to(target_device),
+ do_sample=True,
+ max_length=max_length, # 64
+ top_k=top_k, # 25
+ top_p=top_p, # 0.95
+ num_return_sequences=ques_per_passage # 1
+ )
+ generated_texts += tokenizer.batch_decode(outs, skip_special_tokens=True)
+
+ results_queue.put([id, generated_texts])
+ except queue.Empty:
+ break
+
+ def generate_multi_process(self, corpus: List[Dict[str, str]], ques_per_passage: int, top_p: int, top_k: int, max_length: int,
+ pool: Dict[str, object], batch_size: int = 32, chunk_size: int = None):
+ """
+ This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
+ and sent to individual processes, which encode these on the different GPUs. This method is only suitable
+ for encoding large sets of sentences
+ :param sentences: List of sentences
+ :param pool: A pool of workers started with SentenceTransformer.start_multi_process_pool
+ :param batch_size: Encode sentences with batch size
+ :param chunk_size: Sentences are chunked and sent to the individual processes. If none, it determine a sensible size.
+ :return: Numpy matrix with all embeddings
+ """
+
+ texts = [(self.gen_prefix + doc["title"] + " " + doc["text"]) for doc in corpus]
+
+ if chunk_size is None:
+ chunk_size = min(math.ceil(len(texts) / len(pool["processes"]) / 10), 5000)
+
+ logger.info("Chunk data into packages of size {}".format(chunk_size))
+
+ input_queue = pool['input']
+ last_chunk_id = 0
+ chunk = []
+
+ for doc_text in texts:
+ chunk.append(doc_text)
+ if len(chunk) >= chunk_size:
+ input_queue.put([last_chunk_id, batch_size, chunk, ques_per_passage, top_p, top_k, max_length])
+ last_chunk_id += 1
+ chunk = []
+
+ if len(chunk) > 0:
+ input_queue.put([last_chunk_id, batch_size, chunk, ques_per_passage, top_p, top_k, max_length])
+ last_chunk_id += 1
+
+ output_queue = pool['output']
+
+ results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
+ queries = [result[1] for result in results_list]
+
+ return [item for sublist in queries for item in sublist]
\ No newline at end of file
diff --git a/beir/generation/models/tilde.py b/beir/generation/models/tilde.py
new file mode 100644
index 0000000..7150d25
--- /dev/null
+++ b/beir/generation/models/tilde.py
@@ -0,0 +1,77 @@
+from transformers import BertLMHeadModel, BertTokenizer, DataCollatorWithPadding
+from tqdm.autonotebook import trange
+import torch, logging, math, queue
+import torch.multiprocessing as mp
+from typing import List, Dict
+from nltk.corpus import stopwords
+import numpy as np
+import re
+
+logger = logging.getLogger(__name__)
+
+class TILDE:
+ def __init__(self, model_path: str, gen_prefix: str = "", use_fast: bool = True, device: str = None, **kwargs):
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast=use_fast)
+ self.model = BertLMHeadModel.from_pretrained(model_path)
+ self.gen_prefix = gen_prefix
+ _, self.bad_ids = self._clean_vocab(self.tokenizer)
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+ logger.info("Use pytorch device: {}".format(self.device))
+ self.model = self.model.to(self.device)
+
+ def _clean_vocab(self, tokenizer, do_stopwords=True):
+ if do_stopwords:
+ stop_words = set(stopwords.words('english'))
+ # keep some common words in ms marco questions
+ # stop_words.difference_update(["where", "how", "what", "when", "which", "why", "who"])
+ stop_words.add("definition")
+
+ vocab = tokenizer.get_vocab()
+ tokens = vocab.keys()
+
+ good_ids = []
+ bad_ids = []
+
+ for stop_word in stop_words:
+ ids = tokenizer(stop_word, add_special_tokens=False)["input_ids"]
+ if len(ids) == 1:
+ bad_ids.append(ids[0])
+
+ for token in tokens:
+ token_id = vocab[token]
+ if token_id in bad_ids:
+ continue
+
+ if token[0] == '#' and len(token) > 1:
+ good_ids.append(token_id)
+ else:
+ if not re.match("^[A-Za-z0-9_-]*$", token):
+ bad_ids.append(token_id)
+ else:
+ good_ids.append(token_id)
+ bad_ids.append(2015) # add ##s to stopwords
+ return good_ids, bad_ids
+
+ def generate(self, corpus: List[Dict[str, str]], top_k: int, max_length: int) -> List[str]:
+
+ expansions = []
+ texts_batch = [(self.gen_prefix + doc["title"] + " " + doc["text"]) for doc in corpus]
+ encode_texts = np.array(self.tokenizer.batch_encode_plus(
+ texts_batch,
+ max_length=max_length,
+ truncation='only_first',
+ return_attention_mask=False,
+ padding='max_length')['input_ids'])
+
+ encode_texts[:,0] = 1
+ encoded_texts_gpu = torch.tensor(encode_texts).to(self.device)
+
+ with torch.no_grad():
+ logits = self.model(encoded_texts_gpu, return_dict=True).logits[:, 0]
+ batch_selected = torch.topk(logits, top_k).indices.cpu().numpy()
+
+ for idx, selected in enumerate(batch_selected):
+ expand_term_ids = np.setdiff1d(np.setdiff1d(selected, encode_texts[idx], assume_unique=True), self.bad_ids, assume_unique=True)
+ expansions.append(self.tokenizer.decode(expand_term_ids))
+
+ return expansions
\ No newline at end of file
diff --git a/beir/logging.py b/beir/logging.py
new file mode 100644
index 0000000..59d53be
--- /dev/null
+++ b/beir/logging.py
@@ -0,0 +1,16 @@
+import logging
+import tqdm
+
+class LoggingHandler(logging.Handler):
+ def __init__(self, level=logging.NOTSET):
+ super().__init__(level)
+
+ def emit(self, record):
+ try:
+ msg = self.format(record)
+ tqdm.tqdm.write(msg)
+ self.flush()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ self.handleError(record)
\ No newline at end of file
diff --git a/beir/losses/__init__.py b/beir/losses/__init__.py
new file mode 100644
index 0000000..47ec4b8
--- /dev/null
+++ b/beir/losses/__init__.py
@@ -0,0 +1,2 @@
+from .bpr_loss import BPRLoss
+from .margin_mse_loss import MarginMSELoss
\ No newline at end of file
diff --git a/beir/losses/bpr_loss.py b/beir/losses/bpr_loss.py
new file mode 100644
index 0000000..22c8049
--- /dev/null
+++ b/beir/losses/bpr_loss.py
@@ -0,0 +1,74 @@
+import math
+import torch
+from typing import Iterable, Dict
+from sentence_transformers import SentenceTransformer, util
+
+class BPRLoss(torch.nn.Module):
+ """
+ This loss expects as input a batch consisting of sentence triplets (a_1, p_1, n_1), (a_2, p_2, n_2)..., (a_n, p_n, n_n)
+ where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
+ You can also provide one or multiple hard negatives (n_1, n_2, ..) per anchor-positive pair by structering the data like this.
+
+ We define the loss function as defined in ACL2021: Efficient Passage Retrieval with Hashing for Open-domain Question Answering.
+ For more information: https://arxiv.org/abs/2106.00882
+
+ Parts of the code has been reused from the source code of BPR (Binary Passage Retriever): https://github.com/studio-ousia/bpr.
+
+ We combine two losses for training a binary code based retriever model =>
+ 1. Margin Ranking Loss: https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html
+ 2. Cross Entropy Loss (or Multiple Negatives Ranking Loss): https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
+
+ """
+ def __init__(self, model: SentenceTransformer, scale: float = 1.0, similarity_fct = util.dot_score, binary_ranking_loss_margin: float = 2.0, hashnet_gamma: float = 0.1):
+ """
+ :param model: SentenceTransformer model
+ :param scale: Output of similarity function is multiplied by scale value
+ :param similarity_fct: similarity function between sentence embeddings. By default, dot_score. Can also be set to cosine similarity.
+ :param binary_ranking_loss_margin: margin used for binary loss. By default original authors found enhanced performance = 2.0, (Appendix D, https://arxiv.org/abs/2106.00882).
+ :param hashnet_gamma: hashnet gamma function used for scaling tanh function. By default original authors found enhanced performance = 0.1, (Appendix B, https://arxiv.org/abs/2106.00882).
+ """
+ super(BPRLoss, self).__init__()
+ self.global_step = 0
+ self.model = model
+ self.scale = scale
+ self.similarity_fct = similarity_fct
+ self.hashnet_gamma = hashnet_gamma
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
+ self.margin_ranking_loss = torch.nn.MarginRankingLoss(margin=binary_ranking_loss_margin)
+
+ def convert_to_binary(self, input_repr: torch.Tensor) -> torch.Tensor:
+ """
+ The paper uses tanh function as an approximation for sign function, because of its incompatibility with backpropogation.
+ """
+ scale = math.pow((1.0 + self.global_step * self.hashnet_gamma), 0.5)
+ return torch.tanh(input_repr * scale)
+
+ def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
+
+ reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
+ embeddings_a = reps[0]
+ embeddings_b = torch.cat([self.convert_to_binary(rep) for rep in reps[1:]])
+
+ # Dense Loss (or Multiple Negatives Ranking Loss)
+ # Used to learn the encoder model
+ scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device) # Example a[i] should match with b[i]
+ dense_loss = self.cross_entropy_loss(scores, labels)
+
+ # Binary Loss (or Margin Ranking Loss)
+ # Used to learn to binary coded model
+ binary_query_repr = self.convert_to_binary(embeddings_a)
+ binary_query_scores = torch.matmul(binary_query_repr, embeddings_b.transpose(0, 1))
+ pos_mask = binary_query_scores.new_zeros(binary_query_scores.size(), dtype=torch.bool)
+ for n, label in enumerate(labels):
+ pos_mask[n, label] = True
+ pos_bin_scores = torch.masked_select(binary_query_scores, pos_mask)
+ pos_bin_scores = pos_bin_scores.repeat_interleave(embeddings_b.size(0) - 1)
+ neg_bin_scores = torch.masked_select(binary_query_scores, torch.logical_not(pos_mask))
+ bin_labels = pos_bin_scores.new_ones(pos_bin_scores.size(), dtype=torch.int64)
+ binary_loss = self.margin_ranking_loss(
+ pos_bin_scores, neg_bin_scores, bin_labels)
+
+ self.global_step += 1
+
+ return dense_loss + binary_loss
diff --git a/beir/losses/margin_mse_loss.py b/beir/losses/margin_mse_loss.py
new file mode 100644
index 0000000..99f5159
--- /dev/null
+++ b/beir/losses/margin_mse_loss.py
@@ -0,0 +1,37 @@
+from .. import util
+import torch
+from torch import nn, Tensor
+from typing import Union, Tuple, List, Iterable, Dict
+from torch.nn import functional as F
+
+
+class MarginMSELoss(nn.Module):
+ """
+ Computes the Margin MSE loss between the query, positive passage and negative passage. This loss
+ is used to train dense-models using cross-architecture knowledge distillation setup.
+
+ Margin MSE Loss is defined as from (Eq.11) in Sebastian Hofstätter et al. in https://arxiv.org/abs/2010.02666:
+ Loss(𝑄, 𝑃+, 𝑃−) = MSE(𝑀𝑠(𝑄, 𝑃+) − 𝑀𝑠(𝑄, 𝑃−), 𝑀𝑡(𝑄, 𝑃+) − 𝑀𝑡(𝑄, 𝑃−))
+ where 𝑄: Query, 𝑃+: Relevant passage, 𝑃−: Non-relevant passage, 𝑀𝑠: Student model, 𝑀𝑡: Teacher model
+
+ Remember: Pass the difference in scores of the passages as labels.
+ """
+ def __init__(self, model, scale: float = 1.0, similarity_fct = 'dot'):
+ super(MarginMSELoss, self).__init__()
+ self.model = model
+ self.scale = scale
+ self.similarity_fct = similarity_fct
+ self.loss_fct = nn.MSELoss()
+
+ def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
+ # sentence_features: query, positive passage, negative passage
+ reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
+ embeddings_query = reps[0]
+ embeddings_pos = reps[1]
+ embeddings_neg = reps[2]
+
+ scores_pos = (embeddings_query * embeddings_pos).sum(dim=-1) * self.scale
+ scores_neg = (embeddings_query * embeddings_neg).sum(dim=-1) * self.scale
+ margin_pred = scores_pos - scores_neg
+
+ return self.loss_fct(margin_pred, labels)
diff --git a/beir/reranking/__init__.py b/beir/reranking/__init__.py
new file mode 100644
index 0000000..722bf2c
--- /dev/null
+++ b/beir/reranking/__init__.py
@@ -0,0 +1 @@
+from .rerank import Rerank
\ No newline at end of file
diff --git a/beir/reranking/models/__init__.py b/beir/reranking/models/__init__.py
new file mode 100644
index 0000000..15e480c
--- /dev/null
+++ b/beir/reranking/models/__init__.py
@@ -0,0 +1,2 @@
+from .cross_encoder import CrossEncoder
+from .mono_t5 import MonoT5
\ No newline at end of file
diff --git a/beir/reranking/models/cross_encoder.py b/beir/reranking/models/cross_encoder.py
new file mode 100644
index 0000000..e1754f3
--- /dev/null
+++ b/beir/reranking/models/cross_encoder.py
@@ -0,0 +1,13 @@
+from sentence_transformers.cross_encoder import CrossEncoder as CE
+import numpy as np
+from typing import List, Dict, Tuple
+
+class CrossEncoder:
+ def __init__(self, model_path: str, **kwargs):
+ self.model = CE(model_path, **kwargs)
+
+ def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
+ return self.model.predict(
+ sentences=sentences,
+ batch_size=batch_size,
+ show_progress_bar=show_progress_bar)
diff --git a/beir/reranking/models/mono_t5.py b/beir/reranking/models/mono_t5.py
new file mode 100644
index 0000000..1c94c1c
--- /dev/null
+++ b/beir/reranking/models/mono_t5.py
@@ -0,0 +1,162 @@
+# Majority of the code has been copied from PyGaggle MonoT5 implementation
+# https://github.com/castorini/pygaggle/blob/master/pygaggle/rerank/transformer.py
+
+from transformers import (AutoTokenizer,
+ AutoModelForSeq2SeqLM,
+ PreTrainedModel,
+ PreTrainedTokenizer,
+ T5ForConditionalGeneration)
+from typing import List, Union, Tuple, Mapping, Optional
+from dataclasses import dataclass
+from tqdm.autonotebook import trange
+import torch
+
+
+TokenizerReturnType = Mapping[str, Union[torch.Tensor, List[int],
+ List[List[int]],
+ List[List[str]]]]
+
+@dataclass
+class QueryDocumentBatch:
+ query: str
+ documents: List[str]
+ output: Optional[TokenizerReturnType] = None
+
+ def __len__(self):
+ return len(self.documents)
+
+class QueryDocumentBatchTokenizer:
+ def __init__(self,
+ tokenizer: PreTrainedTokenizer,
+ pattern: str = '{query} {document}',
+ **tokenizer_kwargs):
+ self.tokenizer = tokenizer
+ self.tokenizer_kwargs = tokenizer_kwargs
+ self.pattern = pattern
+
+ def encode(self, strings: List[str]):
+ assert self.tokenizer and self.tokenizer_kwargs is not None, \
+ 'mixin used improperly'
+ ret = self.tokenizer.batch_encode_plus(strings,
+ **self.tokenizer_kwargs)
+ ret['tokens'] = list(map(self.tokenizer.tokenize, strings))
+ return ret
+
+ def traverse_query_document(
+ self, batch_input: Tuple[str, List[str]], batch_size: int):
+ query, doc_texts = batch_input[0], batch_input[1]
+ for batch_idx in range(0, len(doc_texts), batch_size):
+ docs = doc_texts[batch_idx:batch_idx + batch_size]
+ outputs = self.encode([self.pattern.format(
+ query=query,
+ document=doc) for doc in docs])
+ yield QueryDocumentBatch(query, docs, outputs)
+
+class T5BatchTokenizer(QueryDocumentBatchTokenizer):
+ def __init__(self, *args, **kwargs):
+ kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:'
+ if 'return_attention_mask' not in kwargs:
+ kwargs['return_attention_mask'] = True
+ if 'padding' not in kwargs:
+ kwargs['padding'] = 'longest'
+ if 'truncation' not in kwargs:
+ kwargs['truncation'] = True
+ if 'return_tensors' not in kwargs:
+ kwargs['return_tensors'] = 'pt'
+ if 'max_length' not in kwargs:
+ kwargs['max_length'] = 512
+ super().__init__(*args, **kwargs)
+
+
+@torch.no_grad()
+def greedy_decode(model: PreTrainedModel,
+ input_ids: torch.Tensor,
+ length: int,
+ attention_mask: torch.Tensor = None,
+ return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ decode_ids = torch.full((input_ids.size(0), 1),
+ model.config.decoder_start_token_id,
+ dtype=torch.long).to(input_ids.device)
+ encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
+ next_token_logits = None
+ for _ in range(length):
+ model_inputs = model.prepare_inputs_for_generation(
+ decode_ids,
+ encoder_outputs=encoder_outputs,
+ past=None,
+ attention_mask=attention_mask,
+ use_cache=True)
+ outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
+ next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
+ decode_ids = torch.cat([decode_ids,
+ next_token_logits.max(1)[1].unsqueeze(-1)],
+ dim=-1)
+ if return_last_logits:
+ return decode_ids, next_token_logits
+ return decode_ids
+
+
+class MonoT5:
+ def __init__(self,
+ model_path: str,
+ tokenizer: QueryDocumentBatchTokenizer = None,
+ use_amp = True,
+ token_false = None,
+ token_true = None):
+ self.model = self.get_model(model_path)
+ self.tokenizer = tokenizer or self.get_tokenizer(model_path)
+ self.token_false_id, self.token_true_id = self.get_prediction_tokens(
+ model_path, self.tokenizer, token_false, token_true)
+ self.model_path = model_path
+ self.device = next(self.model.parameters(), None).device
+ self.use_amp = use_amp
+
+ @staticmethod
+ def get_model(model_path: str, *args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
+ device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+ device = torch.device(device)
+ return AutoModelForSeq2SeqLM.from_pretrained(model_path, *args, **kwargs).to(device).eval()
+
+ @staticmethod
+ def get_tokenizer(model_path: str, *args, **kwargs) -> T5BatchTokenizer:
+ return T5BatchTokenizer(
+ AutoTokenizer.from_pretrained(model_path, use_fast=False, *args, **kwargs)
+ )
+
+ @staticmethod
+ def get_prediction_tokens(model_path: str, tokenizer, token_false, token_true):
+ if (token_false and token_true):
+ token_false_id = tokenizer.tokenizer.get_vocab()[token_false]
+ token_true_id = tokenizer.tokenizer.get_vocab()[token_true]
+ return token_false_id, token_true_id
+
+ def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, **kwargs) -> List[float]:
+
+ sentence_dict, queries, scores = {}, [], []
+
+ # T5 model requires a batch of single query and top-k documents
+ for (query, doc_text) in sentences:
+ if query not in sentence_dict:
+ sentence_dict[query] = []
+ queries.append(query) # Preserves order of queries
+ sentence_dict[query].append(doc_text)
+
+ for start_idx in trange(0, len(queries), 1): # Take one query at a time
+ batch_input = (queries[start_idx], sentence_dict[queries[start_idx]]) # (single query, top-k docs)
+ for batch in self.tokenizer.traverse_query_document(batch_input, batch_size):
+ with torch.cuda.amp.autocast(enabled=self.use_amp):
+ input_ids = batch.output['input_ids'].to(self.device)
+ attn_mask = batch.output['attention_mask'].to(self.device)
+ _, batch_scores = greedy_decode(self.model,
+ input_ids,
+ length=1,
+ attention_mask=attn_mask,
+ return_last_logits=True)
+
+ batch_scores = batch_scores[:, [self.token_false_id, self.token_true_id]]
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
+ batch_log_probs = batch_scores[:, 1].tolist()
+ scores.extend(batch_log_probs)
+
+ assert len(scores) == len(sentences) # Sanity check, should be equal
+ return scores
\ No newline at end of file
diff --git a/beir/reranking/rerank.py b/beir/reranking/rerank.py
new file mode 100644
index 0000000..9ff9aaa
--- /dev/null
+++ b/beir/reranking/rerank.py
@@ -0,0 +1,45 @@
+import logging
+from typing import Dict, List
+
+logger = logging.getLogger(__name__)
+
+#Parent class for any reranking model
+class Rerank:
+
+ def __init__(self, model, batch_size: int = 128, **kwargs):
+ self.cross_encoder = model
+ self.batch_size = batch_size
+ self.rerank_results = {}
+
+ def rerank(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ results: Dict[str, Dict[str, float]],
+ top_k: int) -> Dict[str, Dict[str, float]]:
+
+ sentence_pairs, pair_ids = [], []
+
+ for query_id in results:
+ if len(results[query_id]) > top_k:
+ for (doc_id, _) in sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]:
+ pair_ids.append([query_id, doc_id])
+ corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
+ sentence_pairs.append([queries[query_id], corpus_text])
+
+ else:
+ for doc_id in results[query_id]:
+ pair_ids.append([query_id, doc_id])
+ corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
+ sentence_pairs.append([queries[query_id], corpus_text])
+
+ #### Starting to Rerank using cross-attention
+ logging.info("Starting To Rerank Top-{}....".format(top_k))
+ rerank_scores = [float(score) for score in self.cross_encoder.predict(sentence_pairs, batch_size=self.batch_size)]
+
+ #### Reranking results
+ self.rerank_results = {query_id: {} for query_id in results}
+ for pair, score in zip(pair_ids, rerank_scores):
+ query_id, doc_id = pair[0], pair[1]
+ self.rerank_results[query_id][doc_id] = score
+
+ return self.rerank_results
diff --git a/beir/retrieval/__init__.py b/beir/retrieval/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/beir/retrieval/custom_metrics.py b/beir/retrieval/custom_metrics.py
new file mode 100644
index 0000000..06134c7
--- /dev/null
+++ b/beir/retrieval/custom_metrics.py
@@ -0,0 +1,117 @@
+import logging
+from typing import List, Dict, Union, Tuple
+
+def mrr(qrels: Dict[str, Dict[str, int]],
+ results: Dict[str, Dict[str, float]],
+ k_values: List[int]) -> Tuple[Dict[str, float]]:
+
+ MRR = {}
+
+ for k in k_values:
+ MRR[f"MRR@{k}"] = 0.0
+
+ k_max, top_hits = max(k_values), {}
+ logging.info("\n")
+
+ for query_id, doc_scores in results.items():
+ top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
+
+ for query_id in top_hits:
+ query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
+ for k in k_values:
+ for rank, hit in enumerate(top_hits[query_id][0:k]):
+ if hit[0] in query_relevant_docs:
+ MRR[f"MRR@{k}"] += 1.0 / (rank + 1)
+ break
+
+ for k in k_values:
+ MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"]/len(qrels), 5)
+ logging.info("MRR@{}: {:.4f}".format(k, MRR[f"MRR@{k}"]))
+
+ return MRR
+
+def recall_cap(qrels: Dict[str, Dict[str, int]],
+ results: Dict[str, Dict[str, float]],
+ k_values: List[int]) -> Tuple[Dict[str, float]]:
+
+ capped_recall = {}
+
+ for k in k_values:
+ capped_recall[f"R_cap@{k}"] = 0.0
+
+ k_max = max(k_values)
+ logging.info("\n")
+
+ for query_id, doc_scores in results.items():
+ top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
+ query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
+ for k in k_values:
+ retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0]
+ denominator = min(len(query_relevant_docs), k)
+ capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator)
+
+ for k in k_values:
+ capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5)
+ logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"]))
+
+ return capped_recall
+
+
+def hole(qrels: Dict[str, Dict[str, int]],
+ results: Dict[str, Dict[str, float]],
+ k_values: List[int]) -> Tuple[Dict[str, float]]:
+
+ Hole = {}
+
+ for k in k_values:
+ Hole[f"Hole@{k}"] = 0.0
+
+ annotated_corpus = set()
+ for _, docs in qrels.items():
+ for doc_id, score in docs.items():
+ annotated_corpus.add(doc_id)
+
+ k_max = max(k_values)
+ logging.info("\n")
+
+ for _, scores in results.items():
+ top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
+ for k in k_values:
+ hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus]
+ Hole[f"Hole@{k}"] += len(hole_docs) / k
+
+ for k in k_values:
+ Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(qrels), 5)
+ logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"]))
+
+ return Hole
+
+def top_k_accuracy(
+ qrels: Dict[str, Dict[str, int]],
+ results: Dict[str, Dict[str, float]],
+ k_values: List[int]) -> Tuple[Dict[str, float]]:
+
+ top_k_acc = {}
+
+ for k in k_values:
+ top_k_acc[f"Accuracy@{k}"] = 0.0
+
+ k_max, top_hits = max(k_values), {}
+ logging.info("\n")
+
+ for query_id, doc_scores in results.items():
+ top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]]
+
+ for query_id in top_hits:
+ query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
+ for k in k_values:
+ for relevant_doc_id in query_relevant_docs:
+ if relevant_doc_id in top_hits[query_id][0:k]:
+ top_k_acc[f"Accuracy@{k}"] += 1.0
+ break
+
+ for k in k_values:
+ top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5)
+ logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"]))
+
+ return top_k_acc
\ No newline at end of file
diff --git a/beir/retrieval/evaluation.py b/beir/retrieval/evaluation.py
new file mode 100644
index 0000000..4db752a
--- /dev/null
+++ b/beir/retrieval/evaluation.py
@@ -0,0 +1,111 @@
+import pytrec_eval
+import logging
+from typing import Type, List, Dict, Union, Tuple
+from .search.dense import DenseRetrievalExactSearch as DRES
+from .search.dense import DenseRetrievalFaissSearch as DRFS
+from .search.lexical import BM25Search as BM25
+from .search.sparse import SparseSearch as SS
+from .custom_metrics import mrr, recall_cap, hole, top_k_accuracy
+
+logger = logging.getLogger(__name__)
+
+class EvaluateRetrieval:
+
+ def __init__(self, retriever: Union[Type[DRES], Type[DRFS], Type[BM25], Type[SS]] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
+ self.k_values = k_values
+ self.top_k = max(k_values)
+ self.retriever = retriever
+ self.score_function = score_function
+
+ def retrieve(self, corpus: Dict[str, Dict[str, str]], queries: Dict, query_negations:Dict=None, **kwargs) -> Dict[str, Dict[str, float]]:
+ if not self.retriever:
+ raise ValueError("Model/Technique has not been provided!")
+ return self.retriever.search(corpus, queries, self.top_k, self.score_function, query_negations=query_negations, **kwargs)
+
+ def rerank(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ results: Dict[str, Dict[str, float]],
+ top_k: int) -> Dict[str, Dict[str, float]]:
+
+ new_corpus = {}
+
+ for query_id in results:
+ if len(results[query_id]) > top_k:
+ for (doc_id, _) in sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]:
+ new_corpus[doc_id] = corpus[doc_id]
+ else:
+ for doc_id in results[query_id]:
+ new_corpus[doc_id] = corpus[doc_id]
+
+ return self.retriever.search(new_corpus, queries, top_k, self.score_function)
+
+ @staticmethod
+ def evaluate(qrels: Dict[str, Dict[str, int]],
+ results: Dict[str, Dict[str, float]],
+ k_values: List[int],
+ ignore_identical_ids: bool=True) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:
+
+ if ignore_identical_ids:
+ logging.info('For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.')
+ popped = []
+ for qid, rels in results.items():
+ for pid in list(rels):
+ if qid == pid:
+ results[qid].pop(pid)
+ popped.append(pid)
+
+ ndcg = {}
+ _map = {}
+ recall = {}
+ precision = {}
+
+ for k in k_values:
+ ndcg[f"NDCG@{k}"] = 0.0
+ _map[f"MAP@{k}"] = 0.0
+ recall[f"Recall@{k}"] = 0.0
+ precision[f"P@{k}"] = 0.0
+
+ map_string = "map_cut." + ",".join([str(k) for k in k_values])
+ ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
+ recall_string = "recall." + ",".join([str(k) for k in k_values])
+ precision_string = "P." + ",".join([str(k) for k in k_values])
+ evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string})
+ scores = evaluator.evaluate(results)
+
+ for query_id in scores.keys():
+ for k in k_values:
+ ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
+ _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
+ recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
+ precision[f"P@{k}"] += scores[query_id]["P_"+ str(k)]
+
+ for k in k_values:
+ ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"]/len(scores), 5)
+ _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"]/len(scores), 5)
+ recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"]/len(scores), 5)
+ precision[f"P@{k}"] = round(precision[f"P@{k}"]/len(scores), 5)
+
+ for eval in [ndcg, _map, recall, precision]:
+ logging.info("\n")
+ for k in eval.keys():
+ logging.info("{}: {:.4f}".format(k, eval[k]))
+
+ return ndcg, _map, recall, precision
+
+ @staticmethod
+ def evaluate_custom(qrels: Dict[str, Dict[str, int]],
+ results: Dict[str, Dict[str, float]],
+ k_values: List[int], metric: str) -> Tuple[Dict[str, float]]:
+
+ if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
+ return mrr(qrels, results, k_values)
+
+ elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]:
+ return recall_cap(qrels, results, k_values)
+
+ elif metric.lower() in ["hole", "hole@k"]:
+ return hole(qrels, results, k_values)
+
+ elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
+ return top_k_accuracy(qrels, results, k_values)
\ No newline at end of file
diff --git a/beir/retrieval/models/__init__.py b/beir/retrieval/models/__init__.py
new file mode 100644
index 0000000..089b73f
--- /dev/null
+++ b/beir/retrieval/models/__init__.py
@@ -0,0 +1,8 @@
+from .sentence_bert import SentenceBERT
+from .use_qa import UseQA
+from .sparta import SPARTA
+from .dpr import DPR
+from .bpr import BinarySentenceBERT
+from .unicoil import UniCOIL
+from .splade import SPLADE
+from .tldr import TLDR
diff --git a/beir/retrieval/models/bpr.py b/beir/retrieval/models/bpr.py
new file mode 100644
index 0000000..7150563
--- /dev/null
+++ b/beir/retrieval/models/bpr.py
@@ -0,0 +1,31 @@
+from sentence_transformers import SentenceTransformer
+from torch import Tensor
+from typing import List, Dict, Union, Tuple
+import numpy as np
+
+class BinarySentenceBERT:
+ def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", threshold: Union[float, Tensor] = 0, **kwargs):
+ self.sep = sep
+ self.threshold = threshold
+
+ if isinstance(model_path, str):
+ self.q_model = SentenceTransformer(model_path)
+ self.doc_model = self.q_model
+
+ elif isinstance(model_path, tuple):
+ self.q_model = SentenceTransformer(model_path[0])
+ self.doc_model = SentenceTransformer(model_path[1])
+
+ def _convert_embedding_to_binary_code(self, embeddings: List[Tensor]) -> List[Tensor]:
+ return embeddings.new_ones(embeddings.size()).masked_fill_(embeddings < self.threshold, -1.0)
+
+ def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
+ return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
+
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> np.ndarray:
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus]
+ embs = self.doc_model.encode(sentences, batch_size=batch_size, convert_to_tensor=True, **kwargs)
+ embs = self._convert_embedding_to_binary_code(embs).cpu().numpy()
+ embs = np.where(embs == -1, 0, embs).astype(np.bool)
+ embs = np.packbits(embs).reshape(embs.shape[0], -1)
+ return np.vstack(embs)
\ No newline at end of file
diff --git a/beir/retrieval/models/dpr.py b/beir/retrieval/models/dpr.py
new file mode 100644
index 0000000..fa98000
--- /dev/null
+++ b/beir/retrieval/models/dpr.py
@@ -0,0 +1,42 @@
+from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
+from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast
+from typing import Union, List, Dict, Tuple
+from tqdm.autonotebook import trange
+import torch
+
+class DPR:
+ def __init__(self, model_path: Union[str, Tuple] = None, **kwargs):
+ # Query tokenizer and model
+ self.q_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(model_path[0])
+ self.q_model = DPRQuestionEncoder.from_pretrained(model_path[0])
+ self.q_model.cuda()
+ self.q_model.eval()
+
+ # Context tokenizer and model
+ self.ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(model_path[1])
+ self.ctx_model = DPRContextEncoder.from_pretrained(model_path[1])
+ self.ctx_model.cuda()
+ self.ctx_model.eval()
+
+ def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> torch.Tensor:
+ query_embeddings = []
+ with torch.no_grad():
+ for start_idx in trange(0, len(queries), batch_size):
+ encoded = self.q_tokenizer(queries[start_idx:start_idx+batch_size], truncation=True, padding=True, return_tensors='pt')
+ model_out = self.q_model(encoded['input_ids'].cuda(), attention_mask=encoded['attention_mask'].cuda())
+ query_embeddings += model_out.pooler_output
+
+ return torch.stack(query_embeddings)
+
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> torch.Tensor:
+
+ corpus_embeddings = []
+ with torch.no_grad():
+ for start_idx in trange(0, len(corpus), batch_size):
+ titles = [row['title'] for row in corpus[start_idx:start_idx+batch_size]]
+ texts = [row['text'] for row in corpus[start_idx:start_idx+batch_size]]
+ encoded = self.ctx_tokenizer(titles, texts, truncation='longest_first', padding=True, return_tensors='pt')
+ model_out = self.ctx_model(encoded['input_ids'].cuda(), attention_mask=encoded['attention_mask'].cuda())
+ corpus_embeddings += model_out.pooler_output.detach()
+
+ return torch.stack(corpus_embeddings)
\ No newline at end of file
diff --git a/beir/retrieval/models/sentence_bert.py b/beir/retrieval/models/sentence_bert.py
new file mode 100644
index 0000000..d3b73c0
--- /dev/null
+++ b/beir/retrieval/models/sentence_bert.py
@@ -0,0 +1,67 @@
+from sentence_transformers import SentenceTransformer
+from torch import Tensor
+import torch.multiprocessing as mp
+from typing import List, Dict, Union, Tuple
+import numpy as np
+import logging
+from datasets import Dataset
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceBERT:
+ def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
+ self.sep = sep
+
+ if isinstance(model_path, str):
+ self.q_model = SentenceTransformer(model_path)
+ self.doc_model = self.q_model
+
+ elif isinstance(model_path, tuple):
+ self.q_model = SentenceTransformer(model_path[0])
+ self.doc_model = SentenceTransformer(model_path[1])
+
+ def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, object]:
+ logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
+
+ ctx = mp.get_context('spawn')
+ input_queue = ctx.Queue()
+ output_queue = ctx.Queue()
+ processes = []
+
+ for process_id, device_name in enumerate(target_devices):
+ p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(process_id, device_name, self.doc_model, input_queue, output_queue), daemon=True)
+ p.start()
+ processes.append(p)
+
+ return {'input': input_queue, 'output': output_queue, 'processes': processes}
+
+ def stop_multi_process_pool(self, pool: Dict[str, object]):
+ output_queue = pool['output']
+ [output_queue.get() for _ in range(len(pool['processes']))]
+ return self.doc_model.stop_multi_process_pool(pool)
+
+ def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
+ return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
+
+ def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
+ if type(corpus) is dict:
+ sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
+ else:
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
+ return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)
+
+ ## Encoding corpus in parallel
+ def encode_corpus_parallel(self, corpus: Union[List[Dict[str, str]], Dataset], pool: Dict[str, str], batch_size: int = 8, chunk_id: int = None, **kwargs):
+ if type(corpus) is dict:
+ sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
+ else:
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
+
+ if chunk_id is not None and chunk_id >= len(pool['processes']):
+ output_queue = pool['output']
+ output_queue.get()
+
+ input_queue = pool['input']
+ input_queue.put([chunk_id, batch_size, sentences])
diff --git a/beir/retrieval/models/sparta.py b/beir/retrieval/models/sparta.py
new file mode 100644
index 0000000..972362e
--- /dev/null
+++ b/beir/retrieval/models/sparta.py
@@ -0,0 +1,77 @@
+from typing import List, Dict, Union, Tuple
+from tqdm.autonotebook import trange
+from transformers import AutoTokenizer, AutoModel
+from scipy.sparse import csr_matrix
+import torch
+import numpy as np
+
+class SPARTA:
+ def __init__(self, model_path: str = None, sep: str = " ", sparse_vector_dim: int = 2000, max_length: int = 500, **kwargs):
+ self.sep = sep
+ self.max_length = max_length
+ self.sparse_vector_dim = sparse_vector_dim
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+ self.model = AutoModel.from_pretrained(model_path)
+ self.initialization()
+ self.bert_input_embeddings = self._bert_input_embeddings()
+
+ def initialization(self):
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.model.to(self.device)
+ self.model.eval()
+
+ def _bert_input_embeddings(self):
+ bert_input_embs = self.model.embeddings.word_embeddings(
+ torch.tensor(list(range(0, len(self.tokenizer))), device=self.device))
+
+ # Set Special tokens [CLS] [MASK] etc. to zero
+ for special_id in self.tokenizer.all_special_ids:
+ bert_input_embs[special_id] = 0 * bert_input_embs[special_id]
+
+ return bert_input_embs
+
+ def _compute_sparse_embeddings(self, documents):
+ sparse_embeddings = []
+ with torch.no_grad():
+ tokens = self.tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=self.max_length).to(self.device)
+ document_embs = self.model(**tokens).last_hidden_state
+ for document_emb in document_embs:
+ scores = torch.matmul(self.bert_input_embeddings, document_emb.transpose(0, 1))
+ max_scores = torch.max(scores, dim=-1).values
+ scores = torch.log(torch.relu(max_scores) + 1)
+ top_results = torch.topk(scores, k=self.sparse_vector_dim)
+ tids = top_results[1].cpu().detach().tolist()
+ scores = top_results[0].cpu().detach().tolist()
+ passage_emb = []
+
+ for tid, score in zip(tids, scores):
+ if score > 0:
+ passage_emb.append((tid, score))
+ else:
+ break
+ sparse_embeddings.append(passage_emb)
+
+ return sparse_embeddings
+
+ def encode_query(self, query: str, **kwargs):
+ return self.tokenizer(query, add_special_tokens=False)['input_ids']
+
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 16, **kwargs):
+
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus]
+ sparse_idx = 0
+ num_elements = len(sentences) * self.sparse_vector_dim
+ col = np.zeros(num_elements, dtype=np.int)
+ row = np.zeros(num_elements, dtype=np.int)
+ values = np.zeros(num_elements, dtype=np.float)
+
+ for start_idx in trange(0, len(sentences), batch_size, desc="docs"):
+ doc_embs = self._compute_sparse_embeddings(sentences[start_idx: start_idx + batch_size])
+ for doc_id, emb in enumerate(doc_embs):
+ for tid, score in emb:
+ col[sparse_idx] = start_idx+doc_id
+ row[sparse_idx] = tid
+ values[sparse_idx] = score
+ sparse_idx += 1
+
+ return csr_matrix((values, (row, col)), shape=(len(self.bert_input_embeddings), len(sentences)), dtype=np.float)
\ No newline at end of file
diff --git a/beir/retrieval/models/splade.py b/beir/retrieval/models/splade.py
new file mode 100644
index 0000000..88a8962
--- /dev/null
+++ b/beir/retrieval/models/splade.py
@@ -0,0 +1,146 @@
+import logging
+from typing import List, Dict, Union
+import numpy as np
+import torch
+from numpy import ndarray
+from torch import Tensor
+from tqdm.autonotebook import trange
+from transformers import AutoModelForMaskedLM, AutoTokenizer
+from sentence_transformers.util import batch_to_device
+
+logger = logging.getLogger(__name__)
+
+
+class SPLADE:
+ def __init__(self, model_path: str = None, sep: str = " ", max_length: int = 256, **kwargs):
+ self.max_length = max_length
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+ self.model = SpladeNaver(model_path)
+ self.model.eval()
+
+ # Write your own encoding query function (Returns: Query embeddings as numpy array)
+ def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
+ return self.model.encode_sentence_bert(self.tokenizer, queries, is_q=True, maxlen=self.max_length)
+
+ # Write your own encoding corpus function (Returns: Document embeddings as numpy array) out_features
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray:
+ sentences = [(doc["title"] + ' ' + doc["text"]).strip() for doc in corpus]
+ return self.model.encode_sentence_bert(self.tokenizer, sentences, maxlen=self.max_length)
+
+
+# Chunks of this code has been taken from: https://github.com/naver/splade/blob/main/beir_evaluation/models.py
+# For more details, please refer to SPLADE by Thibault Formal, Benjamin Piwowarski and Stéphane Clinchant (https://arxiv.org/abs/2107.05720)
+class SpladeNaver(torch.nn.Module):
+ def __init__(self, model_path):
+ super().__init__()
+ self.transformer = AutoModelForMaskedLM.from_pretrained(model_path)
+
+ def forward(self, **kwargs):
+ out = self.transformer(**kwargs)["logits"] # output (logits) of MLM head, shape (bs, pad_len, voc_size)
+ return torch.max(torch.log(1 + torch.relu(out)) * kwargs["attention_mask"].unsqueeze(-1), dim=1).values
+
+ def _text_length(self, text: Union[List[int], List[List[int]]]):
+ """helper function to get the length for the input text. Text can be either
+ a list of ints (which means a single text as input), or a tuple of list of ints
+ (representing several text inputs to the model).
+ """
+
+ if isinstance(text, dict): # {key: value} case
+ return len(next(iter(text.values())))
+ elif not hasattr(text, '__len__'): # Object has no len() method
+ return 1
+ elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
+ return len(text)
+ else:
+ return sum([len(t) for t in text]) # Sum of length of individual strings
+
+ def encode_sentence_bert(self, tokenizer, sentences: Union[str, List[str], List[int]],
+ batch_size: int = 32,
+ show_progress_bar: bool = None,
+ output_value: str = 'sentence_embedding',
+ convert_to_numpy: bool = True,
+ convert_to_tensor: bool = False,
+ device: str = None,
+ normalize_embeddings: bool = False,
+ maxlen: int = 512,
+ is_q: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
+ """
+ Computes sentence embeddings
+ :param sentences: the sentences to embed
+ :param batch_size: the batch size used for the computation
+ :param show_progress_bar: Output a progress bar when encode sentences
+ :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings.
+ :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
+ :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
+ :param device: Which torch.device to use for the computation
+ :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
+ :return:
+ By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
+ """
+ self.eval()
+ if show_progress_bar is None:
+ show_progress_bar = True
+
+ if convert_to_tensor:
+ convert_to_numpy = False
+
+ if output_value == 'token_embeddings':
+ convert_to_tensor = False
+ convert_to_numpy = False
+
+ input_was_string = False
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
+ # Cast an individual sentence to a list with length 1
+ sentences = [sentences]
+ input_was_string = True
+
+ if device is None:
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+ self.to(device)
+
+ all_embeddings = []
+ length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
+ sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
+
+ for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
+ sentences_batch = sentences_sorted[start_index:start_index + batch_size]
+ # features = tokenizer(sentences_batch)
+ # print(sentences_batch)
+ features = tokenizer(sentences_batch,
+ add_special_tokens=True,
+ padding="longest", # pad to max sequence length in batch
+ truncation="only_first", # truncates to self.max_length
+ max_length=maxlen,
+ return_attention_mask=True,
+ return_tensors="pt")
+ # print(features)
+ features = batch_to_device(features, device)
+
+ with torch.no_grad():
+ out_features = self.forward(**features)
+ if output_value == 'token_embeddings':
+ embeddings = []
+ for token_emb, attention in zip(out_features[output_value], out_features['attention_mask']):
+ last_mask_id = len(attention) - 1
+ while last_mask_id > 0 and attention[last_mask_id].item() == 0:
+ last_mask_id -= 1
+ embeddings.append(token_emb[0:last_mask_id + 1])
+ else: # Sentence embeddings
+ # embeddings = out_features[output_value]
+ embeddings = out_features
+ embeddings = embeddings.detach()
+ if normalize_embeddings:
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+ # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
+ if convert_to_numpy:
+ embeddings = embeddings.cpu()
+ all_embeddings.extend(embeddings)
+ all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
+ if convert_to_tensor:
+ all_embeddings = torch.stack(all_embeddings)
+ elif convert_to_numpy:
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
+ if input_was_string:
+ all_embeddings = all_embeddings[0]
+ return all_embeddings
\ No newline at end of file
diff --git a/beir/retrieval/models/tldr.py b/beir/retrieval/models/tldr.py
new file mode 100644
index 0000000..a7ff1b0
--- /dev/null
+++ b/beir/retrieval/models/tldr.py
@@ -0,0 +1,56 @@
+from sentence_transformers import SentenceTransformer
+import torch
+from torch import Tensor
+from typing import List, Dict, Union, Tuple
+import numpy as np
+import importlib.util
+
+if importlib.util.find_spec("tldr") is not None:
+ from tldr import TLDR as NaverTLDR
+
+class TLDR:
+ def __init__(self, encoder_model: SentenceTransformer, model_path: Union[str, Tuple] = None, sep: str = " ", n_components: int = 128, n_neighbors: int = 5,
+ encoder: str = "linear", projector: str = "mlp-2-2048", verbose: int = 2, knn_approximation: str = None, output_folder: str = "data/", **kwargs):
+ self.encoder_model = encoder_model
+ self.sep = sep
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.output_folder = output_folder
+
+ if model_path: self.load(model_path)
+
+ else:
+ self.model = NaverTLDR(
+ n_components=n_components,
+ n_neighbors=n_neighbors,
+ encoder=encoder,
+ projector=projector,
+ device=self.device,
+ verbose=verbose,
+ knn_approximation=knn_approximation,
+ )
+
+ def fit(self, corpus: List[Dict[str, str]], batch_size: int = 8, epochs: int = 100, warmup_epochs: int = 10,
+ train_batch_size: int = 1024, print_every: int = 100, **kwargs):
+
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
+ self.model.fit(self.encoder_model.encode(sentences, batch_size=batch_size, **kwargs),
+ epochs=epochs,
+ warmup_epochs=warmup_epochs,
+ batch_size=batch_size,
+ output_folder=self.output_folder,
+ print_every=print_every)
+
+ def save(self, model_path: str, knn_path: str = None):
+ self.model.save(model_path)
+ if knn_path: self.model.save_knn(knn_path)
+
+ def load(self, model_path: str):
+ self.model = NaverTLDR()
+ self.model.load(model_path, init=True)
+
+ def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
+ return self.model.transform(self.encoder_model.encode(queries, batch_size=batch_size, **kwargs), l2_norm=True)
+
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
+ return self.model.transform(self.encoder_model.encode(sentences, batch_size=batch_size, **kwargs), l2_norm=True)
\ No newline at end of file
diff --git a/beir/retrieval/models/unicoil.py b/beir/retrieval/models/unicoil.py
new file mode 100644
index 0000000..f353140
--- /dev/null
+++ b/beir/retrieval/models/unicoil.py
@@ -0,0 +1,168 @@
+from typing import Optional, List, Dict, Union, Tuple
+from transformers import BertConfig, BertModel, BertTokenizer, PreTrainedModel
+import numpy as np
+import torch
+from tqdm.autonotebook import trange
+from scipy.sparse import csr_matrix
+
+class UniCOIL:
+ def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", query_max_length: int = 128,
+ doc_max_length: int = 500, **kwargs):
+ self.sep = sep
+ self.model = UniCoilEncoder.from_pretrained(model_path)
+ self.tokenizer = BertTokenizer.from_pretrained(model_path)
+ self.bert_input_emb = len(self.tokenizer.get_vocab())
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.query_max_length = query_max_length
+ self.doc_max_length = doc_max_length
+ self.model.to(self.device)
+ self.model.eval()
+
+ def encode_query(self, query: str, batch_size: int = 16, **kwargs):
+ embedding = np.zeros(self.bert_input_emb, dtype=np.float)
+ input_ids = self.tokenizer(query, max_length=self.query_max_length, padding='longest',
+ truncation=True, add_special_tokens=True,
+ return_tensors='pt').to(self.device)["input_ids"]
+
+ with torch.no_grad():
+ batch_weights = self.model(input_ids).cpu().detach().numpy()
+ batch_token_ids = input_ids.cpu().detach().numpy()
+ np.put(embedding, batch_token_ids, batch_weights.flatten())
+
+ return embedding
+
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs):
+ sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
+ return self.encode(sentences, batch_size=batch_size, max_length=self.doc_max_length)
+
+ def encode(
+ self,
+ sentences: Union[str, List[str], List[int]],
+ batch_size: int = 32,
+ max_length: int = 512) -> np.ndarray:
+
+ passage_embs = []
+ non_zero_tokens = 0
+
+ for start_idx in trange(0, len(sentences), batch_size, desc="docs"):
+ documents = sentences[start_idx: start_idx + batch_size]
+ input_ids = self.tokenizer(documents, max_length=max_length, padding='longest',
+ truncation=True, add_special_tokens=True,
+ return_tensors='pt').to(self.device)["input_ids"]
+
+ with torch.no_grad():
+ batch_weights = self.model(input_ids).cpu().detach().numpy()
+ batch_token_ids = input_ids.cpu().detach().numpy()
+
+ for idx in range(len(batch_token_ids)):
+ token_ids_and_embs = list(zip(batch_token_ids[idx], batch_weights[idx].flatten()))
+ non_zero_tokens += len(token_ids_and_embs)
+ passage_embs.append(token_ids_and_embs)
+
+ col = np.zeros(non_zero_tokens, dtype=np.int)
+ row = np.zeros(non_zero_tokens, dtype=np.int)
+ values = np.zeros(non_zero_tokens, dtype=np.float)
+ sparse_idx = 0
+
+ for pid, emb in enumerate(passage_embs):
+ for tid, score in emb:
+ col[sparse_idx] = pid
+ row[sparse_idx] = tid
+ values[sparse_idx] = score
+ sparse_idx += 1
+
+ return csr_matrix((values, (col, row)), shape=(len(sentences), self.bert_input_emb), dtype=np.float)
+
+# class UniCOIL:
+# def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
+# self.sep = sep
+# self.model = UniCoilEncoder.from_pretrained(model_path)
+# self.tokenizer = BertTokenizer.from_pretrained(model_path)
+# self.sparse_vector_dim = len(self.tokenizer.get_vocab())
+# self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+# self.model.to(self.device)
+# self.model.eval()
+
+# def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs):
+# max_length = 128 # hardcode for now
+# return self.encode(queries, batch_size=batch_size, max_length=max_length)
+
+# def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs):
+# max_length = 500
+# sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
+# return self.encode(sentences, batch_size=batch_size, max_length=max_length)
+
+# def encode(
+# self,
+# sentences: Union[str, List[str], List[int]],
+# batch_size: int = 32,
+# max_length: int = 512) -> np.ndarray:
+
+# embeddings = np.zeros((len(sentences), self.sparse_vector_dim), dtype=np.float)
+
+# for start_idx in trange(0, len(sentences), batch_size, desc="docs"):
+# documents = sentences[start_idx: start_idx + batch_size]
+# input_ids = self.tokenizer(documents, max_length=max_length, padding='longest',
+# truncation=True, add_special_tokens=True,
+# return_tensors='pt').to(self.device)["input_ids"]
+
+# with torch.no_grad():
+# batch_weights = self.model(input_ids).cpu().detach().numpy()
+# batch_token_ids = input_ids.cpu().detach().numpy()
+
+# for idx in range(len(batch_token_ids)):
+# np.put(embeddings[start_idx + idx], batch_token_ids[idx], batch_weights[idx].flatten())
+
+# return embeddings
+# # return csr_matrix((values, (row, col)), shape=(len(sentences), self.sparse_vector_dim), dtype=np.float).toarray()
+
+
+# Chunks of this code has been taken from: https://github.com/castorini/pyserini/blob/master/pyserini/encode/_unicoil.py
+# For more details, please refer to uniCOIL by Jimmy Lin and Xueguang Ma (https://arxiv.org/abs/2106.14807)
+class UniCoilEncoder(PreTrainedModel):
+ config_class = BertConfig
+ base_model_prefix = 'coil_encoder'
+ load_tf_weights = None
+
+ def __init__(self, config: BertConfig):
+ super().__init__(config)
+ self.config = config
+ self.bert = BertModel(config)
+ self.tok_proj = torch.nn.Linear(config.hidden_size, 1)
+ self.init_weights()
+
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, torch.nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, torch.nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def init_weights(self):
+ self.bert.init_weights()
+ self.tok_proj.apply(self._init_weights)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ input_shape = input_ids.size()
+ device = input_ids.device
+ if attention_mask is None:
+ attention_mask = (
+ torch.ones(input_shape, device=device)
+ if input_ids is None
+ else (input_ids != self.bert.config.pad_token_id)
+ )
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
+ sequence_output = outputs.last_hidden_state
+ tok_weights = self.tok_proj(sequence_output)
+ tok_weights = torch.relu(tok_weights)
+ return tok_weights
\ No newline at end of file
diff --git a/beir/retrieval/models/use_qa.py b/beir/retrieval/models/use_qa.py
new file mode 100644
index 0000000..7660b72
--- /dev/null
+++ b/beir/retrieval/models/use_qa.py
@@ -0,0 +1,52 @@
+import numpy as np
+import importlib.util
+from typing import List, Dict
+from tqdm.autonotebook import trange
+
+if importlib.util.find_spec("tensorflow") is not None:
+ import tensorflow as tf
+ import tensorflow_hub as hub
+ import tensorflow_text
+
+class UseQA:
+ def __init__(self, hub_url=None, **kwargs):
+ self.initialisation()
+ self.model = hub.load(hub_url)
+
+ @staticmethod
+ def initialisation():
+ # limiting tensorflow gpu-memory if used
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ if gpus:
+ try:
+ for gpu in gpus:
+ tf.config.experimental.set_memory_growth(gpu, True)
+ except RuntimeError as e:
+ print(e)
+
+ def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
+ output = []
+ for start_idx in trange(0, len(queries), batch_size, desc='que'):
+ embeddings_q = self.model.signatures['question_encoder'](
+ tf.constant(queries[start_idx:start_idx+batch_size]))
+ for emb in embeddings_q["outputs"]:
+ output.append(emb)
+
+ return np.asarray(output)
+
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> np.ndarray:
+ output = []
+ for start_idx in trange(0, len(corpus), batch_size, desc='pas'):
+ titles = [row.get('title', '') for row in corpus[start_idx:start_idx+batch_size]]
+ texts = [row.get('text', '') if row.get('text', '') != None else "" for row in corpus[start_idx:start_idx+batch_size]]
+
+ if all(title == "" for title in titles): # Check is title is not present in the dataset
+ titles = texts # title becomes the context as well
+
+ embeddings_c = self.model.signatures['response_encoder'](
+ input=tf.constant(titles),
+ context=tf.constant(texts))
+ for emb in embeddings_c["outputs"]:
+ output.append(emb)
+
+ return np.asarray(output)
\ No newline at end of file
diff --git a/beir/retrieval/search/__init__.py b/beir/retrieval/search/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/beir/retrieval/search/dense/__init__.py b/beir/retrieval/search/dense/__init__.py
new file mode 100644
index 0000000..03dca00
--- /dev/null
+++ b/beir/retrieval/search/dense/__init__.py
@@ -0,0 +1,3 @@
+from .exact_search import DenseRetrievalExactSearch
+from .exact_search_multi_gpu import DenseRetrievalParallelExactSearch
+from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, HNSWSQFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch
\ No newline at end of file
diff --git a/beir/retrieval/search/dense/exact_search.py b/beir/retrieval/search/dense/exact_search.py
new file mode 100644
index 0000000..f750dd0
--- /dev/null
+++ b/beir/retrieval/search/dense/exact_search.py
@@ -0,0 +1,120 @@
+from .util import cos_sim, dot_score
+import logging
+import sys
+import torch
+from typing import Dict, List
+
+logger = logging.getLogger(__name__)
+
+#Parent class for any dense model
+class DenseRetrievalExactSearch:
+
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, **kwargs):
+ #model is class that provides encode_corpus() and encode_queries()
+ self.model = model
+ self.batch_size = batch_size
+ self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
+ self.score_function_desc = {'cos_sim': "Cosine Similarity", 'dot': "Dot Product"}
+ self.corpus_chunk_size = corpus_chunk_size
+ self.show_progress_bar = True #TODO: implement no progress bar if false
+ self.convert_to_tensor = True
+ self.results = {}
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict,
+ top_k: List[int],
+ score_function: str,
+ return_sorted: bool = False,
+ query_negations: List=None,
+ **kwargs) -> Dict[str, Dict[str, float]]:
+ #Create embeddings for all queries using model.encode_queries()
+ #Runs semantic search against the corpus embeddings
+ #Returns a ranked list with the corpus ids
+ if score_function not in self.score_functions:
+ raise ValueError("score function: {} must be either (cos_sim) for cosine similarity or (dot) for dot product".format(score_function))
+
+ logger.info("Encoding Queries...")
+ query_ids = list(queries.keys())
+ self.results = {qid: {} for qid in query_ids}
+ queries = [queries[qid] for qid in query_ids]
+ if query_negations is not None:
+ query_negations = [query_negations[qid] if qid in query_negations else None for qid in query_ids]
+
+ query_embeddings=[]
+ for idx in range(len(queries)):
+ curr_query = queries[idx]
+ if type(curr_query) is str:
+ curr_query_embedding = self.model.encode_queries(
+ curr_query, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)
+ elif type(curr_query) is list:
+ curr_query_embedding = []
+ for k in range(len(curr_query)):
+ qe = self.model.encode_queries(
+ curr_query[k], batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)
+ curr_query_embedding.append(qe)
+ query_embeddings.append(curr_query_embedding)
+
+ logger.info("Sorting Corpus by document length (Longest first)...")
+
+ corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
+ corpus = [corpus[cid] for cid in corpus_ids]
+
+ logger.info("Encoding Corpus in batches... Warning: This might take a while!")
+ logger.info("Scoring Function: {} ({})".format(self.score_function_desc[score_function], score_function))
+
+ itr = range(0, len(corpus), self.corpus_chunk_size)
+
+ all_cos_scores = []
+ for batch_num, corpus_start_idx in enumerate(itr):
+ logger.info("Encoding Batch {}/{}...".format(batch_num+1, len(itr)))
+ corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus))
+
+ #Encode chunk of corpus
+ sub_corpus_embeddings = self.model.encode_corpus(
+ corpus[corpus_start_idx:corpus_end_idx],
+ batch_size=self.batch_size,
+ show_progress_bar=self.show_progress_bar,
+ convert_to_tensor = self.convert_to_tensor
+ )
+
+ #Compute similarites using either cosine-similarity or dot product
+ cos_scores = []
+ for query_itr in range(len(query_embeddings)):
+ curr_query_embedding = query_embeddings[query_itr]
+ if type(curr_query_embedding) is list:
+ curr_cos_scores_ls = self.score_functions[score_function](torch.stack(curr_query_embedding), sub_corpus_embeddings)
+ if query_negations is not None and query_negations[query_itr] is not None:
+ curr_query_negations = torch.tensor(query_negations[query_itr])
+ curr_cos_scores_ls[curr_query_negations == 1] = - curr_cos_scores_ls[curr_query_negations == 1]
+
+ curr_cos_scores_ls[torch.isnan(curr_cos_scores_ls)] = -1
+
+ curr_cos_scores = 1
+ for idx in range(len(curr_cos_scores_ls)):
+ curr_cos_scores *= curr_cos_scores_ls[idx]
+
+
+ else:
+ curr_cos_scores = self.score_functions[score_function](curr_query_embedding.unsqueeze(0), sub_corpus_embeddings)
+ curr_cos_scores[torch.isnan(curr_cos_scores)] = -1
+ curr_cos_scores = curr_cos_scores.squeeze(0)
+
+ cos_scores.append(curr_cos_scores)
+ cos_scores = torch.stack(cos_scores)
+ all_cos_scores.append(cos_scores)
+
+ all_cos_scores_tensor = torch.cat(all_cos_scores, dim=-1)
+ #Get top-k values
+ cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(all_cos_scores_tensor, min(top_k+1, len(all_cos_scores_tensor[0])), dim=1, largest=True, sorted=return_sorted)
+ cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
+ cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
+
+ for query_itr in range(len(query_embeddings)):
+ query_id = query_ids[query_itr]
+ for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]):
+ corpus_id = corpus_ids[sub_corpus_id]
+ if corpus_id != query_id:
+ self.results[query_id][corpus_id] = score
+
+ return self.results
diff --git a/beir/retrieval/search/dense/exact_search_multi_gpu.py b/beir/retrieval/search/dense/exact_search_multi_gpu.py
new file mode 100644
index 0000000..a9b03c7
--- /dev/null
+++ b/beir/retrieval/search/dense/exact_search_multi_gpu.py
@@ -0,0 +1,205 @@
+from .util import cos_sim, dot_score
+from sentence_transformers import SentenceTransformer
+from torch.utils.data import DataLoader
+from datasets import Features, Value, Sequence
+from datasets.utils.filelock import FileLock
+from datasets import Array2D, Dataset
+from tqdm.autonotebook import tqdm
+from datetime import datetime
+from typing import Dict, List, Tuple
+
+import logging
+import torch
+import math
+import queue
+import os
+import time
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+import importlib.util
+
+### HuggingFace Evaluate library (pip install evaluate) only available with Python >= 3.7.
+### Hence for no import issues with Python 3.6, we move DummyMetric if ``evaluate`` library is found.
+if importlib.util.find_spec("evaluate") is not None:
+ from evaluate.module import EvaluationModule, EvaluationModuleInfo
+
+ class DummyMetric(EvaluationModule):
+ len_queries = None
+
+ def _info(self):
+ return EvaluationModuleInfo(
+ description="dummy metric to handle storing middle results",
+ citation="",
+ features=Features(
+ {"cos_scores_top_k_values": Array2D((None, self.len_queries), "float32"), "cos_scores_top_k_idx": Array2D((None, self.len_queries), "int32"), "batch_index": Value("int32")},
+ ),
+ )
+
+ def _compute(self, cos_scores_top_k_values, cos_scores_top_k_idx, batch_index):
+ for i in range(len(batch_index) - 1, -1, -1):
+ if batch_index[i] == -1:
+ del cos_scores_top_k_values[i]
+ del cos_scores_top_k_idx[i]
+ batch_index = [e for e in batch_index if e != -1]
+ batch_index = np.repeat(batch_index, len(cos_scores_top_k_values[0]))
+ cos_scores_top_k_values = np.concatenate(cos_scores_top_k_values, axis=0)
+ cos_scores_top_k_idx = np.concatenate(cos_scores_top_k_idx, axis=0)
+ return cos_scores_top_k_values, cos_scores_top_k_idx, batch_index[:len(cos_scores_top_k_values)]
+
+ def warmup(self):
+ """
+ Add dummy batch to acquire filelocks for all processes and avoid getting errors
+ """
+ self.add_batch(cos_scores_top_k_values=torch.ones((1, 1, self.len_queries), dtype=torch.float32), cos_scores_top_k_idx=torch.ones((1, 1, self.len_queries), dtype=torch.int32), batch_index=-torch.ones(1, dtype=torch.int32))
+
+#Parent class for any dense model
+class DenseRetrievalParallelExactSearch:
+
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = None, target_devices: List[str] = None, **kwargs):
+ #model is class that provides encode_corpus() and encode_queries()
+ self.model = model
+ self.batch_size = batch_size
+ if target_devices is None:
+ if torch.cuda.is_available():
+ target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
+ else:
+ logger.info("CUDA is not available. Start 4 CPU worker")
+ target_devices = ['cpu']*1 # 4
+ self.target_devices = target_devices # PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used, or 4 CPU processes
+ self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
+ self.score_function_desc = {'cos_sim': "Cosine Similarity", 'dot': "Dot Product"}
+ self.corpus_chunk_size = corpus_chunk_size
+ self.show_progress_bar = True #TODO: implement no progress bar if false
+ self.convert_to_tensor = True
+ self.results = {}
+
+ self.query_embeddings = {}
+ self.top_k = None
+ self.score_function = None
+ self.sort_corpus = True
+ self.experiment_id = "exact_search_multi_gpu" # f"test_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
+
+ def search(self,
+ corpus: Dataset,
+ queries: Dataset,
+ top_k: List[int],
+ score_function: str,
+ **kwargs) -> Dict[str, Dict[str, float]]:
+ #Create embeddings for all queries using model.encode_queries()
+ #Runs semantic search against the corpus embeddings
+ #Returns a ranked list with the corpus ids
+ if score_function not in self.score_functions:
+ raise ValueError("score function: {} must be either (cos_sim) for cosine similarity or (dot) for dot product".format(score_function))
+ logger.info("Scoring Function: {} ({})".format(self.score_function_desc[score_function], score_function))
+
+ if importlib.util.find_spec("evaluate") is None:
+ raise ImportError("evaluate library not available. Please do ``pip install evaluate`` library with Python>=3.7 (not available with Python 3.6) to use distributed and multigpu evaluation.")
+
+ self.corpus_chunk_size = min(math.ceil(len(corpus) / len(self.target_devices) / 10), 5000) if self.corpus_chunk_size is None else self.corpus_chunk_size
+ self.corpus_chunk_size = min(self.corpus_chunk_size, len(corpus)-1) # to avoid getting error in metric.compute()
+
+ if self.sort_corpus:
+ logger.info("Sorting Corpus by document length (Longest first)...")
+ corpus = corpus.map(lambda x: {'len': len(x.get("title", "") + x.get("text", ""))}, num_proc=4)
+ corpus = corpus.sort('len', reverse=True)
+
+ # Initiate dataloader
+ queries_dl = DataLoader(queries, batch_size=self.corpus_chunk_size)
+ corpus_dl = DataLoader(corpus, batch_size=self.corpus_chunk_size)
+
+ # Encode queries
+ logger.info("Encoding Queries in batches...")
+ query_embeddings = []
+ for step, queries_batch in enumerate(queries_dl):
+ with torch.no_grad():
+ q_embeds = self.model.encode_queries(
+ queries_batch['text'], batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)
+ query_embeddings.append(q_embeds)
+ query_embeddings = torch.cat(query_embeddings, dim=0)
+
+ # copy the query embeddings to all target devices
+ self.query_embeddings = query_embeddings
+ self.top_k = top_k
+ self.score_function = score_function
+
+ # Start the multi-process pool on all target devices
+ SentenceTransformer._encode_multi_process_worker = self._encode_multi_process_worker
+ pool = self.model.start_multi_process_pool(self.target_devices)
+
+ logger.info("Encoding Corpus in batches... Warning: This might take a while!")
+ start_time = time.time()
+ for chunk_id, corpus_batch in tqdm(enumerate(corpus_dl), total=len(corpus) // self.corpus_chunk_size):
+ with torch.no_grad():
+ self.model.encode_corpus_parallel(
+ corpus_batch, pool=pool, batch_size=self.batch_size, chunk_id=chunk_id)
+
+ # Stop the proccesses in the pool and free memory
+ self.model.stop_multi_process_pool(pool)
+
+ end_time = time.time()
+ logger.info("Encoded all batches in {:.2f} seconds".format(end_time - start_time))
+
+ # Gather all results
+ DummyMetric.len_queries = len(queries)
+ metric = DummyMetric(experiment_id=self.experiment_id, num_process=len(self.target_devices), process_id=0)
+ metric.filelock = FileLock(os.path.join(metric.data_dir, f"{metric.experiment_id}-{metric.num_process}-{metric.process_id}.arrow.lock"))
+ metric.cache_file_name = os.path.join(metric.data_dir, f"{metric.experiment_id}-{metric.num_process}-{metric.process_id}.arrow")
+
+ cos_scores_top_k_values, cos_scores_top_k_idx, chunk_ids = metric.compute()
+ cos_scores_top_k_idx = (cos_scores_top_k_idx.T + chunk_ids * self.corpus_chunk_size).T
+
+ # sort similar docs for each query by cosine similarity and keep only top_k
+ sorted_idx = np.argsort(cos_scores_top_k_values, axis=0)[::-1]
+ sorted_idx = sorted_idx[:self.top_k+1]
+ cos_scores_top_k_values = np.take_along_axis(cos_scores_top_k_values, sorted_idx, axis=0)
+ cos_scores_top_k_idx = np.take_along_axis(cos_scores_top_k_idx, sorted_idx, axis=0)
+
+ logger.info("Formatting results...")
+ # Load corpus ids in memory
+ query_ids = queries['id']
+ corpus_ids = corpus['id']
+ self.results = {qid: {} for qid in query_ids}
+ for query_itr in tqdm(range(len(query_embeddings))):
+ query_id = query_ids[query_itr]
+ for i in range(len(cos_scores_top_k_values)):
+ sub_corpus_id = cos_scores_top_k_idx[i][query_itr]
+ score = cos_scores_top_k_values[i][query_itr].item() # convert np.float to float
+ corpus_id = corpus_ids[sub_corpus_id]
+ if corpus_id != query_id:
+ self.results[query_id][corpus_id] = score
+ return self.results
+
+ def _encode_multi_process_worker(self, process_id, device, model, input_queue, results_queue):
+ """
+ (taken from UKPLab/sentence-transformers/sentence_transformers/SentenceTransformer.py)
+ Internal working process to encode sentences in multi-process setup.
+ Note: Added distributed similarity computing and finding top k similar docs.
+ """
+ DummyMetric.len_queries = len(self.query_embeddings)
+ metric = DummyMetric(experiment_id=self.experiment_id, num_process=len(self.target_devices), process_id=process_id)
+ metric.warmup()
+ with torch.no_grad():
+ while True:
+ try:
+ id, batch_size, sentences = input_queue.get()
+ corpus_embeds = model.encode(
+ sentences, device=device, show_progress_bar=False, convert_to_tensor=True, batch_size=batch_size
+ ).detach()
+
+ cos_scores = self.score_functions[self.score_function](self.query_embeddings.to(corpus_embeds.device), corpus_embeds).detach()
+ cos_scores[torch.isnan(cos_scores)] = -1
+
+ #Get top-k values
+ cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(self.top_k+1, len(cos_scores[1])), dim=1, largest=True, sorted=False)
+ cos_scores_top_k_values = cos_scores_top_k_values.T.unsqueeze(0).detach()
+ cos_scores_top_k_idx = cos_scores_top_k_idx.T.unsqueeze(0).detach()
+
+ # Store results in an Apache Arrow table
+ metric.add_batch(cos_scores_top_k_values=cos_scores_top_k_values, cos_scores_top_k_idx=cos_scores_top_k_idx, batch_index=[id]*len(cos_scores_top_k_values))
+
+ # Alarm that process finished processing a batch
+ results_queue.put(None)
+ except queue.Empty:
+ break
diff --git a/beir/retrieval/search/dense/faiss_index.py b/beir/retrieval/search/dense/faiss_index.py
new file mode 100644
index 0000000..56b8a04
--- /dev/null
+++ b/beir/retrieval/search/dense/faiss_index.py
@@ -0,0 +1,174 @@
+from .util import normalize
+from typing import List, Optional, Tuple, Union
+from tqdm.autonotebook import trange
+import numpy as np
+
+import faiss
+import logging
+import time
+
+logger = logging.getLogger(__name__)
+
+
+class FaissIndex:
+ def __init__(self, index: faiss.Index, passage_ids: List[int] = None):
+ self.index = index
+ self._passage_ids = None
+ if passage_ids is not None:
+ self._passage_ids = np.array(passage_ids, dtype=np.int64)
+
+ def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ start_time = time.time()
+ scores_arr, ids_arr = self.index.search(query_embeddings, k)
+ if self._passage_ids is not None:
+ ids_arr = self._passage_ids[ids_arr.reshape(-1)].reshape(query_embeddings.shape[0], -1)
+ logger.info("Total search time: %.3f", time.time() - start_time)
+ return scores_arr, ids_arr
+
+ def save(self, fname: str):
+ faiss.write_index(self.index, fname)
+
+ @classmethod
+ def build(
+ cls,
+ passage_ids: List[int],
+ passage_embeddings: np.ndarray,
+ index: Optional[faiss.Index] = None,
+ buffer_size: int = 50000,
+ ):
+ if index is None:
+ index = faiss.IndexFlatIP(passage_embeddings.shape[1])
+ for start in trange(0, len(passage_ids), buffer_size):
+ index.add(passage_embeddings[start : start + buffer_size])
+
+ return cls(index, passage_ids)
+
+ def to_gpu(self):
+ if faiss.get_num_gpus() == 1:
+ res = faiss.StandardGpuResources()
+ self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
+ else:
+ cloner_options = faiss.GpuMultipleClonerOptions()
+ cloner_options.shard = True
+ self.index = faiss.index_cpu_to_all_gpus(self.index, co=cloner_options)
+
+ return self.index
+
+
+class FaissHNSWIndex(FaissIndex):
+ def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ query_embeddings = np.hstack((query_embeddings, np.zeros((query_embeddings.shape[0], 1), dtype=np.float32)))
+ return super().search(query_embeddings, k)
+
+ def save(self, output_path: str):
+ super().save(output_path)
+
+ @classmethod
+ def build(
+ cls,
+ passage_ids: List[int],
+ passage_embeddings: np.ndarray,
+ index: Optional[faiss.Index] = None,
+ buffer_size: int = 50000,
+ ):
+ sq_norms = (passage_embeddings ** 2).sum(1)
+ max_sq_norm = float(sq_norms.max())
+ aux_dims = np.sqrt(max_sq_norm - sq_norms)
+ passage_embeddings = np.hstack((passage_embeddings, aux_dims.reshape(-1, 1)))
+ return super().build(passage_ids, passage_embeddings, index, buffer_size)
+
+class FaissTrainIndex(FaissIndex):
+ def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ return super().search(query_embeddings, k)
+
+ def save(self, output_path: str):
+ super().save(output_path)
+
+ @classmethod
+ def build(
+ cls,
+ passage_ids: List[int],
+ passage_embeddings: np.ndarray,
+ index: Optional[faiss.Index] = None,
+ buffer_size: int = 50000,
+ ):
+ index.train(passage_embeddings)
+ return super().build(passage_ids, passage_embeddings, index, buffer_size)
+
+class FaissBinaryIndex(FaissIndex):
+ def __init__(self, index: faiss.Index, passage_ids: List[int] = None, passage_embeddings: np.ndarray = None):
+ self.index = index
+ self._passage_ids = None
+ if passage_ids is not None:
+ self._passage_ids = np.array(passage_ids, dtype=np.int64)
+
+ self._passage_embeddings = None
+ if passage_embeddings is not None:
+ self._passage_embeddings = passage_embeddings
+
+ def search(self, query_embeddings: np.ndarray, k: int, binary_k: int = 1000, rerank: bool = True,
+ score_function: str = "dot", threshold: Union[int, np.ndarray] = 0, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ start_time = time.time()
+ num_queries = query_embeddings.shape[0]
+ bin_query_embeddings = np.packbits(np.where(query_embeddings > threshold, 1, 0)).reshape(num_queries, -1)
+
+ if not rerank:
+ scores_arr, ids_arr = self.index.search(bin_query_embeddings, k)
+ if self._passage_ids is not None:
+ ids_arr = self._passage_ids[ids_arr.reshape(-1)].reshape(num_queries, -1)
+ return scores_arr, ids_arr
+
+ if self._passage_ids is not None:
+ _, ids_arr = self.index.search(bin_query_embeddings, binary_k)
+ logger.info("Initial search time: %.3f", time.time() - start_time)
+ passage_embeddings = np.unpackbits(self._passage_embeddings[ids_arr.reshape(-1)])
+ passage_embeddings = passage_embeddings.reshape(num_queries, binary_k, -1).astype(np.float32)
+ else:
+ raw_index = self.index.index
+ _, ids_arr = raw_index.search(bin_query_embeddings, binary_k)
+ logger.info("Initial search time: %.3f", time.time() - start_time)
+ passage_embeddings = np.vstack(
+ [np.unpackbits(raw_index.reconstruct(int(id_))) for id_ in ids_arr.reshape(-1)]
+ )
+ passage_embeddings = passage_embeddings.reshape(
+ query_embeddings.shape[0], binary_k, query_embeddings.shape[1]
+ )
+ passage_embeddings = passage_embeddings.astype(np.float32)
+
+ passage_embeddings = passage_embeddings * 2 - 1
+
+ if score_function == "cos_sim":
+ passage_embeddings, query_embeddings = normalize(passage_embeddings), normalize(query_embeddings)
+
+ scores_arr = np.einsum("ijk,ik->ij", passage_embeddings, query_embeddings)
+ sorted_indices = np.argsort(-scores_arr, axis=1)
+
+ ids_arr = ids_arr[np.arange(num_queries)[:, None], sorted_indices]
+ if self._passage_ids is not None:
+ ids_arr = self._passage_ids[ids_arr.reshape(-1)].reshape(num_queries, -1)
+ else:
+ ids_arr = np.array([self.index.id_map.at(int(id_)) for id_ in ids_arr.reshape(-1)], dtype=np.int)
+ ids_arr = ids_arr.reshape(num_queries, -1)
+
+ scores_arr = scores_arr[np.arange(num_queries)[:, None], sorted_indices]
+ logger.info("Total search time: %.3f", time.time() - start_time)
+
+ return scores_arr[:, :k], ids_arr[:, :k]
+
+ def save(self, fname: str):
+ faiss.write_index_binary(self.index, fname)
+
+ @classmethod
+ def build(
+ cls,
+ passage_ids: List[int],
+ passage_embeddings: np.ndarray,
+ index: Optional[faiss.Index] = None,
+ buffer_size: int = 50000,
+ ):
+ if index is None:
+ index = faiss.IndexBinaryFlat(passage_embeddings.shape[1] * 8)
+ for start in trange(0, len(passage_ids), buffer_size):
+ index.add(passage_embeddings[start : start + buffer_size])
+
+ return cls(index, passage_ids, passage_embeddings)
\ No newline at end of file
diff --git a/beir/retrieval/search/dense/faiss_search.py b/beir/retrieval/search/dense/faiss_search.py
new file mode 100644
index 0000000..6ad7bbf
--- /dev/null
+++ b/beir/retrieval/search/dense/faiss_search.py
@@ -0,0 +1,389 @@
+from .util import cos_sim, dot_score, normalize, save_dict_to_tsv, load_tsv_to_dict
+from .faiss_index import FaissBinaryIndex, FaissTrainIndex, FaissHNSWIndex, FaissIndex
+import logging
+import sys
+import torch
+import faiss
+import numpy as np
+import os
+from typing import Dict, List
+from tqdm.autonotebook import tqdm
+
+logger = logging.getLogger(__name__)
+
+#Parent class for any faiss search
+class DenseRetrievalFaissSearch:
+
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, **kwargs):
+ self.model = model
+ self.batch_size = batch_size
+ self.corpus_chunk_size = corpus_chunk_size
+ self.score_functions = ['cos_sim','dot']
+ self.mapping_tsv_keys = ["beir-docid", "faiss-docid"]
+ self.faiss_index = None
+ self.dim_size = 0
+ self.results = {}
+ self.mapping = {}
+ self.rev_mapping = {}
+
+ def _create_mapping_ids(self, corpus_ids):
+ if not all(isinstance(doc_id, int) for doc_id in corpus_ids):
+ for idx in range(len(corpus_ids)):
+ self.mapping[corpus_ids[idx]] = idx
+ self.rev_mapping[idx] = corpus_ids[idx]
+
+ def _load(self, input_dir: str, prefix: str, ext: str):
+
+ # Load ID mappings from file
+ input_mappings_path = os.path.join(input_dir, "{}.{}.tsv".format(prefix, ext))
+ logger.info("Loading Faiss ID-mappings from path: {}".format(input_mappings_path))
+ self.mapping = load_tsv_to_dict(input_mappings_path, header=True)
+ self.rev_mapping = {v: k for k, v in self.mapping.items()}
+ passage_ids = sorted(list(self.rev_mapping))
+
+ # Load Faiss Index from disk
+ input_faiss_path = os.path.join(input_dir, "{}.{}.faiss".format(prefix, ext))
+ logger.info("Loading Faiss Index from path: {}".format(input_faiss_path))
+
+ return input_faiss_path, passage_ids
+
+ def save(self, output_dir: str, prefix: str, ext: str):
+
+ # Save BEIR -> Faiss ids mappings
+ save_mappings_path = os.path.join(output_dir, "{}.{}.tsv".format(prefix, ext))
+ logger.info("Saving Faiss ID-mappings to path: {}".format(save_mappings_path))
+ save_dict_to_tsv(self.mapping, save_mappings_path, keys=self.mapping_tsv_keys)
+
+ # Save Faiss Index to disk
+ save_faiss_path = os.path.join(output_dir, "{}.{}.faiss".format(prefix, ext))
+ logger.info("Saving Faiss Index to path: {}".format(save_faiss_path))
+ self.faiss_index.save(save_faiss_path)
+ logger.info("Index size: {:.2f}MB".format(os.path.getsize(save_faiss_path)*0.000001))
+
+ def _index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None):
+
+ logger.info("Sorting Corpus by document length (Longest first)...")
+ corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
+ self._create_mapping_ids(corpus_ids)
+ corpus = [corpus[cid] for cid in corpus_ids]
+ normalize_embeddings = True if score_function == "cos_sim" else False
+
+ logger.info("Encoding Corpus in batches... Warning: This might take a while!")
+
+ itr = range(0, len(corpus), self.corpus_chunk_size)
+
+ for batch_num, corpus_start_idx in enumerate(itr):
+ logger.info("Encoding Batch {}/{}...".format(batch_num+1, len(itr)))
+ corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus))
+
+ #Encode chunk of corpus
+ sub_corpus_embeddings = self.model.encode_corpus(
+ corpus[corpus_start_idx:corpus_end_idx],
+ batch_size=self.batch_size,
+ show_progress_bar=True,
+ normalize_embeddings=normalize_embeddings)
+
+ if not batch_num:
+ corpus_embeddings = sub_corpus_embeddings
+ else:
+ corpus_embeddings = np.vstack([corpus_embeddings, sub_corpus_embeddings])
+
+ #Index chunk of corpus into faiss index
+ logger.info("Indexing Passages into Faiss...")
+
+ faiss_ids = [self.mapping.get(corpus_id) for corpus_id in corpus_ids]
+ self.dim_size = corpus_embeddings.shape[1]
+
+ del sub_corpus_embeddings
+
+ return faiss_ids, corpus_embeddings
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ assert score_function in self.score_functions
+
+ if not self.faiss_index: self.index(corpus, score_function)
+
+ logger.info("Encoding Queries...")
+ query_ids = list(queries.keys())
+ queries = [queries[qid] for qid in queries]
+ query_embeddings = self.model.encode_queries(
+ queries, show_progress_bar=True, batch_size=self.batch_size)
+
+ faiss_scores, faiss_doc_ids = self.faiss_index.search(query_embeddings, top_k, **kwargs)
+
+ for idx in range(len(query_ids)):
+ scores = [float(score) for score in faiss_scores[idx]]
+ if len(self.rev_mapping) != 0:
+ doc_ids = [self.rev_mapping[doc_id] for doc_id in faiss_doc_ids[idx]]
+ else:
+ doc_ids = [str(doc_id) for doc_id in faiss_doc_ids[idx]]
+ self.results[query_ids[idx]] = dict(zip(doc_ids, scores))
+
+ return self.results
+
+
+class BinaryFaissSearch(DenseRetrievalFaissSearch):
+
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "bin"):
+ passage_embeddings = []
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index_binary(input_faiss_path)
+ logger.info("Reconstructing passage_embeddings back in Memory from Index...")
+ for idx in tqdm(range(0, len(passage_ids)), total=len(passage_ids)):
+ passage_embeddings.append(base_index.reconstruct(idx))
+ passage_embeddings = np.vstack(passage_embeddings)
+ self.faiss_index = FaissBinaryIndex(base_index, passage_ids, passage_embeddings)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function)
+ logger.info("Using Binary Hashing in Flat Mode!")
+ logger.info("Output Dimension: {}".format(self.dim_size))
+ base_index = faiss.IndexBinaryFlat(self.dim_size * 8)
+ self.faiss_index = FaissBinaryIndex.build(faiss_ids, corpus_embeddings, base_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "bin"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "binary_faiss_index"
+
+
+class PQFaissSearch(DenseRetrievalFaissSearch):
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, num_of_centroids: int = 96,
+ code_size: int = 8, similarity_metric=faiss.METRIC_INNER_PRODUCT, use_rotation: bool = False, **kwargs):
+ super(PQFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
+ self.num_of_centroids = num_of_centroids
+ self.code_size = code_size
+ self.similarity_metric = similarity_metric
+ self.use_rotation = use_rotation
+
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "pq"):
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index(input_faiss_path)
+ self.faiss_index = FaissTrainIndex(base_index, passage_ids)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
+
+ logger.info("Using Product Quantization (PQ) in Flat mode!")
+ logger.info("Parameters Used: num_of_centroids: {} ".format(self.num_of_centroids))
+ logger.info("Parameters Used: code_size: {}".format(self.code_size))
+
+ base_index = faiss.IndexPQ(self.dim_size, self.num_of_centroids, self.code_size, self.similarity_metric)
+
+ if self.use_rotation:
+ logger.info("Rotating data before encoding it with a product quantizer...")
+ logger.info("Creating OPQ Matrix...")
+ opq_matrix = faiss.OPQMatrix(self.dim_size, self.code_size)
+ base_index = faiss.IndexPreTransform(opq_matrix, base_index)
+
+ self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, base_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "pq"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "pq_faiss_index"
+
+
+class HNSWFaissSearch(DenseRetrievalFaissSearch):
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, hnsw_store_n: int = 512,
+ hnsw_ef_search: int = 128, hnsw_ef_construction: int = 200, similarity_metric=faiss.METRIC_INNER_PRODUCT, **kwargs):
+ super(HNSWFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
+ self.hnsw_store_n = hnsw_store_n
+ self.hnsw_ef_search = hnsw_ef_search
+ self.hnsw_ef_construction = hnsw_ef_construction
+ self.similarity_metric = similarity_metric
+
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "hnsw"):
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index(input_faiss_path)
+ self.faiss_index = FaissHNSWIndex(base_index, passage_ids)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
+
+ logger.info("Using Approximate Nearest Neighbours (HNSW) in Flat Mode!")
+ logger.info("Parameters Required: hnsw_store_n: {}".format(self.hnsw_store_n))
+ logger.info("Parameters Required: hnsw_ef_search: {}".format(self.hnsw_ef_search))
+ logger.info("Parameters Required: hnsw_ef_construction: {}".format(self.hnsw_ef_construction))
+
+ base_index = faiss.IndexHNSWFlat(self.dim_size + 1, self.hnsw_store_n, self.similarity_metric)
+ base_index.hnsw.efSearch = self.hnsw_ef_search
+ base_index.hnsw.efConstruction = self.hnsw_ef_construction
+ self.faiss_index = FaissHNSWIndex.build(faiss_ids, corpus_embeddings, base_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "hnsw"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "hnsw_faiss_index"
+
+class HNSWSQFaissSearch(DenseRetrievalFaissSearch):
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, hnsw_store_n: int = 128,
+ hnsw_ef_search: int = 128, hnsw_ef_construction: int = 200, similarity_metric=faiss.METRIC_INNER_PRODUCT,
+ quantizer_type: str = "QT_8bit", **kwargs):
+ super(HNSWSQFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
+ self.hnsw_store_n = hnsw_store_n
+ self.hnsw_ef_search = hnsw_ef_search
+ self.hnsw_ef_construction = hnsw_ef_construction
+ self.similarity_metric = similarity_metric
+ self.qname = quantizer_type
+
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "hnsw-sq"):
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index(input_faiss_path)
+ self.faiss_index = FaissTrainIndex(base_index, passage_ids)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
+
+ logger.info("Using Approximate Nearest Neighbours (HNSW) in SQ Mode!")
+ logger.info("Parameters Required: hnsw_store_n: {}".format(self.hnsw_store_n))
+ logger.info("Parameters Required: hnsw_ef_search: {}".format(self.hnsw_ef_search))
+ logger.info("Parameters Required: hnsw_ef_construction: {}".format(self.hnsw_ef_construction))
+ logger.info("Parameters Required: quantizer_type: {}".format(self.qname))
+
+ qtype = getattr(faiss.ScalarQuantizer, self.qname)
+ base_index = faiss.IndexHNSWSQ(self.dim_size + 1, qtype, self.hnsw_store_n)
+ base_index.hnsw.efSearch = self.hnsw_ef_search
+ base_index.hnsw.efConstruction = self.hnsw_ef_construction
+ self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, base_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "hnsw-sq"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "hnswsq_faiss_index"
+
+class FlatIPFaissSearch(DenseRetrievalFaissSearch):
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "flat"):
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index(input_faiss_path)
+ self.faiss_index = FaissIndex(base_index, passage_ids)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
+ base_index = faiss.IndexFlatIP(self.dim_size)
+ self.faiss_index = FaissIndex.build(faiss_ids, corpus_embeddings, base_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "flat"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "flat_faiss_index"
+
+class PCAFaissSearch(DenseRetrievalFaissSearch):
+ def __init__(self, model, base_index: faiss.Index, output_dimension: int, batch_size: int = 128,
+ corpus_chunk_size: int = 50000, **kwargs):
+ super(PCAFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
+ self.base_index = base_index
+ self.output_dim = output_dimension
+
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "pca"):
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index(input_faiss_path)
+ self.faiss_index = FaissTrainIndex(base_index, passage_ids)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
+ logger.info("Creating PCA Matrix...")
+ logger.info("Input Dimension: {}, Output Dimension: {}".format(self.dim_size, self.output_dim))
+ pca_matrix = faiss.PCAMatrix(self.dim_size, self.output_dim, 0, True)
+ final_index = faiss.IndexPreTransform(pca_matrix, self.base_index)
+ self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, final_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "pca"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "pca_faiss_index"
+
+class SQFaissSearch(DenseRetrievalFaissSearch):
+ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000,
+ similarity_metric=faiss.METRIC_INNER_PRODUCT, quantizer_type: str = "QT_fp16", **kwargs):
+ super(SQFaissSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs)
+ self.similarity_metric = similarity_metric
+ self.qname = quantizer_type
+
+ def load(self, input_dir: str, prefix: str = "my-index", ext: str = "sq"):
+ input_faiss_path, passage_ids = super()._load(input_dir, prefix, ext)
+ base_index = faiss.read_index(input_faiss_path)
+ self.faiss_index = FaissTrainIndex(base_index, passage_ids)
+
+ def index(self, corpus: Dict[str, Dict[str, str]], score_function: str = None, **kwargs):
+ faiss_ids, corpus_embeddings = super()._index(corpus, score_function, **kwargs)
+
+ logger.info("Using Scalar Quantizer in Flat Mode!")
+ logger.info("Parameters Used: quantizer_type: {}".format(self.qname))
+
+ qtype = getattr(faiss.ScalarQuantizer, self.qname)
+ base_index = faiss.IndexScalarQuantizer(self.dim_size, qtype, self.similarity_metric)
+ self.faiss_index = FaissTrainIndex.build(faiss_ids, corpus_embeddings, base_index)
+
+ def save(self, output_dir: str, prefix: str = "my-index", ext: str = "sq"):
+ super().save(output_dir, prefix, ext)
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function = str, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ return super().search(corpus, queries, top_k, score_function, **kwargs)
+
+ def get_index_name(self):
+ return "sq_faiss_index"
\ No newline at end of file
diff --git a/beir/retrieval/search/dense/util.py b/beir/retrieval/search/dense/util.py
new file mode 100644
index 0000000..1f6d03b
--- /dev/null
+++ b/beir/retrieval/search/dense/util.py
@@ -0,0 +1,65 @@
+import torch
+import numpy as np
+import csv
+
+def cos_sim(a: torch.Tensor, b: torch.Tensor):
+ """
+ Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
+ :return: Matrix with res[i][j] = cos_sim(a[i], b[j])
+ """
+ if not isinstance(a, torch.Tensor):
+ a = torch.tensor(a)
+
+ if not isinstance(b, torch.Tensor):
+ b = torch.tensor(b)
+
+ if len(a.shape) == 1:
+ a = a.unsqueeze(0)
+
+ if len(b.shape) == 1:
+ b = b.unsqueeze(0)
+
+ a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
+ b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
+ return torch.mm(a_norm, b_norm.transpose(0, 1)) #TODO: this keeps allocating GPU memory
+
+def dot_score(a: torch.Tensor, b: torch.Tensor):
+ """
+ Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
+ :return: Matrix with res[i][j] = dot_prod(a[i], b[j])
+ """
+ if not isinstance(a, torch.Tensor):
+ a = torch.tensor(a)
+
+ if not isinstance(b, torch.Tensor):
+ b = torch.tensor(b)
+
+ if len(a.shape) == 1:
+ a = a.unsqueeze(0)
+
+ if len(b.shape) == 1:
+ b = b.unsqueeze(0)
+
+ return torch.mm(a, b.transpose(0, 1))
+
+def normalize(a: np.ndarray) -> np.ndarray:
+ return a/np.linalg.norm(a, ord=2, axis=1, keepdims=True)
+
+def save_dict_to_tsv(_dict, output_path, keys=[]):
+
+ with open(output_path, 'w') as fIn:
+ writer = csv.writer(fIn, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
+ if keys: writer.writerow(keys)
+ for key, value in _dict.items():
+ writer.writerow([key, value])
+
+def load_tsv_to_dict(input_path, header=True):
+
+ mappings = {}
+ reader = csv.reader(open(input_path, encoding="utf-8"),
+ delimiter="\t", quoting=csv.QUOTE_MINIMAL)
+ if header: next(reader)
+ for row in reader:
+ mappings[row[0]] = int(row[1])
+
+ return mappings
\ No newline at end of file
diff --git a/beir/retrieval/search/lexical/__init__.py b/beir/retrieval/search/lexical/__init__.py
new file mode 100644
index 0000000..efa5acb
--- /dev/null
+++ b/beir/retrieval/search/lexical/__init__.py
@@ -0,0 +1 @@
+from .bm25_search import BM25Search
\ No newline at end of file
diff --git a/beir/retrieval/search/lexical/bm25_search.py b/beir/retrieval/search/lexical/bm25_search.py
new file mode 100644
index 0000000..b5f105a
--- /dev/null
+++ b/beir/retrieval/search/lexical/bm25_search.py
@@ -0,0 +1,77 @@
+from .elastic_search import ElasticSearch
+import tqdm
+import time
+from typing import List, Dict
+
+def sleep(seconds):
+ if seconds: time.sleep(seconds)
+
+class BM25Search:
+ def __init__(self, index_name: str, hostname: str = "localhost", keys: Dict[str, str] = {"title": "title", "body": "txt"}, language: str = "english",
+ batch_size: int = 128, timeout: int = 100, retry_on_timeout: bool = True, maxsize: int = 24, number_of_shards: int = "default",
+ initialize: bool = True, sleep_for: int = 2):
+ self.results = {}
+ self.batch_size = batch_size
+ self.initialize = initialize
+ self.sleep_for = sleep_for
+ self.config = {
+ "hostname": hostname,
+ "index_name": index_name,
+ "keys": keys,
+ "timeout": timeout,
+ "retry_on_timeout": retry_on_timeout,
+ "maxsize": maxsize,
+ "number_of_shards": number_of_shards,
+ "language": language
+ }
+ self.es = ElasticSearch(self.config)
+ if self.initialize:
+ self.initialise()
+
+ def initialise(self):
+ self.es.delete_index()
+ sleep(self.sleep_for)
+ self.es.create_index()
+
+ def search(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], top_k: int, *args, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ # Index the corpus within elastic-search
+ # False, if the corpus has been already indexed
+ if self.initialize:
+ self.index(corpus)
+ # Sleep for few seconds so that elastic-search indexes the docs properly
+ sleep(self.sleep_for)
+
+ #retrieve results from BM25
+ query_ids = list(queries.keys())
+ queries = [queries[qid] for qid in query_ids]
+
+ for start_idx in tqdm.trange(0, len(queries), self.batch_size, desc='que'):
+ query_ids_batch = query_ids[start_idx:start_idx+self.batch_size]
+ results = self.es.lexical_multisearch(
+ texts=queries[start_idx:start_idx+self.batch_size],
+ top_hits=top_k + 1) # Add 1 extra if query is present with documents
+
+ for (query_id, hit) in zip(query_ids_batch, results):
+ scores = {}
+ for corpus_id, score in hit['hits']:
+ if corpus_id != query_id: # query doesnt return in results
+ scores[corpus_id] = score
+ self.results[query_id] = scores
+
+ return self.results
+
+
+ def index(self, corpus: Dict[str, Dict[str, str]]):
+ progress = tqdm.tqdm(unit="docs", total=len(corpus))
+ # dictionary structure = {_id: {title_key: title, text_key: text}}
+ dictionary = {idx: {
+ self.config["keys"]["title"]: corpus[idx].get("title", None),
+ self.config["keys"]["body"]: corpus[idx].get("text", None)
+ } for idx in list(corpus.keys())
+ }
+ self.es.bulk_add_to_index(
+ generate_actions=self.es.generate_actions(
+ dictionary=dictionary, update=False),
+ progress=progress
+ )
diff --git a/beir/retrieval/search/lexical/elastic_search.py b/beir/retrieval/search/lexical/elastic_search.py
new file mode 100644
index 0000000..843c6c3
--- /dev/null
+++ b/beir/retrieval/search/lexical/elastic_search.py
@@ -0,0 +1,247 @@
+from elasticsearch import Elasticsearch
+from elasticsearch.helpers import streaming_bulk
+from typing import Dict, List, Tuple
+import logging
+import tqdm
+import sys
+
+tracer = logging.getLogger('elasticsearch')
+tracer.setLevel(logging.CRITICAL) # supressing INFO messages for elastic-search
+
+class ElasticSearch(object):
+
+ def __init__(self, es_credentials: Dict[str, object]):
+
+ logging.info("Activating Elasticsearch....")
+ logging.info("Elastic Search Credentials: %s", es_credentials)
+ self.index_name = es_credentials["index_name"]
+ self.check_index_name()
+
+ # https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-lang-analyzer.html
+ self.languages = ["arabic", "armenian", "basque", "bengali", "brazilian", "bulgarian", "catalan",
+ "cjk", "czech", "danish", "dutch", "english","estonian","finnish","french",
+ "galician", "german", "greek", "hindi", "hungarian", "indonesian", "irish",
+ "italian", "latvian", "lithuanian", "norwegian", "persian", "portuguese",
+ "romanian", "russian", "sorani", "spanish", "swedish", "turkish", "thai"]
+
+ self.language = es_credentials["language"]
+ self.check_language_supported()
+
+ self.text_key = es_credentials["keys"]["body"]
+ self.title_key = es_credentials["keys"]["title"]
+ self.number_of_shards = es_credentials["number_of_shards"]
+
+ self.es = Elasticsearch(
+ [es_credentials["hostname"]],
+ timeout=es_credentials["timeout"],
+ retry_on_timeout=es_credentials["retry_on_timeout"],
+ maxsize=es_credentials["maxsize"])
+
+ def check_language_supported(self):
+ """Check Language Supported in Elasticsearch
+ """
+ if self.language.lower() not in self.languages:
+ raise ValueError("Invalid Language: {}, not supported by Elasticsearch. Languages Supported: \
+ {}".format(self.language, self.languages))
+
+ def check_index_name(self):
+ """Check Elasticsearch Index Name"""
+ # https://stackoverflow.com/questions/41585392/what-are-the-rules-for-index-names-in-elastic-search
+ # Check 1: Must not contain the characters ===> #:\/*?"<>|,
+ for char in '#:\/*?"<>|,':
+ if char in self.index_name:
+ raise ValueError('Invalid Elasticsearch Index, must not contain the characters ===> #:\/*?"<>|,')
+
+ # Check 2: Must not start with characters ===> _-+
+ if self.index_name.startswith(("_", "-", "+")):
+ raise ValueError('Invalid Elasticsearch Index, must not start with characters ===> _ or - or +')
+
+ # Check 3: must not be . or ..
+ if self.index_name in [".", ".."]:
+ raise ValueError('Invalid Elasticsearch Index, must not be . or ..')
+
+ # Check 4: must be lowercase
+ if not self.index_name.islower():
+ raise ValueError('Invalid Elasticsearch Index, must be lowercase')
+
+
+ def create_index(self):
+ """Create Elasticsearch Index
+ """
+ logging.info("Creating fresh Elasticsearch-Index named - {}".format(self.index_name))
+
+ try:
+ if self.number_of_shards == "default":
+ mapping = {
+ "mappings" : {
+ "properties" : {
+ self.title_key: {"type": "text", "analyzer": self.language},
+ self.text_key: {"type": "text", "analyzer": self.language}
+ }}}
+ else:
+ mapping = {
+ "settings": {
+ "number_of_shards": self.number_of_shards
+ },
+ "mappings" : {
+ "properties" : {
+ self.title_key: {"type": "text", "analyzer": self.language},
+ self.text_key: {"type": "text", "analyzer": self.language}
+ }}}
+
+ self.es.indices.create(index=self.index_name, body=mapping, ignore=[400]) #400: IndexAlreadyExistsException
+ except Exception as e:
+ logging.error("Unable to create Index in Elastic Search. Reason: {}".format(e))
+
+ def delete_index(self):
+ """Delete Elasticsearch Index"""
+
+ logging.info("Deleting previous Elasticsearch-Index named - {}".format(self.index_name))
+ try:
+ self.es.indices.delete(index=self.index_name, ignore=[400, 404]) # 404: IndexDoesntExistException
+ except Exception as e:
+ logging.error("Unable to create Index in Elastic Search. Reason: {}".format(e))
+
+ def bulk_add_to_index(self, generate_actions, progress):
+ """Bulk indexing to elastic search using generator actions
+
+ Args:
+ generate_actions (generator function): generator function must be provided
+ progress (tqdm.tqdm): tqdm progress_bar
+ """
+ for ok, action in streaming_bulk(
+ client=self.es, index=self.index_name, actions=generate_actions,
+ ):
+ progress.update(1)
+ progress.reset()
+ progress.close()
+
+ def lexical_search(self, text: str, top_hits: int, ids: List[str] = None, skip: int = 0) -> Dict[str, object]:
+ """[summary]
+
+ Args:
+ text (str): query text
+ top_hits (int): top k hits to retrieved
+ ids (List[str], optional): Filter results for only specific ids. Defaults to None.
+
+ Returns:
+ Dict[str, object]: Hit results
+ """
+ req_body = {"query" : {"multi_match": {
+ "query": text,
+ "type": "best_fields",
+ "fields": [self.text_key, self.title_key],
+ "tie_breaker": 0.5
+ }}}
+
+ if ids: req_body = {"query": {"bool": {
+ "must": req_body["query"],
+ "filter": {"ids": {"values": ids}}
+ }}}
+
+ res = self.es.search(
+ search_type="dfs_query_then_fetch",
+ index = self.index_name,
+ body = req_body,
+ size = skip + top_hits
+ )
+
+ hits = []
+
+ for hit in res["hits"]["hits"][skip:]:
+ hits.append((hit["_id"], hit['_score']))
+
+ return self.hit_template(es_res=res, hits=hits)
+
+
+ def lexical_multisearch(self, texts: List[str], top_hits: int, skip: int = 0) -> Dict[str, object]:
+ """Multiple Query search in Elasticsearch
+
+ Args:
+ texts (List[str]): Multiple query texts
+ top_hits (int): top k hits to be retrieved
+ skip (int, optional): top hits to be skipped. Defaults to 0.
+
+ Returns:
+ Dict[str, object]: Hit results
+ """
+ request = []
+
+ assert skip + top_hits <= 10000, "Elastic-Search Window too large, Max-Size = 10000"
+
+ for text in texts:
+ req_head = {"index" : self.index_name, "search_type": "dfs_query_then_fetch"}
+ req_body = {
+ "_source": False, # No need to return source objects
+ "query": {
+ "multi_match": {
+ "query": text, # matching query with both text and title fields
+ "type": "best_fields",
+ "fields": [self.title_key, self.text_key],
+ "tie_breaker": 0.5
+ }
+ },
+ "size": skip + top_hits, # The same paragraph will occur in results
+ }
+
+ request.extend([req_head, req_body])
+
+ res = self.es.msearch(body = request)
+
+ result = []
+ for resp in res["responses"]:
+ responses = resp["hits"]["hits"][skip:]
+
+ hits = []
+ for hit in responses:
+ hits.append((hit["_id"], hit['_score']))
+
+ result.append(self.hit_template(es_res=resp, hits=hits))
+ return result
+
+
+ def generate_actions(self, dictionary: Dict[str, Dict[str, str]], update: bool = False):
+ """Iterator function for efficient addition to Elasticsearch
+ Ref: https://stackoverflow.com/questions/35182403/bulk-update-with-pythons-elasticsearch
+ """
+ for _id, value in dictionary.items():
+ if not update:
+ doc = {
+ "_id": str(_id),
+ "_op_type": "index",
+ "refresh": "wait_for",
+ self.text_key: value[self.text_key],
+ self.title_key: value[self.title_key],
+ }
+ else:
+ doc = {
+ "_id": str(_id),
+ "_op_type": "update",
+ "refresh": "wait_for",
+ "doc": {
+ self.text_key: value[self.text_key],
+ self.title_key: value[self.title_key],
+ }
+ }
+
+ yield doc
+
+ def hit_template(self, es_res: Dict[str, object], hits: List[Tuple[str, float]]) -> Dict[str, object]:
+ """Hit output results template
+
+ Args:
+ es_res (Dict[str, object]): Elasticsearch response
+ hits (List[Tuple[str, float]]): Hits from Elasticsearch
+
+ Returns:
+ Dict[str, object]: Hit results
+ """
+ result = {
+ 'meta': {
+ 'total': es_res['hits']['total']['value'],
+ 'took': es_res['took'],
+ 'num_hits': len(hits)
+ },
+ 'hits': hits,
+ }
+ return result
\ No newline at end of file
diff --git a/beir/retrieval/search/sparse/__init__.py b/beir/retrieval/search/sparse/__init__.py
new file mode 100644
index 0000000..0c4d3cb
--- /dev/null
+++ b/beir/retrieval/search/sparse/__init__.py
@@ -0,0 +1 @@
+from .sparse_search import SparseSearch
\ No newline at end of file
diff --git a/beir/retrieval/search/sparse/sparse_search.py b/beir/retrieval/search/sparse/sparse_search.py
new file mode 100644
index 0000000..5f08d87
--- /dev/null
+++ b/beir/retrieval/search/sparse/sparse_search.py
@@ -0,0 +1,46 @@
+from tqdm.autonotebook import trange
+from typing import List, Dict, Union, Tuple
+import logging
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+class SparseSearch:
+
+ def __init__(self, model, batch_size: int = 16, **kwargs):
+ self.model = model
+ self.batch_size = batch_size
+ self.sparse_matrix = None
+ self.results = {}
+
+ def search(self,
+ corpus: Dict[str, Dict[str, str]],
+ queries: Dict[str, str],
+ top_k: int,
+ score_function: str,
+ query_weights: bool = False,
+ *args, **kwargs) -> Dict[str, Dict[str, float]]:
+
+ doc_ids = list(corpus.keys())
+ query_ids = list(queries.keys())
+ documents = [corpus[doc_id] for doc_id in doc_ids]
+ logging.info("Computing document embeddings and creating sparse matrix")
+ self.sparse_matrix = self.model.encode_corpus(documents, batch_size=self.batch_size)
+
+ logging.info("Starting to Retrieve...")
+ for start_idx in trange(0, len(queries), desc='query'):
+ qid = query_ids[start_idx]
+ query_tokens = self.model.encode_query(queries[qid])
+
+ if query_weights:
+ # used for uniCOIL, query weights are considered!
+ scores = self.sparse_matrix.dot(query_tokens)
+ else:
+ # used for SPARTA, query weights are not considered (i.e. binary)!
+ scores = np.asarray(self.sparse_matrix[query_tokens, :].sum(axis=0)).squeeze(0)
+
+ top_k_ind = np.argpartition(scores, -top_k)[-top_k:]
+ self.results[qid] = {doc_ids[pid]: float(scores[pid]) for pid in top_k_ind if doc_ids[pid] != qid}
+
+ return self.results
+
diff --git a/beir/retrieval/train.py b/beir/retrieval/train.py
new file mode 100644
index 0000000..60af1c2
--- /dev/null
+++ b/beir/retrieval/train.py
@@ -0,0 +1,148 @@
+from sentence_transformers import SentenceTransformer, SentencesDataset, models, datasets
+from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator, InformationRetrievalEvaluator
+from sentence_transformers.readers import InputExample
+from transformers import AdamW
+from torch import nn
+from torch.utils.data import DataLoader
+from torch.optim import Optimizer
+from tqdm.autonotebook import trange
+from typing import Dict, Type, List, Callable, Iterable, Tuple
+import logging
+import time
+import difflib
+
+logger = logging.getLogger(__name__)
+
+class TrainRetriever:
+
+ def __init__(self, model: Type[SentenceTransformer], batch_size: int = 64):
+ self.model = model
+ self.batch_size = batch_size
+
+ def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
+ qrels: Dict[str, Dict[str, int]]) -> List[Type[InputExample]]:
+
+ query_ids = list(queries.keys())
+ train_samples = []
+
+ for idx, start_idx in enumerate(trange(0, len(query_ids), self.batch_size, desc='Adding Input Examples')):
+ query_ids_batch = query_ids[start_idx:start_idx+self.batch_size]
+ for query_id in query_ids_batch:
+ for corpus_id, score in qrels[query_id].items():
+ if score >= 1: # if score = 0, we don't consider for training
+ try:
+ s1 = queries[query_id]
+ s2 = corpus[corpus_id].get("title") + " " + corpus[corpus_id].get("text")
+ train_samples.append(InputExample(guid=idx, texts=[s1, s2], label=1))
+ except KeyError:
+ logging.error("Error: Key {} not present in corpus!".format(corpus_id))
+
+ logger.info("Loaded {} training pairs.".format(len(train_samples)))
+ return train_samples
+
+ def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type[InputExample]]:
+
+ train_samples = []
+
+ for idx, start_idx in enumerate(trange(0, len(triplets), self.batch_size, desc='Adding Input Examples')):
+ triplets_batch = triplets[start_idx:start_idx+self.batch_size]
+ for triplet in triplets_batch:
+ guid = None
+ train_samples.append(InputExample(guid=guid, texts=triplet))
+
+ logger.info("Loaded {} training pairs.".format(len(train_samples)))
+ return train_samples
+
+ def prepare_train(self, train_dataset: List[Type[InputExample]], shuffle: bool = True, dataset_present: bool = False) -> DataLoader:
+
+ if not dataset_present:
+ train_dataset = SentencesDataset(train_dataset, model=self.model)
+
+ train_dataloader = DataLoader(train_dataset, shuffle=shuffle, batch_size=self.batch_size)
+ return train_dataloader
+
+ def prepare_train_triplets(self, train_dataset: List[Type[InputExample]]) -> DataLoader:
+
+ train_dataloader = datasets.NoDuplicatesDataLoader(train_dataset, batch_size=self.batch_size)
+ return train_dataloader
+
+ def load_ir_evaluator(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
+ qrels: Dict[str, Dict[str, int]], max_corpus_size: int = None, name: str = "eval") -> SentenceEvaluator:
+
+ if len(queries) <= 0:
+ raise ValueError("Dev Set Empty!, Cannot evaluate on Dev set.")
+
+ rel_docs = {}
+ corpus_ids = set()
+
+ # need to convert corpus to cid => doc
+ corpus = {idx: corpus[idx].get("title") + " " + corpus[idx].get("text") for idx in corpus}
+
+ # need to convert dev_qrels to qid => Set[cid]
+ for query_id, metadata in qrels.items():
+ rel_docs[query_id] = set()
+ for corpus_id, score in metadata.items():
+ if score >= 1:
+ corpus_ids.add(corpus_id)
+ rel_docs[query_id].add(corpus_id)
+
+ if max_corpus_size:
+ # check if length of corpus_ids > max_corpus_size
+ if len(corpus_ids) > max_corpus_size:
+ raise ValueError("Your maximum corpus size should atleast contain {} corpus ids".format(len(corpus_ids)))
+
+ # Add mandatory corpus documents
+ new_corpus = {idx: corpus[idx] for idx in corpus_ids}
+
+ # Remove mandatory corpus documents from original corpus
+ for corpus_id in corpus_ids:
+ corpus.pop(corpus_id, None)
+
+ # Sample randomly remaining corpus documents
+ for corpus_id in random.sample(list(corpus), max_corpus_size - len(corpus_ids)):
+ new_corpus[corpus_id] = corpus[corpus_id]
+
+ corpus = new_corpus
+
+ logger.info("{} set contains {} documents and {} queries".format(name, len(corpus), len(queries)))
+ return InformationRetrievalEvaluator(queries, corpus, rel_docs, name=name)
+
+ def load_dummy_evaluator(self) -> SentenceEvaluator:
+ return SequentialEvaluator([], main_score_function=lambda x: time.time())
+
+ def fit(self,
+ train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
+ evaluator: SentenceEvaluator = None,
+ epochs: int = 1,
+ steps_per_epoch = None,
+ scheduler: str = 'WarmupLinear',
+ warmup_steps: int = 10000,
+ optimizer_class: Type[Optimizer] = AdamW,
+ optimizer_params : Dict[str, object]= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
+ weight_decay: float = 0.01,
+ evaluation_steps: int = 0,
+ output_path: str = None,
+ save_best_model: bool = True,
+ max_grad_norm: float = 1,
+ use_amp: bool = False,
+ callback: Callable[[float, int, int], None] = None,
+ **kwargs):
+
+ # Train the model
+ logger.info("Starting to Train...")
+
+ self.model.fit(train_objectives=train_objectives,
+ evaluator=evaluator,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ warmup_steps=warmup_steps,
+ optimizer_class=optimizer_class,
+ scheduler=scheduler,
+ optimizer_params=optimizer_params,
+ weight_decay=weight_decay,
+ output_path=output_path,
+ evaluation_steps=evaluation_steps,
+ save_best_model=save_best_model,
+ max_grad_norm=max_grad_norm,
+ use_amp=use_amp,
+ callback=callback, **kwargs)
\ No newline at end of file
diff --git a/beir/util.py b/beir/util.py
new file mode 100644
index 0000000..254a2cc
--- /dev/null
+++ b/beir/util.py
@@ -0,0 +1,121 @@
+from typing import Dict
+from tqdm.autonotebook import tqdm
+import csv
+import torch
+import json
+import logging
+import os
+import requests
+import zipfile
+
+logger = logging.getLogger(__name__)
+
+def dot_score(a: torch.Tensor, b: torch.Tensor):
+ """
+ Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
+ :return: Matrix with res[i][j] = dot_prod(a[i], b[j])
+ """
+ if not isinstance(a, torch.Tensor):
+ a = torch.tensor(a)
+
+ if not isinstance(b, torch.Tensor):
+ b = torch.tensor(b)
+
+ if len(a.shape) == 1:
+ a = a.unsqueeze(0)
+
+ if len(b.shape) == 1:
+ b = b.unsqueeze(0)
+
+ return torch.mm(a, b.transpose(0, 1))
+
+def cos_sim(a: torch.Tensor, b: torch.Tensor):
+ """
+ Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
+ :return: Matrix with res[i][j] = cos_sim(a[i], b[j])
+ """
+ if not isinstance(a, torch.Tensor):
+ a = torch.tensor(a)
+
+ if not isinstance(b, torch.Tensor):
+ b = torch.tensor(b)
+
+ if len(a.shape) == 1:
+ a = a.unsqueeze(0)
+
+ if len(b.shape) == 1:
+ b = b.unsqueeze(0)
+
+ a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
+ b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
+ return torch.mm(a_norm, b_norm.transpose(0, 1))
+
+def download_url(url: str, save_path: str, chunk_size: int = 1024):
+ """Download url with progress bar using tqdm
+ https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
+
+ Args:
+ url (str): downloadable url
+ save_path (str): local path to save the downloaded file
+ chunk_size (int, optional): chunking of files. Defaults to 1024.
+ """
+ r = requests.get(url, stream=True)
+ total = int(r.headers.get('Content-Length', 0))
+ with open(save_path, 'wb') as fd, tqdm(
+ desc=save_path,
+ total=total,
+ unit='iB',
+ unit_scale=True,
+ unit_divisor=chunk_size,
+ ) as bar:
+ for data in r.iter_content(chunk_size=chunk_size):
+ size = fd.write(data)
+ bar.update(size)
+
+def unzip(zip_file: str, out_dir: str):
+ zip_ = zipfile.ZipFile(zip_file, "r")
+ zip_.extractall(path=out_dir)
+ zip_.close()
+
+def download_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
+
+ os.makedirs(out_dir, exist_ok=True)
+ dataset = url.split("/")[-1]
+ zip_file = os.path.join(out_dir, dataset)
+
+ if not os.path.isfile(zip_file):
+ logger.info("Downloading {} ...".format(dataset))
+ download_url(url, zip_file, chunk_size)
+
+ if not os.path.isdir(zip_file.replace(".zip", "")):
+ logger.info("Unzipping {} ...".format(dataset))
+ unzip(zip_file, out_dir)
+
+ return os.path.join(out_dir, dataset.replace(".zip", ""))
+
+def write_to_json(output_file: str, data: Dict[str, str]):
+ with open(output_file, 'w') as fOut:
+ for idx, meta in data.items():
+ if type(meta) == str:
+ json.dump({
+ "_id": idx,
+ "text": meta,
+ "metadata": {}
+ }, fOut)
+
+ elif type(meta) == dict:
+ json.dump({
+ "_id": idx,
+ "title": meta.get("title", ""),
+ "text": meta.get("text", ""),
+ "metadata": {}
+ }, fOut)
+ fOut.write('\n')
+
+def write_to_tsv(output_file: str, data: Dict[str, str]):
+ with open(output_file, 'w') as fOut:
+ writer = csv.writer(fOut, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
+ writer.writerow(["query-id", "corpus-id", "score"])
+ for query_id, corpus_dict in data.items():
+ for corpus_id, score in corpus_dict.items():
+ writer.writerow([query_id, corpus_id, score])
diff --git a/examples/beir-pyserini/Dockerfile b/examples/beir-pyserini/Dockerfile
new file mode 100644
index 0000000..3f31c47
--- /dev/null
+++ b/examples/beir-pyserini/Dockerfile
@@ -0,0 +1,26 @@
+FROM python:3.6-slim
+
+# Install Java first, to better take advantage of layer caching.
+#
+# Note (1): first mkdir line fixes the following error:
+# E: Sub-process /usr/bin/dpkg returned an error code (1)
+# https://stackoverflow.com/questions/58160597/docker-fails-with-sub-process-usr-bin-dpkg-returned-an-error-code-1
+#
+# Note (2): pyjnius appears to need JDK, JRE doesn't suffice.
+#
+RUN mkdir -p /usr/share/man/man1 && \
+ apt update && \
+ apt install -y bash \
+ build-essential \
+ curl \
+ ca-certificates \
+ openjdk-11-jdk-headless && \
+ rm -rf /var/lib/apt/lists
+
+
+RUN pip install pyserini==0.12.0 fastapi uvicorn python-multipart
+
+WORKDIR /home
+COPY main.py config.py /home/
+RUN mkdir /home/datasets
+CMD ["uvicorn", "main:app", "--host", "0.0.0.0"]
diff --git a/examples/beir-pyserini/config.py b/examples/beir-pyserini/config.py
new file mode 100644
index 0000000..67a7253
--- /dev/null
+++ b/examples/beir-pyserini/config.py
@@ -0,0 +1,15 @@
+from pydantic import BaseSettings
+
+class IndexSettings(BaseSettings):
+ index_name: str = "beir/test"
+ data_folder: str = "/home/datasets/"
+
+def hit_template(hits):
+ results = {}
+
+ for qid, hit in hits.items():
+ results[qid] = {}
+ for i in range(0, len(hit)):
+ results[qid][hit[i].docid] = hit[i].score
+
+ return results
\ No newline at end of file
diff --git a/examples/beir-pyserini/dockerhub.sh b/examples/beir-pyserini/dockerhub.sh
new file mode 100755
index 0000000..e9609ef
--- /dev/null
+++ b/examples/beir-pyserini/dockerhub.sh
@@ -0,0 +1,10 @@
+#!/bin/sh
+#This tagname build the docker hub containers
+
+# TAGNAME="1.0"
+
+# docker build --no-cache -t beir/pyserini-fastapi:${TAGNAME} .
+# docker push beir/pyserini-fastapi:${TAGNAME}
+
+docker build --no-cache -t beir/pyserini-fastapi:latest .
+docker push beir/pyserini-fastapi:latest
\ No newline at end of file
diff --git a/examples/beir-pyserini/main.py b/examples/beir-pyserini/main.py
new file mode 100644
index 0000000..329ca36
--- /dev/null
+++ b/examples/beir-pyserini/main.py
@@ -0,0 +1,78 @@
+import sys, os
+import config
+
+from fastapi import FastAPI, File, UploadFile
+from pyserini.search import SimpleSearcher
+from typing import Optional, List, Dict, Union
+
+settings = config.IndexSettings()
+app = FastAPI()
+
+@app.post("/upload/")
+async def upload(file: UploadFile = File(...)):
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ filename = f'{dir_path}/datasets/{file.filename}'
+ settings.data_folder = f'{dir_path}/datasets/'
+ f = open(f'{filename}', 'wb')
+ content = await file.read()
+ f.write(content)
+ return {"filename": file.filename}
+
+@app.get("/index/")
+def index(index_name: str, threads: Optional[int] = 8):
+ settings.index_name = index_name
+
+ command = f"python -m pyserini.index -collection JsonCollection \
+ -generator DefaultLuceneDocumentGenerator -threads {threads} \
+ -input {settings.data_folder} -index {settings.index_name} -storeRaw \
+ -storePositions -storeDocvectors"
+
+ os.system(command)
+
+ return {200: "OK"}
+
+@app.get("/lexical/search/")
+def search(q: str,
+ k: Optional[int] = 1000,
+ bm25: Optional[Dict[str, float]] = {"k1": 0.9, "b": 0.4},
+ fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0}):
+
+ searcher = SimpleSearcher(settings.index_name)
+ searcher.set_bm25(k1=bm25["k1"], b=bm25["b"])
+
+ hits = searcher.search(q, k=k, fields=fields)
+ results = []
+ for i in range(0, len(hits)):
+ results.append({'docid': hits[i].docid, 'score': hits[i].score})
+
+ return {'results': results}
+
+@app.post("/lexical/batch_search/")
+def batch_search(queries: List[str],
+ qids: List[str],
+ k: Optional[int] = 1000,
+ threads: Optional[int] = 8,
+ bm25: Optional[Dict[str, float]] = {"k1": 0.9, "b": 0.4},
+ fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0}):
+
+ searcher = SimpleSearcher(settings.index_name)
+ searcher.set_bm25(k1=bm25["k1"], b=bm25["b"])
+
+ hits = searcher.batch_search(queries=queries, qids=qids, k=k, threads=threads, fields=fields)
+ return {'results': config.hit_template(hits)}
+
+@app.post("/lexical/rm3/batch_search/")
+def batch_search_rm3(queries: List[str],
+ qids: List[str],
+ k: Optional[int] = 1000,
+ threads: Optional[int] = 8,
+ bm25: Optional[Dict[str, float]] = {"k1": 0.9, "b": 0.4},
+ fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0},
+ rm3: Optional[Dict[str, Union[int, float]]] = {"fb_terms": 10, "fb_docs": 10, "original_query_weight": 0.5}):
+
+ searcher = SimpleSearcher(settings.index_name)
+ searcher.set_bm25(k1=bm25["k1"], b=bm25["b"])
+ searcher.set_rm3(fb_terms=rm3["fb_terms"], fb_docs=rm3["fb_docs"], original_query_weight=rm3["original_query_weight"])
+
+ hits = searcher.batch_search(queries=queries, qids=qids, k=k, threads=threads, fields=fields)
+ return {'results': config.hit_template(hits)}
\ No newline at end of file
diff --git a/examples/benchmarking/benchmark_bm25.py b/examples/benchmarking/benchmark_bm25.py
new file mode 100644
index 0000000..9458ffd
--- /dev/null
+++ b/examples/benchmarking/benchmark_bm25.py
@@ -0,0 +1,72 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+
+import pathlib, os
+import datetime
+import logging
+import random
+
+random.seed(42)
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+
+#### /print debug information to stdout
+
+#### Download dbpedia-entity.zip dataset and unzip the dataset
+dataset = "dbpedia-entity"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Loading test queries and corpus in DBPedia
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+corpus_ids, query_ids = list(corpus), list(queries)
+
+#### Randomly sample 1M pairs from Original Corpus (4.63M pairs)
+#### First include all relevant documents (i.e. present in qrels)
+corpus_set = set()
+for query_id in qrels:
+ corpus_set.update(list(qrels[query_id].keys()))
+corpus_new = {corpus_id: corpus[corpus_id] for corpus_id in corpus_set}
+
+#### Remove already seen k relevant documents and sample (1M - k) docs randomly
+remaining_corpus = list(set(corpus_ids) - corpus_set)
+sample = 1000000 - len(corpus_set)
+
+for corpus_id in random.sample(remaining_corpus, sample):
+ corpus_new[corpus_id] = corpus[corpus_id]
+
+#### Provide parameters for Elasticsearch
+hostname = "desktop-158.ukp.informatik.tu-darmstadt.de:9200"
+index_name = dataset
+model = BM25(index_name=index_name, hostname=hostname)
+bm25 = EvaluateRetrieval(model)
+
+#### Index 1M passages into the index (seperately)
+bm25.retriever.index(corpus_new)
+
+#### Saving benchmark times
+time_taken_all = {}
+
+for query_id in query_ids:
+ query = queries[query_id]
+
+ #### Measure time to retrieve top-10 BM25 documents using single query latency
+ start = datetime.datetime.now()
+ results = bm25.retriever.es.lexical_search(text=query, top_hits=10)
+ end = datetime.datetime.now()
+
+ #### Measuring time taken in ms (milliseconds)
+ time_taken = (end - start)
+ time_taken = time_taken.total_seconds() * 1000
+ time_taken_all[query_id] = time_taken
+ logging.info("{}: {} {:.2f}ms".format(query_id, query, time_taken))
+
+time_taken = list(time_taken_all.values())
+logging.info("Average time taken: {:.2f}ms".format(sum(time_taken)/len(time_taken_all)))
\ No newline at end of file
diff --git a/examples/benchmarking/benchmark_bm25_ce_reranking.py b/examples/benchmarking/benchmark_bm25_ce_reranking.py
new file mode 100644
index 0000000..b5aec14
--- /dev/null
+++ b/examples/benchmarking/benchmark_bm25_ce_reranking.py
@@ -0,0 +1,84 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+from beir.reranking.models import CrossEncoder
+from operator import itemgetter
+
+import pathlib, os
+import datetime
+import logging
+import random
+
+random.seed(42)
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+
+#### /print debug information to stdout
+
+#### Download dbpedia-entity.zip dataset and unzip the dataset
+dataset = "dbpedia-entity"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Loading test queries and corpus in DBPedia
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+corpus_ids, query_ids = list(corpus), list(queries)
+corpus_texts = {corpus_id: corpus[corpus_id]["title"] + " " + corpus[corpus_id]["text"] for corpus_id in corpus}
+
+#### Randomly sample 1M pairs from Original Corpus (4.63M pairs)
+#### First include all relevant documents (i.e. present in qrels)
+corpus_set = set()
+for query_id in qrels:
+ corpus_set.update(list(qrels[query_id].keys()))
+corpus_new = {corpus_id: corpus[corpus_id] for corpus_id in corpus_set}
+
+#### Remove already seen k relevant documents and sample (1M - k) docs randomly
+remaining_corpus = list(set(corpus_ids) - corpus_set)
+sample = 1000000 - len(corpus_set)
+
+for corpus_id in random.sample(remaining_corpus, sample):
+ corpus_new[corpus_id] = corpus[corpus_id]
+
+#### Provide parameters for Elasticsearch
+hostname = "desktop-158.ukp.informatik.tu-darmstadt.de:9200"
+index_name = dataset
+model = BM25(index_name=index_name, hostname=hostname)
+bm25 = EvaluateRetrieval(model)
+
+#### Index 1M passages into the index (seperately)
+bm25.retriever.index(corpus_new)
+
+#### Reranking using Cross-Encoder model
+reranker = CrossEncoder('cross-encoder/ms-marco-electra-base')
+
+#### Saving benchmark times
+time_taken_all = {}
+
+for query_id in query_ids:
+ query = queries[query_id]
+
+ #### Measure time to retrieve top-100 BM25 documents using single query latency
+ start = datetime.datetime.now()
+ results = bm25.retriever.es.lexical_search(text=query, top_hits=100)
+
+ #### Measure time to rerank top-100 BM25 documents using CE
+ sentence_pairs = [[queries[query_id], corpus_texts[hit[0]]] for hit in results["hits"]]
+ scores = reranker.predict(sentence_pairs, batch_size=100, show_progress_bar=False)
+ hits = {results["hits"][idx][0]: scores[idx] for idx in range(len(scores))}
+ sorted_results = {k: v for k,v in sorted(hits.items(), key=itemgetter(1), reverse=True)}
+ end = datetime.datetime.now()
+
+ #### Measuring time taken in ms (milliseconds)
+ time_taken = (end - start)
+ time_taken = time_taken.total_seconds() * 1000
+ time_taken_all[query_id] = time_taken
+ logging.info("{}: {} {:.2f}ms".format(query_id, query, time_taken))
+
+time_taken = list(time_taken_all.values())
+logging.info("Average time taken: {:.2f}ms".format(sum(time_taken)/len(time_taken_all)))
\ No newline at end of file
diff --git a/examples/benchmarking/benchmark_sbert.py b/examples/benchmarking/benchmark_sbert.py
new file mode 100644
index 0000000..79d396a
--- /dev/null
+++ b/examples/benchmarking/benchmark_sbert.py
@@ -0,0 +1,83 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.search.dense import util as utils
+
+import pathlib, os, sys
+import numpy as np
+import torch
+import logging
+import datetime
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download dbpedia-entity.zip dataset and unzip the dataset
+dataset = "dbpedia-entity"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where dbpedia-entity has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+corpus_ids, query_ids = list(corpus), list(queries)
+
+#### For benchmarking using dense models, you can take any 1M documents, as it doesnt matter which documents you chose.
+#### For simplicity, we take the first 1M documents.
+number_docs = 1000000
+reduced_corpus = [corpus[corpus_id] for corpus_id in corpus_ids[:number_docs]]
+
+#### Dense retriever models
+#### For ANCE (msmarco-roberta-base-ance-fristp), no normalization the embeddings required (normalize=False).
+#### For DPR (facebook-dpr-question_encoder-multiset-base, facebook-dpr-ctx_encoder-multiset-base) no normalization of the embeddings required (normalize=False).
+#### For SBERT (msmarco-distilbert-base-v3) normalization of the embeddings are required (normalize=True).
+
+model_path = "msmarco-distilbert-base-v3"
+model = models.SentenceBERT(model_path=model_path)
+normalize = True
+
+#### Pre-compute all document embeddings (with or without normalization)
+#### We do not count the time required to compute document embeddings, at inference we assume to have document embeddings in-memory.
+logging.info("Computing Document Embeddings...")
+if normalize:
+ corpus_embs = model.encode_corpus(reduced_corpus, batch_size=128, convert_to_tensor=True, normalize_embeddings=True)
+else:
+ corpus_embs = model.encode_corpus(reduced_corpus, batch_size=128, convert_to_tensor=True)
+
+#### Saving benchmark times
+time_taken_all = {}
+
+for query_id in query_ids:
+ query = queries[query_id]
+
+ #### Compute query embedding and retrieve similar scores using dot-product
+ start = datetime.datetime.now()
+ if normalize:
+ query_emb = model.encode_queries([query], batch_size=1, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False)
+ else:
+ query_emb = model.encode_queries([query], batch_size=1, convert_to_tensor=True, show_progress_bar=False)
+
+ #### Dot product for normalized embeddings is equal to cosine similarity
+ sim_scores = utils.dot_score(query_emb, corpus_embs)
+ sim_scores_top_k_values, sim_scores_top_k_idx = torch.topk(sim_scores, 10, dim=1, largest=True, sorted=True)
+ end = datetime.datetime.now()
+
+ #### Measuring time taken in ms (milliseconds)
+ time_taken = (end - start)
+ time_taken = time_taken.total_seconds() * 1000
+ time_taken_all[query_id] = time_taken
+ logging.info("{}: {} {:.2f}ms".format(query_id, query, time_taken))
+
+time_taken = list(time_taken_all.values())
+logging.info("Average time taken: {:.2f}ms".format(sum(time_taken)/len(time_taken_all)))
+
+#### Measuring Index size consumed by document embeddings
+corpus_embs = corpus_embs.cpu()
+cpu_memory = sys.getsizeof(np.asarray([emb.numpy() for emb in corpus_embs]))
+
+logging.info("Number of documents: {}, Dim: {}".format(len(corpus_embs), len(corpus_embs[0])))
+logging.info("Index size (in MB): {:.2f}MB".format(cpu_memory*0.000001))
\ No newline at end of file
diff --git a/examples/dataset/README.md b/examples/dataset/README.md
new file mode 100644
index 0000000..447d687
--- /dev/null
+++ b/examples/dataset/README.md
@@ -0,0 +1,54 @@
+# Dataset Information
+
+Generally, all public datasets can be easily downloaded using the zip folder.
+
+Below we mention how to reproduce retrieval on datasets which are not public -
+
+## 1. TREC-NEWS
+
+### Corpus
+
+1. Fill up the application to use the Washington Post (WaPo) Corpus: https://trec.nist.gov/data/wapost/
+2. Loop through your contents. For a single document, get all the ``paragraph`` subtypes and extract HTML from text in case mime is ``text/html`` or directly include text from ``text/plain``.
+3. I used ``html2text`` (https://pypi.org/project/html2text/) python package to extract text out of the HTML.
+
+### Queries and Qrels
+1. Download background linking topics and qrels from 2019 News Track: https://trec.nist.gov/data/news2019.html
+2. We consider the document title as the query for our experiments.
+
+## 2. BioASQ
+
+### Corpus
+
+1. Register yourself at BioASQ: http://www.bioasq.org/
+2. Download documents from BioASQ task 9a (Training v.2020 ~ 14,913,939 docs) and extract the title and abstractText for each document.
+3. There are few documents not present in this corpus but present in test qrels so we add them manually.
+4. Find these manual documents here: https://docs.google.com/spreadsheets/d/1GZghfN5RT8h01XzIlejuwhBIGe8f-VaGf-yGaq11U-k/edit#gid=2015463710
+
+### Queries and Qrels
+1. Download Training and Test dataset from BioASQ 8B datasets which were published in 2020.
+2. Consider all documents with answers as relevant (binary label) for a given question.
+
+## 3. Robust04
+
+### Corpus
+
+1. Fill up the application to use the TREC disks 4 and 5: https://trec.nist.gov/data/cd45/index.html
+2. Download, format it according to ``ir_datasets`` and get the preprocessed corpus: https://ir-datasets.com/trec-robust04.html#trec-robust04
+
+### Queries and Qrels
+1. Download the queries and qrels from ``ir_datasets`` with the key ``trec-robust04`` here - https://ir-datasets.com/trec-robust04.html#trec-robust04
+2. For our experiments, we used the description of the query for retrieval.
+
+## 4. Signal-1M
+
+### Corpus
+1. Scrape tweets from Twitter manually for the ids here: https://github.com/igorbrigadir/newsir16-data/tree/master/twitter/curated
+2. I used ``tweepy`` (https://www.tweepy.org/) from python to scrape tweets. You can find the script here: [scrape_tweets.py](https://github.com/UKPLab/beir/blob/main/examples/dataset/scrape_tweets.py).
+3. We preprocess the text retrieved, we remove emojis and links from the original text. You can find the function implementations in the code above.
+4. Remove tweets which are empty or do not contain any text.
+
+### Queries and Qrels
+1. Sign up at Signal1M website to download qrels: https://research.signal-ai.com/datasets/signal1m-tweetir.html
+2. Sign up at Signal1M website to download queries: https://research.signal-ai.com/datasets/signal1m.html
+3. We consider the title of the query for our experiments.
\ No newline at end of file
diff --git a/examples/dataset/download_dataset.py b/examples/dataset/download_dataset.py
new file mode 100644
index 0000000..2132858
--- /dev/null
+++ b/examples/dataset/download_dataset.py
@@ -0,0 +1,27 @@
+import os
+import pathlib
+from beir import util
+
+def main():
+
+ out_dir = pathlib.Path(__file__).parent.absolute()
+
+ dataset_files = ["msmarco.zip", "trec-covid.zip", "nfcorpus.zip",
+ "nq.zip", "hotpotqa.zip", "fiqa.zip", "arguana.zip",
+ "webis-touche2020.zip", "cqadupstack.zip", "quora.zip",
+ "dbpedia-entity.zip", "scidocs.zip", "fever.zip",
+ "climate-fever.zip", "scifact.zip", "germanquad.zip"]
+
+ for dataset in dataset_files:
+
+ zip_file = os.path.join(out_dir, dataset)
+ url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}".format(dataset)
+
+ print("Downloading {} ...".format(dataset))
+ util.download_url(url, zip_file)
+
+ print("Unzipping {} ...".format(dataset))
+ util.unzip(zip_file, out_dir)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/examples/dataset/md5.csv b/examples/dataset/md5.csv
new file mode 100644
index 0000000..b8b518f
--- /dev/null
+++ b/examples/dataset/md5.csv
@@ -0,0 +1,17 @@
+dataset,md5
+msmarco.zip,444067daf65d982533ea17ebd59501e4
+trec-covid.zip,ce62140cb23feb9becf6270d0d1fe6d1
+nfcorpus.zip,a89dba18a62ef92f7d323ec890a0d38d
+nq.zip,d4d3d2e48787a744b6f6e691ff534307
+hotpotqa.zip,f412724f78b0d91183a0e86805e16114
+fiqa.zip,17918ed23cd04fb15047f73e6c3bd9d9
+arguana.zip,8ad3e3c2a5867cdced806d6503f29b99
+webis-touche2020.zip,6e29ac6c57e2d227fb57501872cac45f
+cqadupstack.zip,4e41456d7df8ee7760a7f866133bda78
+quora.zip,18fb154900ba42a600f84b839c173167
+dbpedia-entity.zip,c2a39eb420a3164af735795df012ac2c
+scidocs.zip,38121350fc3a4d2f48850f6aff52e4a9
+fever.zip,5f7d1de60b170fc8027bb7898e2efca1
+climate-fever.zip,8b66f0a9126c521bae2bde127b4dc99d
+scifact.zip,5f7d1de60b170fc8027bb7898e2efca1
+germanquad.zip,95a581c3162d10915a418609bcce851b
diff --git a/examples/dataset/scrape_tweets.py b/examples/dataset/scrape_tweets.py
new file mode 100644
index 0000000..806b6a2
--- /dev/null
+++ b/examples/dataset/scrape_tweets.py
@@ -0,0 +1,95 @@
+'''
+The following is a basic twitter scraper code using tweepy.
+We preprocess the text - 1. Remove Emojis 2. Remove urls from the tweet.
+We store the output tweets with tweet-id and tweet-text in each line tab seperated.
+
+You will need to have an active Twitter account and provide your consumer key, secret and a callback url.
+You can get your keys from here: https://developer.twitter.com/en/portal/projects-and-apps
+Twitter by default implements rate limiting of scraping tweets per hour: https://developer.twitter.com/en/docs/twitter-api/rate-limits
+Default limits are 300 calls (in every 15 mins).
+
+Install tweepy (pip install tweepy) to run the code below.
+python scrape_tweets.py
+'''
+
+import tweepy
+import csv
+import pickle
+import tqdm
+import re
+
+#### Twitter Account Details
+consumer_key = 'XXXXXXXX' # Your twitter consumer key
+consumer_secret = 'XXXXXXXX' # Your twitter consumer secret
+callback_url = 'XXXXXXXX' # callback url
+
+#### Input/Output Details
+input_file = "input-tweets.tsv" # Tab seperated file containing twitter tweet-id in each line
+output_file = "201509-tweet-scraped-ids-test.txt" # output file which you wish to save
+
+def chunks(lst, n):
+ """Yield successive n-sized chunks from lst."""
+ for i in range(0, len(lst), n):
+ yield lst[i:i + n]
+
+def de_emojify(text):
+ regrex_pattern = re.compile(pattern = "["
+ u"\U0001F600-\U0001F64F" # emoticons
+ u"\U0001F300-\U0001F5FF" # symbols & pictographs
+ u"\U0001F680-\U0001F6FF" # transport & map symbols
+ u"\U0001F1E0-\U0001F1FF" # flags (iOS)
+ "]+", flags = re.UNICODE)
+ return regrex_pattern.sub(r'',text)
+
+def preprocessing(text):
+ return re.sub(r"http\S+", "", de_emojify(text).replace("\n", "")).strip()
+
+def update_tweet_dict(tweets, tweet_dict):
+ for tweet in tweets:
+ if tweet:
+ try:
+ idx = tweet.id_str.strip()
+ tweet_dict[idx] = preprocessing(tweet.text)
+ except:
+ continue
+
+ return tweet_dict
+
+def write_dict_to_file(filename, dic):
+ with open(filename, "w") as outfile:
+ outfile.write("\n".join((idx + "\t" + text) for idx, text in dic.items()))
+
+### Main Code starts here
+auth = tweepy.OAuthHandler(consumer_key, consumer_secret, callback_url)
+try:
+ redirect_url = auth.get_authorization_url()
+except tweepy.TweepError:
+ print('Error! Failed to get request token.')
+
+api = tweepy.API(auth, wait_on_rate_limit=True, wait_on_rate_limit_notify=True)
+
+all_tweets = []
+tweets = []
+tweet_dict = {}
+
+reader = csv.reader(open(input_file, encoding="utf-8"), delimiter="\t", quoting=csv.QUOTE_NONE)
+for row in reader:
+ all_tweets.append(row[0])
+
+generator = chunks(all_tweets, 100)
+batches = int(len(all_tweets)/100)
+total = batches if len(all_tweets) % 100 == 0 else batches + 1
+
+print("Retrieving Tweets...")
+for idx, tweet_id_chunks in enumerate(tqdm.tqdm(generator, total=total)):
+
+ if idx >= 300 and idx % 300 == 0: # Rate-limiting every 300 calls (in 15 mins)
+ print("Preprocessing Text...")
+ tweet_dict = update_tweet_dict(tweets, tweet_dict)
+ write_dict_to_file(output_file, tweet_dict)
+
+ tweets += api.statuses_lookup(id_=tweet_id_chunks, include_entities=True, trim_user=True, map=None)
+
+print("Preprocessing Text...")
+tweet_dict = update_tweet_dict(tweets, tweet_dict)
+write_dict_to_file(output_file, tweet_dict)
\ No newline at end of file
diff --git a/examples/generation/passage_expansion_tilde.py b/examples/generation/passage_expansion_tilde.py
new file mode 100644
index 0000000..1719d05
--- /dev/null
+++ b/examples/generation/passage_expansion_tilde.py
@@ -0,0 +1,51 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.generation import PassageExpansion as PassageExp
+from beir.generation.models import TILDE
+
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where scifact has been downloaded and unzipped
+corpus = GenericDataLoader(data_path).load_corpus()
+
+#############################
+#### TILDE Model Loading ####
+#############################
+
+#### Model Loading
+model_path = "ielab/TILDE"
+generator = PassageExp(model=TILDE(model_path))
+
+#### TILDE passage expansion using top-k most likely expansion tokens from BERT Vocabulary ####
+#### Only supports bert-base-uncased (TILDE) model for now
+#### Prefix is required to store the final expanded passages as a corpus.jsonl file
+prefix = "tilde-exp"
+
+#### Expand useful tokens per passage from docs in corpus and save them in a new corpus
+#### check your datasets folder to find the expanded passages appended with the original, you will find below:
+#### 1. datasets/scifact/tilde-exp-corpus.jsonl
+
+#### Batch size denotes the number of passages getting expanded at once
+batch_size = 64
+
+#### top-k value will retrieve the top-k expansion terms with highest softmax probability
+#### These tokens are individually appended once to the passage
+#### We remove stopwords, bad-words (punctuation, etc.) and words in original passage.
+top_k = 200
+
+generator.expand(corpus, output_dir=data_path, prefix=prefix, batch_size=batch_size, top_k=top_k)
\ No newline at end of file
diff --git a/examples/generation/query_gen.py b/examples/generation/query_gen.py
new file mode 100644
index 0000000..d92e26c
--- /dev/null
+++ b/examples/generation/query_gen.py
@@ -0,0 +1,50 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.generation import QueryGenerator as QGen
+from beir.generation.models import QGenModel
+
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where scifact has been downloaded and unzipped
+corpus = GenericDataLoader(data_path).load_corpus()
+
+###########################
+#### Query-Generation ####
+###########################
+
+#### Model Loading
+model_path = "BeIR/query-gen-msmarco-t5-base-v1"
+generator = QGen(model=QGenModel(model_path))
+
+#### Query-Generation using Nucleus Sampling (top_k=25, top_p=0.95) ####
+#### https://huggingface.co/blog/how-to-generate
+#### Prefix is required to seperate out synthetic queries and qrels from original
+prefix = "gen-3"
+
+#### Generating 3 questions per document for all documents in the corpus
+#### Reminder the higher value might produce diverse questions but also duplicates
+ques_per_passage = 3
+
+#### Generate queries per passage from docs in corpus and save them in original corpus
+#### check your datasets folder to find the generated questions, you will find below:
+#### 1. datasets/scifact/gen-3-queries.jsonl
+#### 2. datasets/scifact/gen-3-qrels/train.tsv
+
+batch_size = 64
+
+generator.generate(corpus, output_dir=data_path, ques_per_passage=ques_per_passage, prefix=prefix, batch_size=batch_size)
\ No newline at end of file
diff --git a/examples/generation/query_gen_and_train.py b/examples/generation/query_gen_and_train.py
new file mode 100644
index 0000000..6464f6e
--- /dev/null
+++ b/examples/generation/query_gen_and_train.py
@@ -0,0 +1,98 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.generation import QueryGenerator as QGen
+from beir.generation.models import QGenModel
+from beir.retrieval.train import TrainRetriever
+from sentence_transformers import SentenceTransformer, losses, models
+
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+dataset = "nfcorpus"
+
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where nfcorpus has been downloaded and unzipped
+corpus = GenericDataLoader(data_path).load_corpus()
+
+
+##############################
+#### 1. Query-Generation ####
+##############################
+
+#### question-generation model loading
+model_path = "BeIR/query-gen-msmarco-t5-base-v1"
+generator = QGen(model=QGenModel(model_path))
+
+#### Query-Generation using Nucleus Sampling (top_k=25, top_p=0.95) ####
+#### https://huggingface.co/blog/how-to-generate
+#### Prefix is required to seperate out synthetic queries and qrels from original
+prefix = "gen"
+
+#### Generating 3 questions per passage.
+#### Reminder the higher value might produce lots of duplicates
+ques_per_passage = 3
+
+#### Generate queries per passage from docs in corpus and save them in data_path
+generator.generate(corpus, output_dir=data_path, ques_per_passage=ques_per_passage, prefix=prefix)
+
+################################
+#### 2. Train Dense-Encoder ####
+################################
+
+
+#### Training on Generated Queries ####
+corpus, gen_queries, gen_qrels = GenericDataLoader(data_path, prefix=prefix).load(split="train")
+#### Please Note - not all datasets contain a dev split, comment out the line if such the case
+dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")
+
+#### Provide any HuggingFace model and fine-tune from scratch
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=350)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Or provide already fine-tuned sentence-transformer model
+# model = SentenceTransformer("msmarco-distilbert-base-v3")
+
+#### Provide any sentence-transformers model path
+model_path = "bert-base-uncased" # or "msmarco-distilbert-base-v3"
+retriever = TrainRetriever(model=model, batch_size=64)
+
+#### Prepare training samples
+train_samples = retriever.load_train(corpus, gen_queries, gen_qrels)
+train_dataloader = retriever.prepare_train(train_samples, shuffle=True)
+train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
+
+#### Prepare dev evaluator
+ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)
+
+#### If no dev set is present evaluate using dummy evaluator
+# ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-GenQ-nfcorpus".format(model_path))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 1
+evaluation_steps = 5000
+warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_stepmodel_paths,
+ use_amp=True)
\ No newline at end of file
diff --git a/examples/generation/query_gen_multi_gpu.py b/examples/generation/query_gen_multi_gpu.py
new file mode 100644
index 0000000..04a7be8
--- /dev/null
+++ b/examples/generation/query_gen_multi_gpu.py
@@ -0,0 +1,79 @@
+"""
+This code shows how to generate using parallel GPU's for very long corpus.
+Multiple GPU's can be used to generate faster!
+
+We use torch.multiprocessing module and define multiple pools for each GPU.
+Then we chunk our big corpus into multiple smaller corpus and generate simultaneously.
+
+Important to use the code within the __main__ module!
+
+Usage: CUDA_VISIBLE_DEVICES=0,1 python query_gen_multi_gpu.py
+"""
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.generation import QueryGenerator as QGen
+from beir.generation.models import QGenModel
+
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes.
+if __name__ == '__main__':
+
+ #### Download scifact.zip dataset and unzip the dataset
+ dataset = "trec-covid"
+
+ url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+ out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+ data_path = util.download_and_unzip(url, out_dir)
+
+ #### Provide the data_path where scifact has been downloaded and unzipped
+ corpus = GenericDataLoader(data_path).load_corpus()
+
+ ###########################
+ #### Query-Generation ####
+ ###########################
+
+ #Define the model
+ model = QGenModel("BeIR/query-gen-msmarco-t5-base-v1")
+
+ #Start the multi-process pool on all available CUDA devices
+ pool = model.start_multi_process_pool()
+
+ generator = QGen(model=model)
+
+ #### Query-Generation using Nucleus Sampling (top_k=25, top_p=0.95) ####
+ #### https://huggingface.co/blog/how-to-generate
+ #### Prefix is required to seperate out synthetic queries and qrels from original
+ prefix = "gen-3"
+
+ #### Generating 3 questions per document for all documents in the corpus
+ #### Reminder the higher value might produce diverse questions but also duplicates
+ ques_per_passage = 3
+
+ #### Generate queries per passage from docs in corpus and save them in original corpus
+ #### check your datasets folder to find the generated questions, you will find below:
+ #### 1. datasets/scifact/gen-3-queries.jsonl
+ #### 2. datasets/scifact/gen-3-qrels/train.tsv
+
+ chunk_size = 5000 # chunks to split within each GPU
+ batch_size = 64 # batch size within a single GPU
+
+ generator.generate_multi_process(
+ corpus=corpus,
+ pool=pool,
+ output_dir=data_path,
+ ques_per_passage=ques_per_passage,
+ prefix=prefix,
+ batch_size=batch_size)
+
+ # #Optional: Stop the proccesses in the pool
+ # model.stop_multi_process_pool(pool)
\ No newline at end of file
diff --git a/examples/retrieval/README.md b/examples/retrieval/README.md
new file mode 100644
index 0000000..f37a531
--- /dev/null
+++ b/examples/retrieval/README.md
@@ -0,0 +1,6 @@
+# Retrieval
+
+This folder contains various examples to evaluate, train retriever models for datasets in BEIR.
+
+## Overall Leaderboard
+
diff --git a/examples/retrieval/evaluation/README.md b/examples/retrieval/evaluation/README.md
new file mode 100644
index 0000000..906cb25
--- /dev/null
+++ b/examples/retrieval/evaluation/README.md
@@ -0,0 +1,4 @@
+## Deep Dive into Evaluation of Retrieval Models
+
+### Leaderboard Overall
+
diff --git a/examples/retrieval/evaluation/custom/evaluate_custom_dataset.py b/examples/retrieval/evaluation/custom/evaluate_custom_dataset.py
new file mode 100644
index 0000000..b2da16c
--- /dev/null
+++ b/examples/retrieval/evaluation/custom/evaluate_custom_dataset.py
@@ -0,0 +1,67 @@
+from beir import LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Corpus ####
+# Load the corpus in this format of Dict[str, Dict[str, str]]
+# Keep the title key and mention an empty string
+
+corpus = {
+ "doc1" : {
+ "title": "Albert Einstein",
+ "text": "Albert Einstein was a German-born theoretical physicist. who developed the theory of relativity, \
+ one of the two pillars of modern physics (alongside quantum mechanics). His work is also known for \
+ its influence on the philosophy of science. He is best known to the general public for his mass–energy \
+ equivalence formula E = mc2, which has been dubbed 'the world's most famous equation'. He received the 1921 \
+ Nobel Prize in Physics 'for his services to theoretical physics, and especially for his discovery of the law \
+ of the photoelectric effect', a pivotal step in the development of quantum theory."
+ },
+ "doc2" : {
+ "title": "", # Keep title an empty string if not present
+ "text": "Wheat beer is a top-fermented beer which is brewed with a large proportion of wheat relative to the amount of \
+ malted barley. The two main varieties are German Weißbier and Belgian witbier; other types include Lambic (made\
+ with wild yeast), Berliner Weisse (a cloudy, sour beer), and Gose (a sour, salty beer)."
+ },
+}
+
+#### Queries ####
+# Load the queries in this format of Dict[str, str]
+
+queries = {
+ "q1" : "Who developed the mass-energy equivalence formula?",
+ "q2" : "Which beer is brewed with a large proportion of wheat?"
+}
+
+#### Qrels ####
+# Load the Qrels in this format of Dict[str, Dict[str, int]]
+# First query_id and then dict with doc_id with gold score (int)
+
+qrels = {
+ "q1" : {"doc1": 1},
+ "q2" : {"doc2": 1},
+}
+
+#### Sentence-Transformer ####
+#### Provide any pretrained sentence-transformers model path
+#### Complete list - https://www.sbert.net/docs/pretrained_models.html
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-v3"))
+
+retriever = EvaluateRetrieval(model, score_function="cos_sim")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/custom/evaluate_custom_dataset_files.py b/examples/retrieval/evaluation/custom/evaluate_custom_dataset_files.py
new file mode 100644
index 0000000..66445ec
--- /dev/null
+++ b/examples/retrieval/evaluation/custom/evaluate_custom_dataset_files.py
@@ -0,0 +1,65 @@
+from beir import LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### METHOD 2 ####
+
+# Provide the path to your CORPUS file, it should be jsonlines format (ref: https://jsonlines.org/)
+# Saved corpus file must have .jsonl extension (for eg: your_corpus_file.jsonl)
+# Corpus file structure:
+# [
+# {"_id": "doc1", "title": "Albert Einstein", "text": "Albert Einstein was a German-born...."},
+# {"_id": "doc2", "title": "", "text": "Wheat beer is a top-fermented beer...."}},
+# ....
+# ]
+corpus_path = "/home/thakur/your-custom-dataset/your_corpus_file.jsonl"
+
+# Provide the path to your QUERY file, it should be jsonlines format (ref: https://jsonlines.org/)
+# Saved query file must have .jsonl extension (for eg: your_query_file.jsonl)
+# Query file structure:
+# [
+# {"_id": "q1", "text": "Who developed the mass-energy equivalence formula?"},
+# {"_id": "q2", "text": "Which beer is brewed with a large proportion of wheat?"},
+# ....
+# ]
+query_path = "/home/thakur/your-custom-dataset/your_query_file.jsonl"
+
+# Provide the path to your QRELS file, it should be tsv or tab-seperated format.
+# Saved qrels file must have .tsv extension (for eg: your_qrels_file.tsv)
+# Qrels file structure: (Keep 1st row as header)
+# query-id corpus-id score
+# q1 doc1 1
+# q2 doc2 1
+# ....
+qrels_path = "/home/thakur/your-custom-dataset/your_qrels_file.tsv"
+
+# Load using load_custom function in GenericDataLoader
+corpus, queries, qrels = GenericDataLoader(
+ corpus_file=corpus_path,
+ query_file=query_path,
+ qrels_file=qrels_path).load_custom()
+
+#### Sentence-Transformer ####
+#### Provide any pretrained sentence-transformers model path
+#### Complete list - https://www.sbert.net/docs/pretrained_models.html
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-v3"))
+
+retriever = EvaluateRetrieval(model, score_function="cos_sim")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/custom/evaluate_custom_metrics.py b/examples/retrieval/evaluation/custom/evaluate_custom_metrics.py
new file mode 100644
index 0000000..899c076
--- /dev/null
+++ b/examples/retrieval/evaluation/custom/evaluate_custom_metrics.py
@@ -0,0 +1,65 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Dense Retrieval using SBERT (Sentence-BERT) ####
+#### Provide any pretrained sentence-transformers model
+#### The model was fine-tuned using cosine-similarity.
+#### Complete list - https://www.sbert.net/docs/pretrained_models.html
+
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-v3"), batch_size=16)
+retriever = EvaluateRetrieval(model, score_function="cos_sim")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K, Recall@K and P@K
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Evaluate your retreival using MRR@K, Recall_cap@K, Hole@K
+mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
+hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+top_k_accuracy = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="top_k_accuracy")
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/custom/evaluate_custom_model.py b/examples/retrieval/evaluation/custom/evaluate_custom_model.py
new file mode 100644
index 0000000..6c1f244
--- /dev/null
+++ b/examples/retrieval/evaluation/custom/evaluate_custom_model.py
@@ -0,0 +1,54 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+from typing import List, Dict
+
+import logging
+import numpy as np
+import pathlib, os
+import random
+
+class YourCustomModel:
+ def __init__(self, model_path=None, **kwargs):
+ self.model = None # ---> HERE Load your custom model
+ # self.model = SentenceTransformer(model_path)
+
+ # Write your own encoding query function (Returns: Query embeddings as numpy array)
+ # For eg ==> return np.asarray(self.model.encode(queries, batch_size=batch_size, **kwargs))
+ def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
+ pass
+
+ # Write your own encoding corpus function (Returns: Document embeddings as numpy array)
+ # For eg ==> sentences = [(doc["title"] + " " + doc["text"]).strip() for doc in corpus]
+ # ==> return np.asarray(self.model.encode(sentences, batch_size=batch_size, **kwargs))
+ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> np.ndarray:
+ pass
+
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+dataset = "nq.zip"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where nfcorpus has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Provide your custom model class name --> HERE
+model = DRES(YourCustomModel(model_path="your-custom-model-path"))
+
+retriever = EvaluateRetrieval(model, score_function="cos_sim") # or "dot" if you wish dot-product
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
diff --git a/examples/retrieval/evaluation/dense/evaluate_ance.py b/examples/retrieval/evaluation/dense/evaluate_ance.py
new file mode 100644
index 0000000..35d65e4
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_ance.py
@@ -0,0 +1,59 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download NFCorpus dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Dense Retrieval using ANCE ####
+# https://www.sbert.net/docs/pretrained-models/msmarco-v3.html
+# MSMARCO Dev Passage Retrieval ANCE(FirstP) 600K model from ANCE.
+# The ANCE model was fine-tuned using dot-product (dot) function.
+
+model = DRES(models.SentenceBERT("msmarco-roberta-base-ance-firstp"))
+retriever = EvaluateRetrieval(model, score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
diff --git a/examples/retrieval/evaluation/dense/evaluate_bpr.py b/examples/retrieval/evaluation/dense/evaluate_bpr.py
new file mode 100644
index 0000000..01625ad
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_bpr.py
@@ -0,0 +1,139 @@
+"""
+The pre-trained models produce embeddings of size 512 - 1024. However, when storing a large
+number of embeddings, this requires quite a lot of memory / storage.
+
+In this example, we convert float embeddings to binary hashes using binary passage retriever (BPR).
+This significantly reduces the required memory / storage while maintaining nearly the same performance.
+
+For more information, please refer to the publication by Yamada et al. in ACL 2021 -
+Efficient Passage Retrieval with Hashing for Open-domain Question Answering, (https://arxiv.org/abs/2106.00882)
+
+For computing binary hashes, we need to train a model with bpr loss function (Margin Ranking Loss + Cross Entropy Loss).
+For more details on training, check train_msmarco_v3_bpr.py on how to train a binary retriever model.
+
+BPR model encoders vectors to 768 dimensions of binary values {1,0} of 768 dim. We pack 8 bits into bytes, this
+further allows a 768 dim (bit) vector to 96 dim byte (int-8) vector.
+for more details on packing refer here: https://numpy.org/doc/stable/reference/generated/numpy.packbits.html
+
+Hence, the new BPR model will produce directly binary hash embeddings without further changes needed. And we
+evaluate the BPR model using BinaryFlat Index in faiss, which computes hamming distance between bits to find top-k
+similarity results. We also rerank top-1000 retrieved from faiss documents with the original query embedding (float)!
+
+The Reranking step is very efficient and fast (as reranking is done by a bi-encoder), hence we advise to rerank
+with top-1000 docs retrieved by hamming distance to decrease the loss in performance!
+
+'''
+model = models.BinarySentenceBERT("msmarco-distilbert-base-tas-b")
+test_corpus = [{"title": "", "text": "Python is a programming language"}]
+print(model.encode_corpus(test_corpus))
+
+>> [[195 86 160 203 135 39 155 173 89 100 107 159 112 94 144 60 57 148
+ 205 15 204 221 181 132 183 242 122 48 108 200 74 221 48 250 12 4
+ 182 165 36 72 101 169 137 227 192 109 136 18 145 5 104 5 221 195
+ 45 254 226 235 109 3 209 156 75 238 143 56 52 227 39 1 144 214
+ 142 120 181 204 166 221 179 88 142 223 110 255 105 44 108 88 47 67
+ 124 126 117 159 37 217]]
+'''
+
+Usage: python evaluate_bpr.py
+"""
+
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import BinaryFaissSearch
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+# Dense Retrieval with Hamming Distance with Binary-Code-SBERT (Sentence-BERT) ####
+# Provide any Binary Passage Retriever Trained model.
+# The model was fine-tuned using CLS Pooling and dot-product!
+# Open-sourced binary code SBERT model trained on MSMARCO to be made available soon!
+
+model = models.BinarySentenceBERT("msmarco-distilbert-base-tas-b") # Proxy for now, soon coming up BPR models trained on MSMARCO!
+faiss_search = BinaryFaissSearch(model, batch_size=128)
+
+#### Load faiss index from file or disk ####
+# We need two files to be present within the input_dir!
+# 1. input_dir/my-index.bin.faiss ({prefix}.{ext}.faiss) which loads the faiss index.
+# 2. input_dir/my-index.bin.tsv ({prefix}.{ext}.faiss) which loads mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
+
+prefix = "my-index" # (default value)
+ext = "bin" # bin for binary (default value)
+input_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
+
+if os.path.isdir(input_dir):
+ faiss_search.load(input_dir=input_dir, prefix=prefix, ext=ext)
+
+# BPR first retrieves binary_k (default 1000) documents based on query hash and document hash similarity with hamming distance!
+# The hamming distance similarity is constructed using IndexBinaryFlat in Faiss.
+# BPR then reranks with dot similarity b/w query embedding and the documents hashes for these binary_k documents.
+# Please Note, Reranking here is done with a bi-encoder which is quite faster compared to cross-encoders.
+# Reranking is advised by the original paper as its quite fast, efficient and leads to decent performances.
+
+score_function = "dot" # or cos_sim for cosine similarity
+retriever = EvaluateRetrieval(faiss_search, score_function=score_function)
+
+rerank = True # False would only retrieve top-k documents based on hamming distance.
+binary_k = 1000 # binary_k value denotes documents reranked for each query.
+
+results = retriever.retrieve(corpus, queries, rerank=rerank, binary_k=binary_k)
+
+### Save faiss index into file or disk ####
+# Unfortunately faiss only supports integer doc-ids!
+# This will mean we need save two files in your output_dir path =>
+# 1. output_dir/my-index.bin.faiss ({prefix}.{ext}.faiss) which saves the faiss index.
+# 2. output_dir/my-index.bin.tsv ({prefix}.{ext}.faiss) which saves mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
+
+prefix = "my-index"
+ext = "bin"
+output_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
+
+os.makedirs(output_dir, exist_ok=True)
+faiss_search.save(output_dir=output_dir, prefix=prefix, ext=ext)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_dim_reduction.py b/examples/retrieval/evaluation/dense/evaluate_dim_reduction.py
new file mode 100644
index 0000000..2ba326b
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_dim_reduction.py
@@ -0,0 +1,137 @@
+"""
+The pre-trained models produce embeddings of size 512 - 1024. However, when storing a large
+number of embeddings, this requires quite a lot of memory / storage.
+
+In this example, we reduce the dimensionality of the embeddings to e.g. 128 dimensions. This significantly
+reduces the required memory / storage while maintaining nearly the same performance.
+
+For dimensionality reduction, we compute embeddings for a large set of (representative) sentence. Then,
+we use PCA to find e.g. 128 principle components of our vector space. This allows us to maintain
+us much information as possible with only 128 dimensions.
+
+PCA gives us a matrix that down-projects vectors to 128 dimensions. We use this matrix
+and extend our original SentenceTransformer model with this linear downproject. Hence,
+the new SentenceTransformer model will produce directly embeddings with 128 dimensions
+without further changes needed.
+
+Usage: python evaluate_dim_reduction.py
+"""
+
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import PCAFaissSearch
+
+import logging
+import pathlib, os
+import random
+import faiss
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "scifact"
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+# Dense Retrieval using Different Faiss Indexes (Flat or ANN) ####
+# Provide any Sentence-Transformer or Dense Retriever model.
+
+model_path = "msmarco-distilbert-base-tas-b"
+model = models.SentenceBERT(model_path)
+
+###############################################################
+#### PCA: Principal Component Analysis (Exhaustive Search) ####
+###############################################################
+# Reduce Input Dimension (768) to output dimension of (128)
+
+output_dimension = 128
+base_index = faiss.IndexFlatIP(output_dimension)
+faiss_search = PCAFaissSearch(model,
+ base_index=base_index,
+ output_dimension=output_dimension,
+ batch_size=128)
+
+#######################################################################
+#### PCA: Principal Component Analysis (with Product Quantization) ####
+#######################################################################
+# Reduce Input Dimension (768) to output dimension of (96)
+
+# output_dimension = 96
+# base_index = faiss.IndexPQ(output_dimension, # output dimension
+# 96, # number of centroids
+# 8, # code size
+# faiss.METRIC_INNER_PRODUCT) # similarity function
+
+# faiss_search = PCAFaissSearch(model,
+# base_index=base_index,
+# output_dimension=output_dimension,
+# batch_size=128)
+
+#### Load faiss index from file or disk ####
+# We need two files to be present within the input_dir!
+# 1. input_dir/{prefix}.{ext}.faiss => which loads the faiss index.
+# 2. input_dir/{prefix}.{ext}.faiss => which loads mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
+
+prefix = "my-index" # (default value)
+ext = "pca" # extension
+
+input_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
+
+if os.path.exists(os.path.join(input_dir, "{}.{}.faiss".format(prefix, ext))):
+ faiss_search.load(input_dir=input_dir, prefix=prefix, ext=ext)
+
+#### Retrieve dense results (format of results is identical to qrels)
+retriever = EvaluateRetrieval(faiss_search, score_function="dot") # or "cos_sim"
+results = retriever.retrieve(corpus, queries)
+
+### Save faiss index into file or disk ####
+# Unfortunately faiss only supports integer doc-ids, We need save two files in output_dir.
+# 1. output_dir/{prefix}.{ext}.faiss => which saves the faiss index.
+# 2. output_dir/{prefix}.{ext}.faiss => which saves mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
+
+prefix = "my-index" # (default value)
+ext = "pca" # extension
+
+output_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
+os.makedirs(output_dir, exist_ok=True)
+
+if not os.path.exists(os.path.join(output_dir, "{}.{}.faiss".format(prefix, ext))):
+ faiss_search.save(output_dir=output_dir, prefix=prefix, ext=ext)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_dpr.py b/examples/retrieval/evaluation/dense/evaluate_dpr.py
new file mode 100644
index 0000000..b88dde0
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_dpr.py
@@ -0,0 +1,87 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download NFCorpus dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Dense Retrieval using Dense Passage Retriever (DPR) ####
+# DPR implements a two-tower strategy i.e. encoding the query and document seperately.
+# The DPR model was fine-tuned using dot-product (dot) function.
+
+#########################################################
+#### 1. Loading DPR model using SentenceTransformers ####
+#########################################################
+# You need to provide a ' [SEP] ' to seperate titles and passages in documents
+# Ref: (https://www.sbert.net/docs/pretrained-models/dpr.html)
+
+model = DRES(models.SentenceBERT((
+ "facebook-dpr-question_encoder-multiset-base",
+ "facebook-dpr-ctx_encoder-multiset-base",
+ " [SEP] "), batch_size=128))
+
+################################################################
+#### 2. Loading Original HuggingFace DPR models by Facebook ####
+################################################################
+# If you do not have your saved model on Sentence Transformers,
+# You can load HF-based DPR models in BEIR.
+# No need to provide seperator token, the model handles automatically!
+
+# model = DRES(models.DPR((
+# "facebook/dpr-question_encoder-multiset-base",
+# "facebook/dpr-ctx_encoder-multiset-base"), batch_size=128))
+
+# You can also load similar trained DPR models available on Hugging Face.
+# For eg. GermanDPR (https://deepset.ai/germanquad)
+
+# model = DRES(models.DPR((
+# "deepset/gbert-base-germandpr-question_encoder",
+# "deepset/gbert-base-germandpr-ctx_encoder"), batch_size=128))
+
+
+retriever = EvaluateRetrieval(model, score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_faiss_dense.py b/examples/retrieval/evaluation/dense/evaluate_faiss_dense.py
new file mode 100644
index 0000000..c24e444
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_faiss_dense.py
@@ -0,0 +1,139 @@
+"""
+In this example, we show how to utilize different faiss indexes for evaluation in BEIR. We currently support
+IndexFlatIP, IndexPQ and IndexHNSW from faiss indexes. Faiss indexes are stored and retrieved using the CPU.
+
+Some good notes for information on different faiss indexes can be found here:
+1. https://github.com/facebookresearch/faiss/wiki/Faiss-indexes#supported-operations
+2. https://github.com/facebookresearch/faiss/wiki/Faiss-building-blocks:-clustering,-PCA,-quantization
+
+For more information, please refer here: https://github.com/facebookresearch/faiss/wiki
+
+PS: You can also save/load your corpus embeddings as a faiss index! Instead of exact search, use FlatIPFaissSearch
+which implements exhaustive search using a faiss index.
+
+Usage: python evaluate_faiss_dense.py
+"""
+
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import PQFaissSearch, HNSWFaissSearch, FlatIPFaissSearch, HNSWSQFaissSearch
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+# Dense Retrieval using Different Faiss Indexes (Flat or ANN) ####
+# Provide any Sentence-Transformer or Dense Retriever model.
+
+model_path = "msmarco-distilbert-base-tas-b"
+model = models.SentenceBERT(model_path)
+
+########################################################
+#### FLATIP: Flat Inner Product (Exhaustive Search) ####
+########################################################
+
+faiss_search = FlatIPFaissSearch(model,
+ batch_size=128)
+
+######################################################
+#### PQ: Product Quantization (Exhaustive Search) ####
+######################################################
+
+# faiss_search = PQFaissSearch(model,
+# batch_size=128,
+# num_of_centroids=96,
+# code_size=8)
+
+#####################################################
+#### HNSW: Approximate Nearest Neighbours Search ####
+#####################################################
+
+# faiss_search = HNSWFaissSearch(model,
+# batch_size=128,
+# hnsw_store_n=512,
+# hnsw_ef_search=128,
+# hnsw_ef_construction=200)
+
+###############################################################
+#### HNSWSQ: Approximate Nearest Neighbours Search with SQ ####
+###############################################################
+
+# faiss_search = HNSWSQFaissSearch(model,
+# batch_size=128,
+# hnsw_store_n=128,
+# hnsw_ef_search=128,
+# hnsw_ef_construction=200)
+
+#### Load faiss index from file or disk ####
+# We need two files to be present within the input_dir!
+# 1. input_dir/{prefix}.{ext}.faiss => which loads the faiss index.
+# 2. input_dir/{prefix}.{ext}.faiss => which loads mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
+
+prefix = "my-index" # (default value)
+ext = "flat" # or "pq", "hnsw", "hnsw-sq"
+input_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
+
+if os.path.exists(os.path.join(input_dir, "{}.{}.faiss".format(prefix, ext))):
+ faiss_search.load(input_dir=input_dir, prefix=prefix, ext=ext)
+
+#### Retrieve dense results (format of results is identical to qrels)
+retriever = EvaluateRetrieval(faiss_search, score_function="dot") # or "cos_sim"
+results = retriever.retrieve(corpus, queries)
+
+### Save faiss index into file or disk ####
+# Unfortunately faiss only supports integer doc-ids, We need save two files in output_dir.
+# 1. output_dir/{prefix}.{ext}.faiss => which saves the faiss index.
+# 2. output_dir/{prefix}.{ext}.faiss => which saves mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
+
+prefix = "my-index" # (default value)
+ext = "flat" # or "pq", "hnsw"
+output_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
+os.makedirs(output_dir, exist_ok=True)
+
+if not os.path.exists(os.path.join(output_dir, "{}.{}.faiss".format(prefix, ext))):
+ faiss_search.save(output_dir=output_dir, prefix=prefix, ext=ext)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_sbert.py b/examples/retrieval/evaluation/dense/evaluate_sbert.py
new file mode 100644
index 0000000..b590fe8
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_sbert.py
@@ -0,0 +1,66 @@
+from time import time
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "trec-covid"
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Dense Retrieval using SBERT (Sentence-BERT) ####
+#### Provide any pretrained sentence-transformers model
+#### The model was fine-tuned using cosine-similarity.
+#### Complete list - https://www.sbert.net/docs/pretrained_models.html
+
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-tas-b"), batch_size=256, corpus_chunk_size=512*9999)
+retriever = EvaluateRetrieval(model, score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+start_time = time()
+results = retriever.retrieve(corpus, queries)
+end_time = time()
+print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_sbert_hf_loader.py b/examples/retrieval/evaluation/dense/evaluate_sbert_hf_loader.py
new file mode 100644
index 0000000..812de6a
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_sbert_hf_loader.py
@@ -0,0 +1,80 @@
+from collections import defaultdict
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader_hf import HFDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalParallelExactSearch as DRPES
+import time
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(level=logging.INFO)
+#### /print debug information to stdout
+
+
+#Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes.
+if __name__ == "__main__":
+
+ dataset = "fiqa"
+
+ #### Download fiqa.zip dataset and unzip the dataset
+ url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+ out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+ data_path = util.download_and_unzip(url, out_dir)
+
+ #### Provide the data path where fiqa has been downloaded and unzipped to the data loader
+ # data folder would contain these files:
+ # (1) fiqa/corpus.jsonl (format: jsonlines)
+ # (2) fiqa/queries.jsonl (format: jsonlines)
+ # (3) fiqa/qrels/test.tsv (format: tsv ("\t"))
+
+ #### Load our locally downloaded datasets via HFDataLoader to save RAM (i.e. do not load the whole corpus in RAM)
+ corpus, queries, qrels = HFDataLoader(data_folder=data_path, streaming=False).load(split="test")
+
+ #### You can use our custom hosted BEIR datasets on HuggingFace again to save RAM (streaming=True) ####
+ # corpus, queries, qrels = HFDataLoader(hf_repo=f"BeIR/{dataset}", streaming=False, keep_in_memory=False).load(split="test")
+
+ #### Dense Retrieval using SBERT (Sentence-BERT) ####
+ #### Provide any pretrained sentence-transformers model
+ #### The model was fine-tuned using cosine-similarity.
+ #### Complete list - https://www.sbert.net/docs/pretrained_models.html
+ beir_model = models.SentenceBERT("msmarco-distilbert-base-tas-b")
+
+ #### Start with Parallel search and evaluation
+ model = DRPES(beir_model, batch_size=512, target_devices=None, corpus_chunk_size=512*2)
+ retriever = EvaluateRetrieval(model, score_function="dot")
+
+ #### Retrieve dense results (format of results is identical to qrels)
+ start_time = time.time()
+ results = retriever.retrieve(corpus, queries)
+ end_time = time.time()
+ print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
+
+ #### Optional: Stop the proccesses in the pool
+ # beir_model.doc_model.stop_multi_process_pool(pool)
+
+ #### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+ logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+ mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+ recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+ hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+ #### Print top-k documents retrieved ####
+ top_k = 10
+
+ query_id, ranking_scores = random.choice(list(results.items()))
+ scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+ query = queries.filter(lambda x: x['id']==query_id)[0]['text']
+ logging.info("Query : %s\n" % query)
+
+ for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ doc = corpus.filter(lambda x: x['id']==doc_id)[0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, doc.get("title"), doc.get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_sbert_multi_gpu.py b/examples/retrieval/evaluation/dense/evaluate_sbert_multi_gpu.py
new file mode 100644
index 0000000..7eca3bb
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_sbert_multi_gpu.py
@@ -0,0 +1,90 @@
+'''
+This sample python shows how to evaluate BEIR dataset quickly using Mutliple GPU for evaluation (for large datasets).
+To run this code, you need Python >= 3.7 (not 3.6) and need to install evaluate library separately: ``pip install evaluate``
+Enabling multi-gpu evaluation has been thanks due to tremendous efforts of Noumane Tazi (https://github.com/NouamaneTazi)
+
+IMPORTANT: The following code will not run with Python 3.6!
+1. Please install Python 3.7 using Anaconda (conda create -n myenv python=3.7)
+2. Next, install Evaluate (https://github.com/huggingface/evaluate) using ``pip install evaluate``.
+
+You are good to go!
+
+To run this code, you preferably need access to mutliple GPUs. Faster than running on single GPU.
+CUDA_VISIBLE_DEVICES=0,1,2,3 python evaluate_sbert_multi_gpu.py
+'''
+
+from collections import defaultdict
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader_hf import HFDataLoader
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalParallelExactSearch as DRPES
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+import time
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(level=logging.INFO)
+#### /print debug information to stdout
+
+
+#Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes.
+if __name__ == "__main__":
+
+ tick = time.time()
+
+ dataset = "nfcorpus"
+ keep_in_memory = False
+ streaming = False
+ corpus_chunk_size = 2048
+ batch_size = 256 # sentence bert model batch size
+ model_name = "msmarco-distilbert-base-tas-b"
+ target_devices = None # ['cpu']*2
+
+ corpus, queries, qrels = HFDataLoader(hf_repo=f"BeIR/{dataset}", streaming=streaming, keep_in_memory=keep_in_memory).load(split="test")
+
+ #### Dense Retrieval using SBERT (Sentence-BERT) ####
+ #### Provide any pretrained sentence-transformers model
+ #### The model was fine-tuned using cosine-similarity.
+ #### Complete list - https://www.sbert.net/docs/pretrained_models.html
+ beir_model = models.SentenceBERT(model_name)
+
+ #### Start with Parallel search and evaluation
+ model = DRPES(beir_model, batch_size=batch_size, target_devices=target_devices, corpus_chunk_size=corpus_chunk_size)
+ retriever = EvaluateRetrieval(model, score_function="dot")
+
+ #### Retrieve dense results (format of results is identical to qrels)
+ start_time = time.time()
+ results = retriever.retrieve(corpus, queries)
+ end_time = time.time()
+ print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
+
+ #### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+ logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+ mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+ recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+ hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+ tock = time.time()
+ print("--- Total time taken: {:.2f} seconds ---".format(tock - tick))
+
+ #### Print top-k documents retrieved ####
+ top_k = 10
+
+ query_id, ranking_scores = random.choice(list(results.items()))
+ scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+ query = queries.filter(lambda x: x['id']==query_id)[0]['text']
+ logging.info("Query : %s\n" % query)
+
+ for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ doc = corpus.filter(lambda x: x['id']==doc_id)[0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, doc.get("title"), doc.get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_tldr.py b/examples/retrieval/evaluation/dense/evaluate_tldr.py
new file mode 100644
index 0000000..a567493
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_tldr.py
@@ -0,0 +1,112 @@
+'''
+In this example, we show how to evaluate TLDR: Twin Learning Dimensionality Reduction using the BEIR Benchmark.
+TLDR is a unsupervised dimension reduction technique, which performs better in comparsion with commonly known: PCA.
+
+In order to run and evaluate the model, it's important to first install the tldr original repository.
+This can be installed conviniently using "pip install tldr".
+
+However, please refer here: https://github.com/naver/tldr for all requirements!
+'''
+
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from sentence_transformers import SentenceTransformer
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os, sys
+import numpy as np
+import torch
+import random
+import importlib.util
+
+if importlib.util.find_spec("tldr") is not None:
+ from tldr import TLDR
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+# #### Download nfcorpus.zip dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+# Get all the corpus documents as a list for tldr training
+corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
+corpus_list = [corpus[cid] for cid in corpus_ids]
+
+# Dense Retrieval with Dimension Reduction with TLDR: Twin Learning Dimensionality Reduction ####
+# TLDR is a dimensionality reduction technique which has been shown to perform better compared to PCA
+# For more details, please refer to the publication: (https://arxiv.org/pdf/2110.09455.pdf)
+# https://europe.naverlabs.com/research/publications/tldr-twin-learning-for-dimensionality-reduction/
+
+# First load the SBERT model, which will be used to create embeddings
+model_path = "sentence-transformers/msmarco-distilbert-base-tas-b"
+
+# Create the TLDR model instance providing the SBERT model path
+tldr = models.TLDR(
+ encoder_model=SentenceTransformer(model_path),
+ n_components=128,
+ n_neighbors=5,
+ encoder="linear",
+ projector="mlp-1-2048",
+ verbose=2,
+ knn_approximation=None,
+ output_folder="data/"
+)
+
+# Starting to train the TLDR model with TAS-B model on the target dataset: nfcorpus
+tldr.fit(corpus=corpus_list, batch_size=128, epochs=100, warmup_epochs=10, train_batch_size=1024, print_every=100)
+logging.info("TLDR model training completed\n")
+
+# You can also save the trained model in the following path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "tldr", "inference_model.pt")
+logging.info("TLDR model saved here: %s\n" % model_save_path)
+tldr.save(model_save_path)
+
+# You can again load back the trained model using the code below:
+tldr = TLDR()
+tldr.load(model_save_path, init=True) # Loads both model parameters and weights
+
+# Now we evaluate the TLDR model using dense retrieval with dot product
+retriever = EvaluateRetrieval(DRES(tldr), score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
+recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
+hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/dense/evaluate_useqa.py b/examples/retrieval/evaluation/dense/evaluate_useqa.py
new file mode 100644
index 0000000..20e7fb8
--- /dev/null
+++ b/examples/retrieval/evaluation/dense/evaluate_useqa.py
@@ -0,0 +1,60 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download NFCorpus dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Dense Retrieval using USE-QA ####
+# https://tfhub.dev/google/universal-sentence-encoder-qa/3
+# We use the English USE-QA v3 and provide the tf-hub url
+# USE-QA implements a two-tower strategy i.e. encoding the query and document seperately.
+# USE-QA provides normalized embeddings, so you can use either dot product or cosine-similarity
+
+model = DRES(models.UseQA("https://tfhub.dev/google/universal-sentence-encoder-qa/3"))
+retriever = EvaluateRetrieval(model, score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/late-interaction/README.md b/examples/retrieval/evaluation/late-interaction/README.md
new file mode 100644
index 0000000..a321960
--- /dev/null
+++ b/examples/retrieval/evaluation/late-interaction/README.md
@@ -0,0 +1,102 @@
+# BEIR Evaluation with ColBERT
+
+In this example, we show how to evaluate the ColBERT zero-shot model on the BEIR Benchmark.
+
+We modify the original [ColBERT](https://github.com/stanford-futuredata/ColBERT) repository to allow for evaluation of ColBERT across any BEIR dataset.
+
+Please follow the required steps to evaluate ColBERT easily across any BEIR dataset.
+
+## Installation with BEIR
+
+- **Step 1**: Clone this beir-ColBERT repository (forked from original) which has modified for evaluating models on the BEIR benchmark:
+```bash
+git clone https://github.com/NThakur20/beir-ColBERT.git
+```
+
+- **Step 2**: Create a new Conda virtual environment using the environment file provided: [conda_env.yml](https://github.com/NThakur20/beir-ColBERT/blob/master/conda_env.yml), It includes pip installation of the beir repository.
+```bash
+# https://github.com/NThakur20/beir-ColBERT#installation
+
+conda env create -f conda_env.yml
+conda activate colbert-v0.2
+```
+ - **Please Note**: We found some issues with ``_swigfaiss`` with both ``faiss-cpu`` and ``faiss-gpu`` installed on Ubuntu. If you face such issues please refer to: https://github.com/facebookresearch/faiss/issues/821#issuecomment-573531694
+
+## ``evaluate_beir.sh``
+
+Run script ``evaluate_beir.sh`` for the complete evaluation of ColBERT model on any BEIR dataset. This scripts has five steps:
+
+**1. BEIR Preprocessing**: We preprocess our BEIR data into ColBERT friendly data format using ``colbert/data_prep.py``. The script converts the original ``jsonl`` format to ``tsv``.
+
+```bash
+python -m colbert.data_prep \
+ --dataset ${dataset} \ # BEIR dataset you want to evaluate, for e.g. nfcorpus
+ --split "test" \ # Split to evaluate on
+ --collection $COLLECTION \ # Path to store collection tsv file
+ --queries $QUERIES \ # Path to store queries tsv file
+```
+
+**2. ColBERT Indexing**: For fast retrieval, indexing precomputes the ColBERT representations of passages.
+
+**NOTE**: you will need to download the trained ColBERT model for inference
+
+```bash
+python -m torch.distributed.launch \
+ --nproc_per_node=2 -m colbert.index \
+ --root $OUTPUT_DIR \ # Directory to store the output logs and ranking files
+ --doc_maxlen 300 \ # We work with 300 sequence length for document (unlike 180 set originally)
+ --mask-punctuation \ # Mask the Punctuation
+ --bsize 128 \ # Batch-size of 128 for encoding documents/tokens.
+ --amp \ # Using Automatic-Mixed Precision (AMP) fp32 -> fp16
+ --checkpoint $CHECKPOINT \ # Path to the checkpoint to the trained ColBERT model
+ --index_root $INDEX_ROOT \ # Path of the root index to store document embeddings
+ --index_name $INDEX_NAME \ # Name of index under which the document embeddings will be stored
+ --collection $COLLECTION \ # Path of the stored collection tsv file
+ --experiment ${dataset} # Keep an experiment name
+```
+**3. FAISS IVFPQ Index**: We store and train the index using an IVFPQ faiss index for end-to-end retrieval.
+
+**NOTE**: You need to choose a different ``k`` number of partitions for IVFPQ for each dataset
+
+```bash
+python -m colbert.index_faiss \
+ --index_root $INDEX_ROOT \ # Path of the root index where the faiss embedding will be store
+ --index_name $INDEX_NAME \ # Name of index under which the faiss embeddings will be stored
+ --partitions $NUM_PARTITIONS \ # Number of Partitions for IVFPQ index (Seperate for each dataset (You need to chose)), for eg. 96 for NFCorpus
+ --sample 0.3 \ # sample: 0.3
+ --root $OUTPUT_DIR \ # Directory to store the output logs and ranking files
+ --experiment ${dataset} # Keep an experiment name
+```
+
+**4. Query Retrieval using ColBERT**: Retrieves top-_k_ documents, where depth = _k_ for each query.
+
+**NOTE**: The output ``ranking.tsv`` file produced has integer document ids (because of faiss). Each each int corresponds to the doc_id position in the original collection tsv file.
+
+```bash
+python -m colbert.retrieve \
+ --amp \ # Using Automatic-Mixed Precision (AMP) fp32 -> fp16
+ --doc_maxlen 300 \ # We work with 300 sequence length for document (unlike 180 set originally)
+ --mask-punctuation \ # Mask the Punctuation
+ --bsize 256 \ # 256 batch-size for evaluation
+ --queries $QUERIES \ # Path which contains the store queries tsv file
+ --nprobe 32 \ # 32 query tokens are considered
+ --partitions $NUM_PARTITIONS \ # Number of Partitions for IVFPQ index
+ --faiss_depth 100 \ # faiss_depth of 100 is used for evaluation (Roughly 100 top-k nearest neighbours are used for retrieval)
+ --depth 100 \ # Depth is kept at 100 to keep 100 documents per query in ranking file
+ --index_root $INDEX_ROOT \ # Path of the root index of the stored IVFPQ index of the faiss embeddings
+ --index_name $INDEX_NAME \ # Name of index under which the faiss embeddings will be stored
+ --checkpoint $CHECKPOINT \ # Path to the checkpoint to the trained ColBERT model
+ --root $OUTPUT_DIR \ # Directory to store the output logs and ranking files
+ --experiment ${dataset} \ # Keep an experiment name
+ --ranking_dir $RANKING_DIR # Ranking Directory will store the final ranking results as ranking.tsv file
+```
+
+**5. Evaluation using BEIR**: Evaluate the ``ranking.tsv`` file using the BEIR evaluation script for any dataset.
+
+```bash
+python -m colbert.beir_eval \
+ --dataset ${dataset} \ # BEIR dataset you want to evaluate, for e.g. nfcorpus
+ --split "test" \ # Split to evaluate on
+ --collection $COLLECTION \ # Path of the stored collection tsv file
+ --rankings "${RANKING_DIR}/ranking.tsv" # Path to store the final ranking tsv file
+```
diff --git a/examples/retrieval/evaluation/lexical/evaluate_anserini_bm25.py b/examples/retrieval/evaluation/lexical/evaluate_anserini_bm25.py
new file mode 100644
index 0000000..3df3194
--- /dev/null
+++ b/examples/retrieval/evaluation/lexical/evaluate_anserini_bm25.py
@@ -0,0 +1,92 @@
+"""
+This example shows how to evaluate Anserini-BM25 in BEIR.
+Since Anserini uses Java-11, we would advise you to use docker for running Pyserini.
+To be able to run the code below you must have docker locally installed in your machine.
+To install docker on your local machine, please refer here: https://docs.docker.com/get-docker/
+
+After docker installation, please follow the steps below to get docker container up and running:
+
+1. docker pull beir/pyserini-fastapi
+2. docker build -t pyserini-fastapi .
+3. docker run -p 8000:8000 -it --rm pyserini-fastapi
+
+Once the docker container is up and running in local, now run the code below.
+This code doesn't require GPU to run.
+
+Usage: python evaluate_anserini_bm25.py
+"""
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+
+import pathlib, os, json
+import logging
+import requests
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#### Convert BEIR corpus to Pyserini Format #####
+pyserini_jsonl = "pyserini.jsonl"
+with open(os.path.join(data_path, pyserini_jsonl), 'w', encoding="utf-8") as fOut:
+ for doc_id in corpus:
+ title, text = corpus[doc_id].get("title", ""), corpus[doc_id].get("text", "")
+ data = {"id": doc_id, "title": title, "contents": text}
+ json.dump(data, fOut)
+ fOut.write('\n')
+
+#### Download Docker Image beir/pyserini-fastapi ####
+#### Locally run the docker Image + FastAPI ####
+docker_beir_pyserini = "http://127.0.0.1:8000"
+
+#### Upload Multipart-encoded files ####
+with open(os.path.join(data_path, "pyserini.jsonl"), "rb") as fIn:
+ r = requests.post(docker_beir_pyserini + "/upload/", files={"file": fIn}, verify=False)
+
+#### Index documents to Pyserini #####
+index_name = "beir/you-index-name" # beir/scifact
+r = requests.get(docker_beir_pyserini + "/index/", params={"index_name": index_name})
+
+#### Retrieve documents from Pyserini #####
+retriever = EvaluateRetrieval()
+qids = list(queries)
+query_texts = [queries[qid] for qid in qids]
+payload = {"queries": query_texts, "qids": qids, "k": max(retriever.k_values)}
+
+#### Retrieve pyserini results (format of results is identical to qrels)
+results = json.loads(requests.post(docker_beir_pyserini + "/lexical/batch_search/", json=payload).text)["results"]
+
+#### Retrieve RM3 expanded pyserini results (format of results is identical to qrels)
+# results = json.loads(requests.post(docker_beir_pyserini + "/lexical/rm3/batch_search/", json=payload).text)["results"]
+
+#### Check if query_id is in results i.e. remove it from docs incase if it appears ####
+#### Quite Important for ArguAna and Quora ####
+for query_id in results:
+ if query_id in results[query_id]:
+ results[query_id].pop(query_id, None)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Retrieval Example ####
+query_id, scores_dict = random.choice(list(results.items()))
+logging.info("Query : %s\n" % queries[query_id])
+
+scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
+for rank in range(10):
+ doc_id = scores[rank][0]
+ logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/lexical/evaluate_bm25.py b/examples/retrieval/evaluation/lexical/evaluate_bm25.py
new file mode 100644
index 0000000..7417ab0
--- /dev/null
+++ b/examples/retrieval/evaluation/lexical/evaluate_bm25.py
@@ -0,0 +1,84 @@
+"""
+This example show how to evaluate BM25 model (Elasticsearch) in BEIR.
+To be able to run Elasticsearch, you should have it installed locally (on your desktop) along with ``pip install beir``.
+Depending on your OS, you would be able to find how to download Elasticsearch. I like this guide for Ubuntu 18.04 -
+https://linuxize.com/post/how-to-install-elasticsearch-on-ubuntu-18-04/
+For more details, please refer here - https://www.elastic.co/downloads/elasticsearch.
+
+This code doesn't require GPU to run.
+
+If unable to get it running locally, you could try the Google Colab Demo, where we first install elastic search locally and retrieve using BM25
+https://colab.research.google.com/drive/1HfutiEhHMJLXiWGT8pcipxT5L2TpYEdt?usp=sharing#scrollTo=nqotyXuIBPt6
+
+
+Usage: python evaluate_bm25.py
+"""
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+
+import pathlib, os, random
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where scifact has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) scifact/corpus.jsonl (format: jsonlines)
+# (2) scifact/queries.jsonl (format: jsonlines)
+# (3) scifact/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#### Lexical Retrieval using Bm25 (Elasticsearch) ####
+#### Provide a hostname (localhost) to connect to ES instance
+#### Define a new index name or use an already existing one.
+#### We use default ES settings for retrieval
+#### https://www.elastic.co/
+
+hostname = "your-hostname" #localhost
+index_name = "your-index-name" # scifact
+
+#### Intialize ####
+# (1) True - Delete existing index and re-index all documents from scratch
+# (2) False - Load existing index
+initialize = True # False
+
+#### Sharding ####
+# (1) For datasets with small corpus (datasets ~ < 5k docs) => limit shards = 1
+# SciFact is a relatively small dataset! (limit shards to 1)
+number_of_shards = 1
+model = BM25(index_name=index_name, hostname=hostname, initialize=initialize, number_of_shards=number_of_shards)
+
+# (2) For datasets with big corpus ==> keep default configuration
+# model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
+retriever = EvaluateRetrieval(model)
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Retrieval Example ####
+query_id, scores_dict = random.choice(list(results.items()))
+logging.info("Query : %s\n" % queries[query_id])
+
+scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
+for rank in range(10):
+ doc_id = scores[rank][0]
+ logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/lexical/evaluate_multilingual_bm25.py b/examples/retrieval/evaluation/lexical/evaluate_multilingual_bm25.py
new file mode 100644
index 0000000..ef72f52
--- /dev/null
+++ b/examples/retrieval/evaluation/lexical/evaluate_multilingual_bm25.py
@@ -0,0 +1,92 @@
+"""
+This example show how to evaluate BM25 model (Elasticsearch) in BEIR for German.
+This script can be used to any evaluate any language by just changing language name.
+To find languages supported by Elasticsearch, please refer below:
+https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-lang-analyzer.html
+
+To be able to run Elasticsearch, you should have it installed locally (on your desktop) along with ``pip install beir``.
+Depending on your OS, you would be able to find how to download Elasticsearch. I like this guide for Ubuntu 18.04 -
+https://linuxize.com/post/how-to-install-elasticsearch-on-ubuntu-18-04/
+For more details, please refer here - https://www.elastic.co/downloads/elasticsearch.
+
+This code doesn't require GPU to run.
+
+If unable to get it running locally, you could try the Google Colab Demo, where we first install elastic search locally and retrieve using BM25
+https://colab.research.google.com/drive/1HfutiEhHMJLXiWGT8pcipxT5L2TpYEdt?usp=sharing#scrollTo=nqotyXuIBPt6
+
+
+Usage: python evaluate_multilingual_bm25.py
+"""
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+
+import pathlib, os, random
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "germanquad"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where scifact has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) scifact/corpus.jsonl (format: jsonlines)
+# (2) scifact/queries.jsonl (format: jsonlines)
+# (3) scifact/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#### Lexical Retrieval using Bm25 (Elasticsearch) ####
+#### Provide a hostname (localhost) to connect to ES instance
+#### Define a new index name or use an already existing one.
+#### We use default ES settings for retrieval
+#### https://www.elastic.co/
+
+hostname = "your-hostname" #localhost
+index_name = "your-index-name" # germanquad
+
+#### Intialize ####
+# (1) True - Delete existing index and re-index all documents from scratch
+# (2) False - Load existing index
+initialize = True # False
+
+#### Language ####
+# For languages supported by Elasticsearch by default, check here ->
+# https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-lang-analyzer.html
+language = "german" # Please provide full names in lowercase for eg. english, hindi ...
+
+#### Sharding ####
+# (1) For datasets with small corpus (datasets ~ < 5k docs) => limit shards = 1
+number_of_shards = 1
+model = BM25(index_name=index_name, hostname=hostname, language=language, initialize=initialize, number_of_shards=number_of_shards)
+
+# (2) For datasets with big corpus ==> keep default configuration
+# model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
+retriever = EvaluateRetrieval(model)
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Retrieval Example ####
+query_id, scores_dict = random.choice(list(results.items()))
+logging.info("Query : %s\n" % queries[query_id])
+
+scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
+for rank in range(10):
+ doc_id = scores[rank][0]
+ logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/reranking/README.md b/examples/retrieval/evaluation/reranking/README.md
new file mode 100644
index 0000000..6ed5aa1
--- /dev/null
+++ b/examples/retrieval/evaluation/reranking/README.md
@@ -0,0 +1,34 @@
+### Re-ranking BM25 top-100 using Cross-Encoder (Leaderboard)
+
+In table below, we evaluate various different reranking architectures and evaluate them based on performance and speed. We include the following model architectures -
+
+- [MiniLM](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html)
+- [TinyBERT](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html)
+
+
+| Reranking-Model |Docs / Sec| MSMARCO | TREC-COVID | BIOASQ |NFCORPUS| NQ |HOTPOT-QA| FIQA |SIGNAL-1M|
+| ---------------------------------- |:------: | :-----: | :--------: | :-----:|:------:| :--: |:------:| :--: |:--------:|
+| **MiniLM Models** | |
+| cross-encoder/ms-marco-MiniLM-L-2-v2 | 4100 | 0.373 | 0.669 | 0.471 | 0.337 |0.465 | 0.655 | 0.278| 0.334 |
+| cross-encoder/ms-marco-MiniLM-L-4-v2 | 2500 | 0.392 | 0.720 | 0.516 | 0.358 |0.509 | 0.699 | 0.327| 0.350 |
+| cross-encoder/ms-marco-MiniLM-L-6-v2 | 1800 | 0.401 | 0.722 | 0.529 | 0.360 |0.530 | 0.712 | 0.334| 0.351 |
+| cross-encoder/ms-marco-MiniLM-L-12-v2 | 960 | 0.401 | 0.737 | 0.532 | 0.339 |0.531 | 0.717 | 0.336| 0.348 |
+| **TinyBERT Models** | |
+| cross-encoder/ms-marco-TinyBERT-L-2-v2 | 9000 | 0.354 | 0.689 | 0.466 | 0.346 |0.444 | 0.650 | 0.270| 0.338 |
+| cross-encoder/ms-marco-TinyBERT-L-4 | 2900 | 0.371 | 0.640 | 0.470 | 0.323 | | 0.679 | 0.260| 0.312 |
+| cross-encoder/ms-marco-TinyBERT-L-6 | 680 | 0.380 | 0.652 | 0.473 | 0.339 | | 0.682 | 0.305| 0.314 |
+| cross-encoder/ms-marco-electra-base | 340 | 0.384 | 0.667 | 0.489 | 0.303 |0.516 | 0.701 | 0.326| 0.308 |
+
+
+| Reranking-Model |Docs / Sec| TREC-NEWS |ArguAna| Touche'20| DBPedia |SCIDOCS| FEVER |Clim.-FEVER| SciFact |
+| ----------------------------------- |:-------: | :-------: |:-----:| :-----: | :-----: |:-----:| :---: |:--------: | :-----: |
+| **MiniLM Models** | |
+| cross-encoder/ms-marco-MiniLM-L-2-v2 | 4100 | 0.417 | 0.157 | 0.363 | 0.502 | 0.145 | 0.759 | 0.215 | 0.607 |
+| cross-encoder/ms-marco-MiniLM-L-4-v2 | 2500 | 0.431 | 0.430 | 0.371 | 0.531 | 0.156 | 0.775 | 0.228 | 0.680 |
+| cross-encoder/ms-marco-MiniLM-L-6-v2 | 1800 | 0.436 | 0.415 | 0.349 | 0.542 | 0.164 | 0.802 | 0.240 | 0.682 |
+| cross-encoder/ms-marco-MiniLM-L-12-v2 | 960 | 0.451 | 0.333 | 0.378 | 0.541 | 0.165 | 0.814 | 0.250 | 0.680 |
+| **TinyBERT Models** | |
+| cross-encoder/ms-marco-TinyBERT-L-2-v2 | 9000 | 0.385 | 0.341 | 0.311 | 0.497 | 0.151 | 0.647 | 0.173 | 0.662 |
+| cross-encoder/ms-marco-TinyBERT-L-4 | 2900 | 0.377 | 0.398 | 0.333 | 0.350 | 0.149 | 0.760 | 0.194 | 0.658 |
+| cross-encoder/ms-marco-TinyBERT-L-6 | 680 | 0.418 | 0.480 | 0.375 | 0.371 | 0.143 | 0.789 | 0.237 | 0.645 |
+| cross-encoder/ms-marco-electra-base | 340 | 0.430 | 0.313 | 0.378 | 0.380 | 0.154 | 0.793 | 0.246 | 0.524 |
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/reranking/evaluate_bm25_ce_reranking.py b/examples/retrieval/evaluation/reranking/evaluate_bm25_ce_reranking.py
new file mode 100644
index 0000000..bc144ff
--- /dev/null
+++ b/examples/retrieval/evaluation/reranking/evaluate_bm25_ce_reranking.py
@@ -0,0 +1,78 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+from beir.reranking.models import CrossEncoder
+from beir.reranking import Rerank
+
+import pathlib, os
+import logging
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download trec-covid.zip dataset and unzip the dataset
+dataset = "trec-covid"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where trec-covid has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) trec-covid/corpus.jsonl (format: jsonlines)
+# (2) trec-covid/queries.jsonl (format: jsonlines)
+# (3) trec-covid/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#########################################
+#### (1) RETRIEVE Top-100 docs using BM25
+#########################################
+
+#### Provide parameters for Elasticsearch
+hostname = "your-hostname" #localhost
+index_name = "your-index-name" # trec-covid
+initialize = True # False
+
+model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
+retriever = EvaluateRetrieval(model)
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+################################################
+#### (2) RERANK Top-100 docs using Cross-Encoder
+################################################
+
+#### Reranking using Cross-Encoder models #####
+#### https://www.sbert.net/docs/pretrained_cross-encoders.html
+cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-electra-base')
+
+#### Or use MiniLM, TinyBERT etc. CE models (https://www.sbert.net/docs/pretrained-models/ce-msmarco.html)
+# cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
+# cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')
+
+reranker = Rerank(cross_encoder_model, batch_size=128)
+
+# Rerank top-100 results using the reranker provided
+rerank_results = reranker.rerank(corpus, queries, results, top_k=100)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(rerank_results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/reranking/evaluate_bm25_monot5_reranking.py b/examples/retrieval/evaluation/reranking/evaluate_bm25_monot5_reranking.py
new file mode 100644
index 0000000..0900336
--- /dev/null
+++ b/examples/retrieval/evaluation/reranking/evaluate_bm25_monot5_reranking.py
@@ -0,0 +1,95 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+from beir.reranking.models import MonoT5
+from beir.reranking import Rerank
+
+import pathlib, os
+import logging
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download trec-covid.zip dataset and unzip the dataset
+dataset = "trec-covid"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where trec-covid has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) trec-covid/corpus.jsonl (format: jsonlines)
+# (2) trec-covid/queries.jsonl (format: jsonlines)
+# (3) trec-covid/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#########################################
+#### (1) RETRIEVE Top-100 docs using BM25
+#########################################
+
+#### Provide parameters for Elasticsearch
+hostname = "your-hostname" #localhost
+index_name = "your-index-name" # trec-covid
+initialize = True # False
+
+model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
+retriever = EvaluateRetrieval(model)
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+##############################################
+#### (2) RERANK Top-100 docs using MonoT5 ####
+##############################################
+
+#### Reranking using MonoT5 model #####
+# Document Ranking with a Pretrained Sequence-to-Sequence Model
+# https://aclanthology.org/2020.findings-emnlp.63/
+
+#### Check below for reference parameters for different MonoT5 models
+#### Two tokens: token_false, token_true
+# 1. 'castorini/monot5-base-msmarco': ['▁false', '▁true']
+# 2. 'castorini/monot5-base-msmarco-10k': ['▁false', '▁true']
+# 3. 'castorini/monot5-large-msmarco': ['▁false', '▁true']
+# 4. 'castorini/monot5-large-msmarco-10k': ['▁false', '▁true']
+# 5. 'castorini/monot5-base-med-msmarco': ['▁false', '▁true']
+# 6. 'castorini/monot5-3b-med-msmarco': ['▁false', '▁true']
+# 7. 'unicamp-dl/mt5-base-en-msmarco': ['▁no' , '▁yes']
+# 8. 'unicamp-dl/ptt5-base-pt-msmarco-10k-v2': ['▁não' , '▁sim']
+# 9. 'unicamp-dl/ptt5-base-pt-msmarco-100k-v2': ['▁não' , '▁sim']
+# 10.'unicamp-dl/ptt5-base-en-pt-msmarco-100k-v2':['▁não' , '▁sim']
+# 11.'unicamp-dl/mt5-base-en-pt-msmarco-v2': ['▁no' , '▁yes']
+# 12.'unicamp-dl/mt5-base-mmarco-v2': ['▁no' , '▁yes']
+# 13.'unicamp-dl/mt5-base-en-pt-msmarco-v1': ['▁no' , '▁yes']
+# 14.'unicamp-dl/mt5-base-mmarco-v1': ['▁no' , '▁yes']
+# 15.'unicamp-dl/ptt5-base-pt-msmarco-10k-v1': ['▁não' , '▁sim']
+# 16.'unicamp-dl/ptt5-base-pt-msmarco-100k-v1': ['▁não' , '▁sim']
+# 17.'unicamp-dl/ptt5-base-en-pt-msmarco-10k-v1': ['▁não' , '▁sim']
+
+cross_encoder_model = MonoT5('castorini/monot5-base-msmarco', token_false='▁false', token_true='▁true')
+reranker = Rerank(cross_encoder_model, batch_size=128)
+
+# # Rerank top-100 results using the reranker provided
+rerank_results = reranker.rerank(corpus, queries, results, top_k=100)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(rerank_results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/reranking/evaluate_bm25_sbert_reranking.py b/examples/retrieval/evaluation/reranking/evaluate_bm25_sbert_reranking.py
new file mode 100644
index 0000000..b6b13d4
--- /dev/null
+++ b/examples/retrieval/evaluation/reranking/evaluate_bm25_sbert_reranking.py
@@ -0,0 +1,59 @@
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.lexical import BM25Search as BM25
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+from beir.retrieval import models
+
+import pathlib, os
+import logging
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+dataset = "trec-covid"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where nfcorpus has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#### Provide parameters for elastic-search
+hostname = "your-hostname" #localhost
+index_name = "your-index-name" # nfcorpus
+initialize = True
+
+model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
+retriever = EvaluateRetrieval(model)
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Reranking top-100 docs using Dense Retriever model
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-v3"), batch_size=128)
+dense_retriever = EvaluateRetrieval(model, score_function="cos_sim", k_values=[1,3,5,10,100])
+
+#### Retrieve dense results (format of results is identical to qrels)
+rerank_results = dense_retriever.rerank(corpus, queries, results, top_k=100)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+ndcg, _map, recall, precision, hole = dense_retriever.evaluate(qrels, rerank_results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(rerank_results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query.py b/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query.py
new file mode 100644
index 0000000..9d3e3be
--- /dev/null
+++ b/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query.py
@@ -0,0 +1,128 @@
+"""
+This example shows how to evaluate DocTTTTTquery in BEIR.
+
+Since Anserini uses Java-11, we would advise you to use docker for running Pyserini.
+To be able to run the code below you must have docker locally installed in your machine.
+To install docker on your local machine, please refer here: https://docs.docker.com/get-docker/
+
+After docker installation, you can start the needed docker container with the following command:
+docker run -p 8000:8000 -it --rm beir/pyserini-fastapi
+
+Once the docker container is up and running in local, now run the code below.
+
+For the example, we use the "castorini/doc2query-t5-base-msmarco" model for query generation.
+In this example, We generate 3 questions per passage and append them with passage used for BM25 retrieval.
+
+Usage: python evaluate_anserini_docT5query.py
+"""
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.generation.models import QGenModel
+from tqdm.autonotebook import trange
+
+import pathlib, os, json
+import logging
+import requests
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+corpus_ids = list(corpus.keys())
+corpus_list = [corpus[doc_id] for doc_id in corpus_ids]
+
+################################
+#### 1. Question-Generation ####
+################################
+
+#### docTTTTTquery model to generate synthetic questions.
+#### Synthetic questions will get prepended with document.
+#### Ref: https://github.com/castorini/docTTTTTquery
+
+model_path = "castorini/doc2query-t5-base-msmarco"
+qgen_model = QGenModel(model_path, use_fast=False)
+
+gen_queries = {}
+num_return_sequences = 3 # We have seen 3-5 questions being diverse!
+batch_size = 80 # bigger the batch-size, faster the generation!
+
+for start_idx in trange(0, len(corpus_list), batch_size, desc='question-generation'):
+
+ size = len(corpus_list[start_idx:start_idx + batch_size])
+ ques = qgen_model.generate(
+ corpus=corpus_list[start_idx:start_idx + batch_size],
+ ques_per_passage=num_return_sequences,
+ max_length=64,
+ top_p=0.95,
+ top_k=10)
+
+ assert len(ques) == size * num_return_sequences
+
+ for idx in range(size):
+ start_id = idx * num_return_sequences
+ end_id = start_id + num_return_sequences
+ gen_queries[corpus_ids[start_idx + idx]] = ques[start_id: end_id]
+
+#### Convert BEIR corpus to Pyserini Format #####
+pyserini_jsonl = "pyserini.jsonl"
+with open(os.path.join(data_path, pyserini_jsonl), 'w', encoding="utf-8") as fOut:
+ for doc_id in corpus:
+ title, text = corpus[doc_id].get("title", ""), corpus[doc_id].get("text", "")
+ query_text = " ".join(gen_queries[doc_id])
+ data = {"id": doc_id, "title": title, "contents": text, "queries": query_text}
+ json.dump(data, fOut)
+ fOut.write('\n')
+
+#### Download Docker Image beir/pyserini-fastapi ####
+#### Locally run the docker Image + FastAPI ####
+docker_beir_pyserini = "http://127.0.0.1:8000"
+
+#### Upload Multipart-encoded files ####
+with open(os.path.join(data_path, "pyserini.jsonl"), "rb") as fIn:
+ r = requests.post(docker_beir_pyserini + "/upload/", files={"file": fIn}, verify=False)
+
+#### Index documents to Pyserini #####
+index_name = "beir/you-index-name" # beir/scifact
+r = requests.get(docker_beir_pyserini + "/index/", params={"index_name": index_name})
+
+######################################
+#### 2. Pyserini-Retrieval (BM25) ####
+######################################
+
+#### Retrieve documents from Pyserini #####
+retriever = EvaluateRetrieval()
+qids = list(queries)
+query_texts = [queries[qid] for qid in qids]
+payload = {"queries": query_texts, "qids": qids, "k": max(retriever.k_values),
+ "fields": {"contents": 1.0, "title": 1.0, "queries": 1.0}}
+
+#### Retrieve pyserini results (format of results is identical to qrels)
+results = json.loads(requests.post(docker_beir_pyserini + "/lexical/batch_search/", json=payload).text)["results"]
+
+#### Retrieve RM3 expanded pyserini results (format of results is identical to qrels)
+# results = json.loads(requests.post(docker_beir_pyserini + "/lexical/rm3/batch_search/", json=payload).text)["results"]
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Retrieval Example ####
+query_id, scores_dict = random.choice(list(results.items()))
+logging.info("Query : %s\n" % queries[query_id])
+
+scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
+for rank in range(10):
+ doc_id = scores[rank][0]
+ logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
diff --git a/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query_parallel.py b/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query_parallel.py
new file mode 100644
index 0000000..2dc4a33
--- /dev/null
+++ b/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query_parallel.py
@@ -0,0 +1,211 @@
+"""
+This example shows how to evaluate docTTTTTquery in BEIR.
+
+Since Anserini uses Java 11, we would advise you to use docker for running Pyserini.
+To be able to run the code below you must have docker locally installed in your machine.
+To install docker on your local machine, please refer here: https://docs.docker.com/get-docker/
+
+After docker installation, please follow the steps below to get docker container up and running:
+
+1. docker pull docker pull beir/pyserini-fastapi
+2. docker build -t pyserini-fastapi .
+3. docker run -p 8000:8000 -it --rm pyserini-fastapi
+
+Once the docker container is up and running in local, now run the code below.
+
+For the example, we use the "castorini/doc2query-t5-base-msmarco" model for query generation.
+In this example, we generate 3 questions per passage and append them with passage used for BM25 retrieval.
+
+Usage: python evaluate_anserini_docT5query.py --dataset
+"""
+
+import argparse
+import json
+import logging
+import os
+import pathlib
+import random
+import requests
+import torch
+import torch.multiprocessing as mp
+
+from tqdm import tqdm
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.generation.models import QGenModel
+
+CHUNK_SIZE_MP = 100
+CHUNK_SIZE_GPU = 64 # memory-bound, this should work for most GPUs
+DEVICE_CPU = 'cpu'
+DEVICE_GPU = 'cuda'
+NUM_QUERIES_PER_PASSAGE = 5
+PYSERINI_URL = "http://127.0.0.1:8000"
+
+DEFAULT_MODEL_ID = 'BeIR/query-gen-msmarco-t5-base-v1' # https://huggingface.co/BeIR/query-gen-msmarco-t5-base-v1
+DEFAULT_DEVICE = DEVICE_GPU
+
+# noinspection PyArgumentList
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO, handlers=[LoggingHandler()])
+
+
+def init_process(device, model_id):
+ """Initializes a worker process."""
+
+ global model
+
+ if device == DEVICE_GPU:
+ # Assign the GPU process ID to bind this process to a specific GPU
+ # This is a bit fragile and relies on CUDA ordinals being the same
+ # See: https://stackoverflow.com/questions/63564028/multiprocess-pool-initialization-with-sequential-initializer-argument
+ proc_id = int(mp.current_process().name.split('-')[1]) - 1
+ device = f'{DEVICE_GPU}:{proc_id}'
+
+ model = QGenModel(model_id, use_fast=True, device=device)
+
+
+def _decide_device(cpu_procs):
+ """Based on command line arguments, sets the device and number of processes to use."""
+
+ if cpu_procs:
+ return DEVICE_CPU, cpu_procs
+ else:
+ assert torch.cuda.is_available(), "No GPUs available. Please set --cpu-procs or make GPUs available"
+ try:
+ mp.set_start_method('spawn')
+ except RuntimeError:
+ pass
+ return DEVICE_GPU, torch.cuda.device_count()
+
+
+def _download_dataset(dataset):
+ """Downloads a dataset and unpacks it on disk."""
+
+ url = 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip'.format(dataset)
+ out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), 'datasets')
+ return util.download_and_unzip(url, out_dir)
+
+
+def _generate_query(corpus_list):
+ """Generates a set of queries for a given document."""
+
+ documents = [document for _, document in corpus_list]
+ generated_queries = model.generate(corpus=documents,
+ ques_per_passage=NUM_QUERIES_PER_PASSAGE,
+ max_length=64,
+ temperature=1,
+ top_k=10)
+
+ for i, (_, document) in enumerate(corpus_list):
+ start_index = i * NUM_QUERIES_PER_PASSAGE
+ end_index = start_index + NUM_QUERIES_PER_PASSAGE
+ document["queries"] = generated_queries[start_index:end_index]
+
+ return dict(corpus_list)
+
+
+def _add_generated_queries_to_corpus(num_procs, device, model_id, corpus):
+ """Using a pool of workers, generate queries to add to each document in the corpus."""
+
+ # Chunk input so we can maximize the use of our GPUs
+ corpus_list = list(corpus.items())
+ chunked_corpus = [corpus_list[pos:pos + CHUNK_SIZE_GPU] for pos in range(0, len(corpus_list), CHUNK_SIZE_GPU)]
+
+ pool = mp.Pool(num_procs, initializer=init_process, initargs=(device, model_id))
+ for partial_corpus in tqdm(pool.imap_unordered(_generate_query, chunked_corpus, chunksize=CHUNK_SIZE_MP), total=len(chunked_corpus)):
+ corpus.update(partial_corpus)
+
+ return corpus
+
+
+def _write_pyserini_corpus(pyserini_index_file, corpus):
+ """Writes the in-memory corpus to disk in the Pyserini format."""
+
+ with open(pyserini_index_file, 'w', encoding='utf-8') as fOut:
+ for doc_id, document in corpus.items():
+ data = {
+ 'id': doc_id,
+ 'title': document.get('title', ''),
+ 'contents': document.get('text', ''),
+ 'queries': ' '.join(document.get('queries', '')),
+ }
+ json.dump(data, fOut)
+ fOut.write('\n')
+
+
+def _index_pyserini(pyserini_index_file, dataset):
+ """Uploads a Pyserini index file and indexes it into Lucene."""
+
+ with open(pyserini_index_file, 'rb') as fIn:
+ r = requests.post(f'{PYSERINI_URL}/upload/', files={'file': fIn}, verify=False)
+
+ r = requests.get(f'{PYSERINI_URL}/index/', params={'index_name': f'beir/{dataset}'})
+
+
+def _search_pyserini(queries, k):
+ """Searches an index in Pyserini in bulk."""
+
+ qids = list(queries)
+ query_texts = [queries[qid] for qid in qids]
+ payload = {
+ 'queries': query_texts,
+ 'qids': qids,
+ 'k': k,
+ 'fields': {'contents': 1.0, 'title': 1.0, 'queries': 1.0},
+ }
+
+ r = requests.post(f'{PYSERINI_URL}/lexical/batch_search/', json=payload)
+ return json.loads(r.text)['results']
+
+
+def _print_retrieval_examples(corpus, queries, results):
+ """Prints retrieval examples for inspection."""
+
+ query_id, scores_dict = random.choice(list(results.items()))
+ logging.info(f"Query: {queries[query_id]}\n")
+
+ scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
+ for rank in range(10):
+ doc_id = scores[rank][0]
+ logging.info(
+ "Doc %d: %s [%s] - %s\n" % (rank + 1, doc_id, corpus[doc_id].get('title'), corpus[doc_id].get('text')))
+
+
+def main():
+ parser = argparse.ArgumentParser(prog='evaluate_anserini_docT5query_parallel')
+ parser.add_argument('--dataset', required=True, help=f"The dataset to use. Example: scifact")
+ parser.add_argument('--model-id',
+ default=DEFAULT_MODEL_ID, help=f"The model ID to use. Default: {DEFAULT_MODEL_ID}")
+ parser.add_argument('--cpu-procs', default=None, type=int,
+ help=f"Use CPUs instead of GPUs and use this number of cores. Leaving this unset (default) "
+ "will use all available GPUs. Default: None")
+ args = parser.parse_args()
+
+ device, num_procs = _decide_device(args.cpu_procs)
+
+ # Download and load the dataset into memory
+ data_path = _download_dataset(args.dataset)
+ pyserini_index_file = os.path.join(data_path, 'pyserini.jsonl')
+ corpus, queries, qrels = GenericDataLoader(data_path).load(split='test')
+
+ # Generate queries per document and create Pyserini index file if does not exist yet
+ if not os.path.isfile(pyserini_index_file):
+ _add_generated_queries_to_corpus(num_procs, device, args.model_id, corpus)
+ _write_pyserini_corpus(pyserini_index_file, corpus)
+
+ # Index into Pyserini
+ _index_pyserini(pyserini_index_file, args.dataset)
+
+ # Retrieve and evaluate
+ retriever = EvaluateRetrieval()
+ results = _search_pyserini(queries, k=max(retriever.k_values))
+ retriever.evaluate(qrels, results, retriever.k_values)
+
+ _print_retrieval_examples(corpus, queries, results)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/retrieval/evaluation/sparse/evaluate_deepct.py b/examples/retrieval/evaluation/sparse/evaluate_deepct.py
new file mode 100644
index 0000000..d4a5dfe
--- /dev/null
+++ b/examples/retrieval/evaluation/sparse/evaluate_deepct.py
@@ -0,0 +1,136 @@
+"""
+This example shows how to evaluate DeepCT (using Anserini) in BEIR.
+For more details on DeepCT, refer here: https://arxiv.org/abs/1910.10687
+
+The original DeepCT repository is not modularised and only works with Tensorflow 1.x (1.15).
+We modified the DeepCT repository to work with Tensorflow latest (2.x).
+We do not change the core-prediction code, only few input/output file format and structure to adapt to BEIR formats.
+For more details on changes, check: https://github.com/NThakur20/DeepCT and compare it with original repo!
+
+Please follow the steps below to install DeepCT:
+
+1. git clone https://github.com/NThakur20/DeepCT.git
+
+Since Anserini uses Java-11, we would advise you to use docker for running Pyserini.
+To be able to run the code below you must have docker locally installed in your machine.
+To install docker on your local machine, please refer here: https://docs.docker.com/get-docker/
+
+After docker installation, please follow the steps below to get docker container up and running:
+
+1. docker pull docker pull beir/pyserini-fastapi
+2. docker build -t pyserini-fastapi .
+3. docker run -p 8000:8000 -it --rm pyserini-fastapi
+
+Usage: python evaluate_deepct.py
+"""
+from DeepCT.deepct import run_deepct # git clone https://github.com/NThakur20/DeepCT.git
+
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.generation.models import QGenModel
+from tqdm.autonotebook import trange
+
+import pathlib, os, json
+import logging
+import requests
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download scifact.zip dataset and unzip the dataset
+dataset = "scifact"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
+
+#### 1. Download Google BERT-BASE, Uncased model ####
+# Ref: https://github.com/google-research/bert
+
+base_model_url = "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip"
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "models")
+bert_base_dir = util.download_and_unzip(base_model_url, out_dir)
+
+#### 2. Download DeepCT MSMARCO Trained BERT checkpoint ####
+# Credits to DeepCT authors: Zhuyun Dai, Jamie Callan, (https://github.com/AdeDZY/DeepCT)
+
+model_url = "http://boston.lti.cs.cmu.edu/appendices/arXiv2019-DeepCT-Zhuyun-Dai/outputs/marco.zip"
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "models")
+checkpoint_dir = util.download_and_unzip(model_url, out_dir)
+
+##################################################
+#### 3. Configure Params for DeepCT inference ####
+##################################################
+# We cannot use the original Repo (https://github.com/AdeDZY/DeepCT) as it only runs with TF 1.15.
+# We reformatted the code (https://github.com/NThakur20/DeepCT) and made it working with latest TF 2.X!
+
+if not os.path.isfile(os.path.join(data_path, "deepct.jsonl")):
+ ################################
+ #### Command-Line Arugments ####
+ ################################
+ run_deepct.FLAGS.task_name = "beir" # Defined a seperate BEIR task in DeepCT. Check out run_deepct.
+ run_deepct.FLAGS.do_train = False # We only want to use the code for inference.
+ run_deepct.FLAGS.do_eval = False # No evaluation.
+ run_deepct.FLAGS.do_predict = True # True, as we would use DeepCT model for only prediction.
+ run_deepct.FLAGS.data_dir = os.path.join(data_path, "corpus.jsonl") # Provide original path to corpus data, follow beir format.
+ run_deepct.FLAGS.vocab_file = os.path.join(bert_base_dir, "vocab.txt") # Provide bert-base-uncased model vocabulary.
+ run_deepct.FLAGS.bert_config_file = os.path.join(bert_base_dir, "bert_config.json") # Provide bert-base-uncased config.json file.
+ run_deepct.FLAGS.init_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-65816") # Provide DeepCT MSMARCO model (bert-base-uncased) checkpoint file.
+ run_deepct.FLAGS.max_seq_length = 350 # Provide Max Sequence Length used for consideration. (Max: 512)
+ run_deepct.FLAGS.train_batch_size = 128 # Inference batch size, Larger more Memory but faster!
+ run_deepct.FLAGS.output_dir = data_path # Output directory, this will contain two files: deepct.jsonl (output-file) and predict.tf_record
+ run_deepct.FLAGS.output_file = "deepct.jsonl" # Output file for storing final DeepCT produced corpus.
+ run_deepct.FLAGS.m = 100 # Scaling parameter for DeepCT weights: scaling parameter > 0, recommend 100
+ run_deepct.FLAGS.smoothing = "sqrt" # Use sqrt to smooth weights. DeepCT Paper uses None.
+ run_deepct.FLAGS.keep_all_terms = True # Do not allow DeepCT to delete terms.
+
+ # Runs DeepCT model on the corpus.jsonl
+ run_deepct.main()
+
+#### Download Docker Image beir/pyserini-fastapi ####
+#### Locally run the docker Image + FastAPI ####
+docker_beir_pyserini = "http://127.0.0.1:8000"
+
+#### Upload Multipart-encoded files ####
+with open(os.path.join(data_path, "deepct.jsonl"), "rb") as fIn:
+ r = requests.post(docker_beir_pyserini + "/upload/", files={"file": fIn}, verify=False)
+
+#### Index documents to Pyserini #####
+index_name = "beir/you-index-name" # beir/scifact
+r = requests.get(docker_beir_pyserini + "/index/", params={"index_name": index_name})
+
+######################################
+#### 2. Pyserini-Retrieval (BM25) ####
+######################################
+
+#### Retrieve documents from Pyserini #####
+retriever = EvaluateRetrieval()
+qids = list(queries)
+query_texts = [queries[qid] for qid in qids]
+payload = {"queries": query_texts, "qids": qids, "k": max(retriever.k_values),
+ "fields": {"contents": 1.0}, "bm25": {"k1": 18, "b": 0.7}}
+
+#### Retrieve pyserini results (format of results is identical to qrels)
+results = json.loads(requests.post(docker_beir_pyserini + "/lexical/batch_search/", json=payload).text)["results"]
+
+#### Retrieve RM3 expanded pyserini results (format of results is identical to qrels)
+# results = json.loads(requests.post(docker_beir_pyserini + "/lexical/rm3/batch_search/", json=payload).text)["results"]
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Retrieval Example ####
+query_id, scores_dict = random.choice(list(results.items()))
+logging.info("Query : %s\n" % queries[query_id])
+
+scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
+for rank in range(10):
+ doc_id = scores[rank][0]
+ logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
diff --git a/examples/retrieval/evaluation/sparse/evaluate_sparta.py b/examples/retrieval/evaluation/sparse/evaluate_sparta.py
new file mode 100644
index 0000000..f01c0f3
--- /dev/null
+++ b/examples/retrieval/evaluation/sparse/evaluate_sparta.py
@@ -0,0 +1,56 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.sparse import SparseSearch
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "scifact"
+
+#### Download scifact dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where scifact has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) scifact/corpus.jsonl (format: jsonlines)
+# (2) scifact/queries.jsonl (format: jsonlines)
+# (3) scifact/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Sparse Retrieval using SPARTA ####
+model_path = "BeIR/sparta-msmarco-distilbert-base-v1"
+sparse_model = SparseSearch(models.SPARTA(model_path), batch_size=128)
+retriever = EvaluateRetrieval(sparse_model)
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/sparse/evaluate_splade.py b/examples/retrieval/evaluation/sparse/evaluate_splade.py
new file mode 100644
index 0000000..ad119d8
--- /dev/null
+++ b/examples/retrieval/evaluation/sparse/evaluate_splade.py
@@ -0,0 +1,68 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download NFCorpus dataset and unzip the dataset
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### SPARSE Retrieval using SPLADE ####
+# The SPLADE model provides a weight for each query token and document token
+# The final score is taken using a dot-product between the weights of the common tokens.
+# To learn more, please refer to the link below:
+# https://europe.naverlabs.com/blog/splade-a-sparse-bi-encoder-bert-based-model-achieves-effective-and-efficient-first-stage-ranking/
+
+#################################################
+#### 1. Loading SPLADE model from NAVER LABS ####
+#################################################
+# Sadly, the model weights from SPLADE are not on huggingface etc.
+# The SPLADE v1 model weights are available on their original repo: (https://github.com/naver/splade)
+
+# First clone SPLADE GitHub repo: git clone https://github.com/naver/splade.git
+# NOTE: this version only works for max agg in SPLADE!
+
+model_path = "splade/weights/distilsplade_max"
+model = DRES(models.SPLADE(model_path), batch_size=128)
+retriever = EvaluateRetrieval(model, score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/evaluation/sparse/evaluate_unicoil.py b/examples/retrieval/evaluation/sparse/evaluate_unicoil.py
new file mode 100644
index 0000000..168f2c0
--- /dev/null
+++ b/examples/retrieval/evaluation/sparse/evaluate_unicoil.py
@@ -0,0 +1,66 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.sparse import SparseSearch
+
+import logging
+import pathlib, os
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+dataset = "nfcorpus"
+
+#### Download NFCorpus dataset and unzip the dataset
+# url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+# out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+# data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
+# data folder would contain these files:
+# (1) nfcorpus/corpus.jsonl (format: jsonlines)
+# (2) nfcorpus/queries.jsonl (format: jsonlines)
+# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
+data_path= "/home/ukp/thakur/projects/sbert_retriever/datasets-new/{}".format(dataset)
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### SPARSE Retrieval using uniCOIL ####
+# uniCOIL implementes an architecture similar to COIL, SPLADE.
+# It computes a weight for each token in query and document
+# Finally a dot product is used to evaluate between similar query and document tokens.
+
+####################################################
+#### 1. Loading uniCOIL model using HuggingFace ####
+####################################################
+# We download the publicly available uniCOIL model from the HF repository
+# For more details on how the model works, please refer: (https://arxiv.org/abs/2106.14807)
+
+model_path = "castorini/unicoil-d2q-msmarco-passage"
+model = SparseSearch(models.UniCOIL(model_path=model_path), batch_size=32)
+retriever = EvaluateRetrieval(model, score_function="dot")
+
+#### Retrieve dense results (format of results is identical to qrels)
+results = retriever.retrieve(corpus, queries, query_weights=True)
+
+#### Evaluate your retrieval using NDCG@k, MAP@K ...
+
+logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
+ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
+
+#### Print top-k documents retrieved ####
+top_k = 10
+
+query_id, ranking_scores = random.choice(list(results.items()))
+scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
+logging.info("Query : %s\n" % queries[query_id])
+
+for rank in range(top_k):
+ doc_id = scores_sorted[rank][0]
+ # Format: Rank x: ID [Title] Body
+ logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
\ No newline at end of file
diff --git a/examples/retrieval/training/train_msmarco_v2.py b/examples/retrieval/training/train_msmarco_v2.py
new file mode 100644
index 0000000..a8f1499
--- /dev/null
+++ b/examples/retrieval/training/train_msmarco_v2.py
@@ -0,0 +1,107 @@
+'''
+"""
+This examples show how to train a Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
+The model is trained with BM25 (only lexical) sampled hard negatives provided by the SentenceTransformers Repo.
+
+This example has been taken from here with few modifications to train SBERT (MSMARCO-v2) models:
+(https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder-v2.py)
+
+The queries and passages are passed independently to the transformer network to produce fixed sized embeddings.
+These embeddings can then be compared using cosine-similarity to find matching passages for a given query.
+
+For training, we use MultipleNegativesRankingLoss. There, we pass triplets in the format:
+(query, positive_passage, negative_passage)
+
+Negative passage are hard negative examples, that where retrieved by lexical search. We use the negative
+passages (the triplets) that are provided by the MS MARCO dataset.
+
+Running this script:
+python train_msmarco_v2.py
+'''
+
+from sentence_transformers import SentenceTransformer, models, losses
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.train import TrainRetriever
+import pathlib, os, gzip
+import logging
+import json
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download msmarco.zip dataset and unzip the dataset
+dataset = "msmarco"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Please Note not all datasets contain a dev split, comment out the line if such the case
+dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")
+
+########################################
+#### Download MSMARCO Triplets File ####
+########################################
+
+train_batch_size = 75 # Increasing the train batch size improves the model performance, but requires more GPU memory (O(n^2))
+max_seq_length = 350 # Max length for passages. Increasing it, requires more GPU memory (O(n^4))
+
+# The triplets file contains 5,028,051 sentence pairs (ref: https://sbert.net/datasets/paraphrases)
+triplets_url = "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/msmarco-query_passage_negative.jsonl.gz"
+msmarco_triplets_filepath = os.path.join(data_path, "msmarco-triplets.jsonl.gz")
+
+if not os.path.isfile(msmarco_triplets_filepath):
+ util.download_url(triplets_url, msmarco_triplets_filepath)
+
+#### The triplets file contains tab seperated triplets in each line =>
+# 1. train query (text), 2. positive doc (text), 3. hard negative doc (text)
+triplets = []
+with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:
+ for line in fIn:
+ triplet = json.loads(line)
+ triplets.append(triplet)
+
+#### Provide any sentence-transformers or HF model
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Provide a high batch-size to train better with triplets!
+retriever = TrainRetriever(model=model, batch_size=train_batch_size)
+
+#### Prepare triplets samples
+train_samples = retriever.load_train_triplets(triplets=triplets)
+train_dataloader = retriever.prepare_train_triplets(train_samples)
+
+#### Training SBERT with cosine-product
+train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
+# #### training SBERT with dot-product
+# # train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)
+
+#### Prepare dev evaluator
+ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)
+
+#### If no dev set is present from above use dummy evaluator
+# ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v2-{}".format(model_name, dataset))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 1
+evaluation_steps = 10000
+warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_steps,
+ use_amp=True)
diff --git a/examples/retrieval/training/train_msmarco_v3.py b/examples/retrieval/training/train_msmarco_v3.py
new file mode 100644
index 0000000..05c49f7
--- /dev/null
+++ b/examples/retrieval/training/train_msmarco_v3.py
@@ -0,0 +1,171 @@
+'''
+This example shows how to train a SOTA Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
+The model is trained using hard negatives which were specially mined with different dense and lexical search methods for MSMARCO.
+
+This example has been taken from here with few modifications to train SBERT (MSMARCO-v3) models:
+(https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder-v3.py)
+
+The queries and passages are passed independently to the transformer network to produce fixed sized embeddings.
+These embeddings can then be compared using cosine-similarity to find matching passages for a given query.
+
+For training, we use MultipleNegativesRankingLoss. There, we pass triplets in the format:
+(query, positive_passage, negative_passage)
+
+Negative passage are hard negative examples, that were mined using different dense embedding methods and lexical search methods.
+Each positive and negative passage comes with a score from a Cross-Encoder. This allows denoising, i.e. removing false negative
+passages that are actually relevant for the query.
+
+Running this script:
+python train_msmarco_v3.py
+'''
+
+from sentence_transformers import SentenceTransformer, models, losses, InputExample
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.train import TrainRetriever
+from torch.utils.data import Dataset
+from tqdm.autonotebook import tqdm
+import pathlib, os, gzip, json
+import logging
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download msmarco.zip dataset and unzip the dataset
+dataset = "msmarco"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+### Load BEIR MSMARCO training dataset, this will be used for query and corpus for reference.
+corpus, queries, _ = GenericDataLoader(data_path).load(split="train")
+
+#################################
+#### Parameters for Training ####
+#################################
+
+train_batch_size = 75 # Increasing the train batch size improves the model performance, but requires more GPU memory (O(n))
+max_seq_length = 350 # Max length for passages. Increasing it, requires more GPU memory (O(n^2))
+ce_score_margin = 3 # Margin for the CrossEncoder score between negative and positive passages
+num_negs_per_system = 5 # We used different systems to mine hard negatives. Number of hard negatives to add from each system
+
+##################################################
+#### Download MSMARCO Hard Negs Triplets File ####
+##################################################
+
+triplets_url = "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz"
+msmarco_triplets_filepath = os.path.join(data_path, "msmarco-hard-negatives.jsonl.gz")
+if not os.path.isfile(msmarco_triplets_filepath):
+ util.download_url(triplets_url, msmarco_triplets_filepath)
+
+#### Load the hard negative MSMARCO jsonl triplets from SBERT
+#### These contain a ce-score which denotes the cross-encoder score for the query and passage.
+#### We chose a margin between positive and negative passage scores => above which consider negative as hard negative.
+#### Finally to limit the number of negatives per passage, we define num_negs_per_system across all different systems.
+
+logging.info("Loading MSMARCO hard-negatives...")
+
+train_queries = {}
+with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:
+ for line in tqdm(fIn, total=502939):
+ data = json.loads(line)
+
+ #Get the positive passage ids
+ pos_pids = [item['pid'] for item in data['pos']]
+ pos_min_ce_score = min([item['ce-score'] for item in data['pos']])
+ ce_score_threshold = pos_min_ce_score - ce_score_margin
+
+ #Get the hard negatives
+ neg_pids = set()
+ for system_negs in data['neg'].values():
+ negs_added = 0
+ for item in system_negs:
+ if item['ce-score'] > ce_score_threshold:
+ continue
+
+ pid = item['pid']
+ if pid not in neg_pids:
+ neg_pids.add(pid)
+ negs_added += 1
+ if negs_added >= num_negs_per_system:
+ break
+
+ if len(pos_pids) > 0 and len(neg_pids) > 0:
+ train_queries[data['qid']] = {'query': queries[data['qid']], 'pos': pos_pids, 'hard_neg': list(neg_pids)}
+
+logging.info("Train queries: {}".format(len(train_queries)))
+
+# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
+# on-the-fly based on the information from the mined-hard-negatives jsonl file.
+
+class MSMARCODataset(Dataset):
+ def __init__(self, queries, corpus):
+ self.queries = queries
+ self.queries_ids = list(queries.keys())
+ self.corpus = corpus
+
+ for qid in self.queries:
+ self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
+ self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
+ random.shuffle(self.queries[qid]['hard_neg'])
+
+ def __getitem__(self, item):
+ query = self.queries[self.queries_ids[item]]
+ query_text = query['query']
+
+ pos_id = query['pos'].pop(0) #Pop positive and add at end
+ pos_text = self.corpus[pos_id]["text"]
+ query['pos'].append(pos_id)
+
+ neg_id = query['hard_neg'].pop(0) #Pop negative and add at end
+ neg_text = self.corpus[neg_id]["text"]
+ query['hard_neg'].append(neg_id)
+
+ return InputExample(texts=[query_text, pos_text, neg_text])
+
+ def __len__(self):
+ return len(self.queries)
+
+# We construct the SentenceTransformer bi-encoder from scratch with mean-pooling
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Provide a high batch-size to train better with triplets!
+retriever = TrainRetriever(model=model, batch_size=train_batch_size)
+
+# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
+train_dataset = MSMARCODataset(train_queries, corpus=corpus)
+train_dataloader = retriever.prepare_train(train_dataset, shuffle=True, dataset_present=True)
+
+#### Training SBERT with cosine-product (default)
+train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
+
+#### training SBERT with dot-product
+# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score, scale=1)
+
+#### If no dev set is present from above use dummy evaluator
+ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v3-{}".format(model_name, dataset))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 10
+evaluation_steps = 10000
+warmup_steps = 1000
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_steps,
+ use_amp=True)
\ No newline at end of file
diff --git a/examples/retrieval/training/train_msmarco_v3_bpr.py b/examples/retrieval/training/train_msmarco_v3_bpr.py
new file mode 100644
index 0000000..67f042a
--- /dev/null
+++ b/examples/retrieval/training/train_msmarco_v3_bpr.py
@@ -0,0 +1,174 @@
+'''
+This example shows how to train a Binary-Code (Binary Passage Retriever) based Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
+The model is trained using hard negatives which were specially mined with different dense and lexical search methods for MSMARCO.
+
+The idea for Binary Passage Retriever originated by Yamada et. al, 2021 in Efficient Passage Retrieval with Hashing for Open-domain Question Answering.
+For more details, please refer here: https://arxiv.org/abs/2106.00882
+
+This example has been taken from here with few modifications to train SBERT (MSMARCO-v3) models:
+(https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder-v3.py)
+
+The queries and passages are passed independently to the transformer network to produce fixed sized binary codes or hashes!!
+These embeddings can then be compared using hamming distances to find matching passages for a given query.
+
+For training, we use BPRLoss (MarginRankingLoss + MultipleNegativesRankingLoss). There, we pass triplets in the format:
+(query, positive_passage, negative_passage)
+
+Negative passage are hard negative examples, that were mined using different dense embedding methods and lexical search methods.
+Each positive and negative passage comes with a score from a Cross-Encoder. This allows denoising, i.e. removing false negative
+passages that are actually relevant for the query.
+
+Running this script:
+python train_msmarco_v3_bpr.py
+'''
+
+from sentence_transformers import SentenceTransformer, models, InputExample
+from beir import util, LoggingHandler
+from beir.losses import BPRLoss
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.train import TrainRetriever
+from torch.utils.data import Dataset
+from tqdm.autonotebook import tqdm
+import pathlib, os, gzip, json
+import logging
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download msmarco.zip dataset and unzip the dataset
+dataset = "msmarco"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+### Load BEIR MSMARCO training dataset, this will be used for query and corpus for reference.
+corpus, queries, _ = GenericDataLoader(data_path).load(split="train")
+
+#################################
+#### Parameters for Training ####
+#################################
+
+train_batch_size = 75 # Increasing the train batch size improves the model performance, but requires more GPU memory (O(n))
+max_seq_length = 350 # Max length for passages. Increasing it, requires more GPU memory (O(n^2))
+ce_score_margin = 3 # Margin for the CrossEncoder score between negative and positive passages
+num_negs_per_system = 5 # We used different systems to mine hard negatives. Number of hard negatives to add from each system
+
+##################################################
+#### Download MSMARCO Hard Negs Triplets File ####
+##################################################
+
+triplets_url = "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz"
+msmarco_triplets_filepath = os.path.join(data_path, "msmarco-hard-negatives.jsonl.gz")
+if not os.path.isfile(msmarco_triplets_filepath):
+ util.download_url(triplets_url, msmarco_triplets_filepath)
+
+#### Load the hard negative MSMARCO jsonl triplets from SBERT
+#### These contain a ce-score which denotes the cross-encoder score for the query and passage.
+#### We chose a margin between positive and negative passage scores => above which consider negative as hard negative.
+#### Finally to limit the number of negatives per passage, we define num_negs_per_system across all different systems.
+
+logging.info("Loading MSMARCO hard-negatives...")
+
+train_queries = {}
+with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:
+ for line in tqdm(fIn, total=502939):
+ data = json.loads(line)
+
+ #Get the positive passage ids
+ pos_pids = [item['pid'] for item in data['pos']]
+ pos_min_ce_score = min([item['ce-score'] for item in data['pos']])
+ ce_score_threshold = pos_min_ce_score - ce_score_margin
+
+ #Get the hard negatives
+ neg_pids = set()
+ for system_negs in data['neg'].values():
+ negs_added = 0
+ for item in system_negs:
+ if item['ce-score'] > ce_score_threshold:
+ continue
+
+ pid = item['pid']
+ if pid not in neg_pids:
+ neg_pids.add(pid)
+ negs_added += 1
+ if negs_added >= num_negs_per_system:
+ break
+
+ if len(pos_pids) > 0 and len(neg_pids) > 0:
+ train_queries[data['qid']] = {'query': queries[data['qid']], 'pos': pos_pids, 'hard_neg': list(neg_pids)}
+
+logging.info("Train queries: {}".format(len(train_queries)))
+
+# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
+# on-the-fly based on the information from the mined-hard-negatives jsonl file.
+
+class MSMARCODataset(Dataset):
+ def __init__(self, queries, corpus):
+ self.queries = queries
+ self.queries_ids = list(queries.keys())
+ self.corpus = corpus
+
+ for qid in self.queries:
+ self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
+ self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
+ random.shuffle(self.queries[qid]['hard_neg'])
+
+ def __getitem__(self, item):
+ query = self.queries[self.queries_ids[item]]
+ query_text = query['query']
+
+ pos_id = query['pos'].pop(0) #Pop positive and add at end
+ pos_text = self.corpus[pos_id]["text"]
+ query['pos'].append(pos_id)
+
+ neg_id = query['hard_neg'].pop(0) #Pop negative and add at end
+ neg_text = self.corpus[neg_id]["text"]
+ query['hard_neg'].append(neg_id)
+
+ return InputExample(texts=[query_text, pos_text, neg_text])
+
+ def __len__(self):
+ return len(self.queries)
+
+# We construct the SentenceTransformer bi-encoder from scratch with CLS token Pooling
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
+ pooling_mode_cls_token=True,
+ pooling_mode_mean_tokens=False)
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Provide a high batch-size to train better with triplets!
+retriever = TrainRetriever(model=model, batch_size=train_batch_size)
+
+# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
+train_dataset = MSMARCODataset(train_queries, corpus=corpus)
+train_dataloader = retriever.prepare_train(train_dataset, shuffle=True, dataset_present=True)
+
+#### Training SBERT with dot-product (default)
+train_loss = BPRLoss(model=retriever.model)
+
+#### If no dev set is present from above use dummy evaluator
+ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v3-{}".format(model_name, dataset))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 10
+evaluation_steps = 10000
+warmup_steps = 1000
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_steps,
+ use_amp=True)
\ No newline at end of file
diff --git a/examples/retrieval/training/train_msmarco_v3_margin_MSE.py b/examples/retrieval/training/train_msmarco_v3_margin_MSE.py
new file mode 100644
index 0000000..7784304
--- /dev/null
+++ b/examples/retrieval/training/train_msmarco_v3_margin_MSE.py
@@ -0,0 +1,170 @@
+'''
+This example shows how to train a SOTA Bi-Encoder with Margin-MSE loss for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
+
+In this example we use a knowledge distillation setup. Sebastian Hofstätter et al. trained in https://arxiv.org/abs/2010.02666 an
+an ensemble of large Transformer models for the MS MARCO datasets and combines the scores from a BERT-base, BERT-large, and ALBERT-large model.
+
+We use the MSMARCO Hard Negatives File (Provided by Nils Reimers): https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz
+Negative passage are hard negative examples, that were mined using different dense embedding, cross-encoder methods and lexical search methods.
+Contains upto 50 negatives for each of the four retrieval systems: [bm25, msmarco-distilbert-base-tas-b, msmarco-MiniLM-L-6-v3, msmarco-distilbert-base-v3]
+Each positive and negative passage comes with a score from a Cross-Encoder (msmarco-MiniLM-L-6-v3). This allows denoising, i.e. removing false negative
+passages that are actually relevant for the query.
+
+This example has been taken from here with few modifications to train SBERT (MSMARCO-v3) models:
+(https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder-v3.py)
+
+The queries and passages are passed independently to the transformer network to produce fixed sized embeddings.
+These embeddings can then be compared using dot-product to find matching passages for a given query.
+
+For training, we use Margin MSE Loss. There, we pass triplets in the format:
+triplets: (query, positive_passage, negative_passage)
+label: positive_ce_score - negative_ce_score => (ce-score b/w query and positive or negative_passage)
+
+PS: Using Margin MSE Loss doesn't require a threshold, or to set maximum negatives per system (required for Multiple Ranking Negative Loss)!
+This is often a cumbersome process to find the optimal threshold which is dependent for Multiple Negative Ranking Loss.
+
+Running this script:
+python train_msmarco_v3_margin_MSE.py
+'''
+
+from sentence_transformers import SentenceTransformer, models, InputExample
+from beir import util, LoggingHandler, losses
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.train import TrainRetriever
+from torch.utils.data import Dataset
+from tqdm.autonotebook import tqdm
+import pathlib, os, gzip, json
+import logging
+import random
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download msmarco.zip dataset and unzip the dataset
+dataset = "msmarco"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+### Load BEIR MSMARCO training dataset, this will be used for query and corpus for reference.
+corpus, queries, _ = GenericDataLoader(data_path).load(split="train")
+
+#################################
+#### Parameters for Training ####
+#################################
+
+train_batch_size = 75 # Increasing the train batch size improves the model performance, but requires more GPU memory (O(n))
+max_seq_length = 350 # Max length for passages. Increasing it, requires more GPU memory (O(n^2))
+
+##################################################
+#### Download MSMARCO Hard Negs Triplets File ####
+##################################################
+
+triplets_url = "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz"
+msmarco_triplets_filepath = os.path.join(data_path, "msmarco-hard-negatives.jsonl.gz")
+if not os.path.isfile(msmarco_triplets_filepath):
+ util.download_url(triplets_url, msmarco_triplets_filepath)
+
+#### Load the hard negative MSMARCO jsonl triplets from SBERT
+#### These contain a ce-score which denotes the cross-encoder score for the query and passage.
+
+logging.info("Loading MSMARCO hard-negatives...")
+
+train_queries = {}
+with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:
+ for line in tqdm(fIn, total=502939):
+ data = json.loads(line)
+
+ #Get the positive passage ids
+ pos_pids = [item['pid'] for item in data['pos']]
+ pos_scores = dict(zip(pos_pids, [item['ce-score'] for item in data['pos']]))
+
+ #Get all the hard negatives
+ neg_pids = set()
+ neg_scores = {}
+ for system_negs in data['neg'].values():
+ for item in system_negs:
+ pid = item['pid']
+ score = item['ce-score']
+ if pid not in neg_pids:
+ neg_pids.add(pid)
+ neg_scores[pid] = score
+
+ if len(pos_pids) > 0 and len(neg_pids) > 0:
+ train_queries[data['qid']] = {'query': queries[data['qid']], 'pos': pos_pids, 'pos_scores': pos_scores,
+ 'hard_neg': neg_pids, 'hard_neg_scores': neg_scores}
+
+logging.info("Train queries: {}".format(len(train_queries)))
+
+# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
+# on-the-fly based on the information from the mined-hard-negatives jsonl file.
+
+class MSMARCODataset(Dataset):
+ def __init__(self, queries, corpus):
+ self.queries = queries
+ self.queries_ids = list(queries.keys())
+ self.corpus = corpus
+
+ for qid in self.queries:
+ self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
+ self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
+ random.shuffle(self.queries[qid]['hard_neg'])
+
+ def __getitem__(self, item):
+ query = self.queries[self.queries_ids[item]]
+ query_text = query['query']
+
+ pos_id = query['pos'].pop(0) #Pop positive and add at end
+ pos_text = self.corpus[pos_id]["text"]
+ query['pos'].append(pos_id)
+ pos_score = float(query['pos_scores'][pos_id])
+
+ neg_id = query['hard_neg'].pop(0) #Pop negative and add at end
+ neg_text = self.corpus[neg_id]["text"]
+ query['hard_neg'].append(neg_id)
+ neg_score = float(query['hard_neg_scores'][neg_id])
+
+ return InputExample(texts=[query_text, pos_text, neg_text], label=(pos_score - neg_score))
+
+ def __len__(self):
+ return len(self.queries)
+
+# We construct the SentenceTransformer bi-encoder from scratch with mean-pooling
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Provide a high batch-size to train better with triplets!
+retriever = TrainRetriever(model=model, batch_size=train_batch_size)
+
+# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
+train_dataset = MSMARCODataset(train_queries, corpus=corpus)
+train_dataloader = retriever.prepare_train(train_dataset, shuffle=True, dataset_present=True)
+
+#### Training SBERT with dot-product (default) using Margin MSE Loss
+train_loss = losses.MarginMSELoss(model=retriever.model)
+
+#### If no dev set is present from above use dummy evaluator
+ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v3-margin-MSE-loss-{}".format(model_name, dataset))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 11
+evaluation_steps = 10000
+warmup_steps = 1000
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_steps,
+ use_amp=True)
\ No newline at end of file
diff --git a/examples/retrieval/training/train_sbert.py b/examples/retrieval/training/train_sbert.py
new file mode 100644
index 0000000..375b5a9
--- /dev/null
+++ b/examples/retrieval/training/train_sbert.py
@@ -0,0 +1,83 @@
+'''
+This examples show how to train a basic Bi-Encoder for any BEIR dataset without any mined hard negatives or triplets.
+
+The queries and passages are passed independently to the transformer network to produce fixed sized embeddings.
+These embeddings can then be compared using cosine-similarity to find matching passages for a given query.
+
+For training, we use MultipleNegativesRankingLoss. There, we pass pairs in the format:
+(query, positive_passage). Other positive passages within a single batch becomes negatives given the pos passage.
+
+We do not mine hard negatives or train triplets in this example.
+
+Running this script:
+python train_sbert.py
+'''
+
+from sentence_transformers import losses, models, SentenceTransformer
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.train import TrainRetriever
+import pathlib, os
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+dataset = "nfcorpus"
+
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where nfcorpus has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="train")
+#### Please Note not all datasets contain a dev split, comment out the line if such the case
+dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")
+
+#### Provide any sentence-transformers or HF model
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=350)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Or provide pretrained sentence-transformer model
+# model = SentenceTransformer("msmarco-distilbert-base-v3")
+
+retriever = TrainRetriever(model=model, batch_size=16)
+
+#### Prepare training samples
+train_samples = retriever.load_train(corpus, queries, qrels)
+train_dataloader = retriever.prepare_train(train_samples, shuffle=True)
+
+#### Training SBERT with cosine-product
+train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
+#### training SBERT with dot-product
+# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)
+
+#### Prepare dev evaluator
+ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)
+
+#### If no dev set is present from above use dummy evaluator
+# ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v1-{}".format(model_name, dataset))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 1
+evaluation_steps = 5000
+warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_steps,
+ use_amp=True)
\ No newline at end of file
diff --git a/examples/retrieval/training/train_sbert_BM25_hardnegs.py b/examples/retrieval/training/train_sbert_BM25_hardnegs.py
new file mode 100644
index 0000000..c7006a4
--- /dev/null
+++ b/examples/retrieval/training/train_sbert_BM25_hardnegs.py
@@ -0,0 +1,129 @@
+'''
+This examples show how to train a Bi-Encoder for any BEIR dataset.
+
+The queries and passages are passed independently to the transformer network to produce fixed sized embeddings.
+These embeddings can then be compared using cosine-similarity to find matching passages for a given query.
+
+For training, we use MultipleNegativesRankingLoss. There, we pass triplets in the format:
+(query, positive_passage, negative_passage)
+
+Negative passage are hard negative examples, that where retrieved by lexical search. We use Elasticsearch
+to get (max=10) hard negative examples given a positive passage.
+
+Running this script:
+python train_sbert_BM25_hardnegs.py
+'''
+
+from sentence_transformers import losses, models, SentenceTransformer
+from beir import util, LoggingHandler
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.search.lexical import BM25Search as BM25
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.train import TrainRetriever
+import pathlib, os, tqdm
+import logging
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+#### /print debug information to stdout
+
+#### Download nfcorpus.zip dataset and unzip the dataset
+dataset = "scifact"
+
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where scifact has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_path).load(split="train")
+
+# #### Please Note not all datasets contain a dev split, comment out the line if such the case
+# dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")
+
+#### Lexical Retrieval using Bm25 (Elasticsearch) ####
+#### Provide a hostname (localhost) to connect to ES instance
+#### Define a new index name or use an already existing one.
+#### We use default ES settings for retrieval
+#### https://www.elastic.co/
+
+hostname = "your-hostname" #localhost
+index_name = "your-index-name" # scifact
+
+#### Intialize ####
+# (1) True - Delete existing index and re-index all documents from scratch
+# (2) False - Load existing index
+initialize = True # False
+
+#### Sharding ####
+# (1) For datasets with small corpus (datasets ~ < 5k docs) => limit shards = 1
+# SciFact is a relatively small dataset! (limit shards to 1)
+number_of_shards = 1
+model = BM25(index_name=index_name, hostname=hostname, initialize=initialize, number_of_shards=number_of_shards)
+
+# (2) For datasets with big corpus ==> keep default configuration
+# model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
+bm25 = EvaluateRetrieval(model)
+
+#### Index passages into the index (seperately)
+bm25.retriever.index(corpus)
+
+triplets = []
+qids = list(qrels)
+hard_negatives_max = 10
+
+#### Retrieve BM25 hard negatives => Given a positive document, find most similar lexical documents
+for idx in tqdm.tqdm(range(len(qids)), desc="Retrieve Hard Negatives using BM25"):
+ query_id, query_text = qids[idx], queries[qids[idx]]
+ pos_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
+ pos_doc_texts = [corpus[doc_id]["title"] + " " + corpus[doc_id]["text"] for doc_id in pos_docs]
+ hits = bm25.retriever.es.lexical_multisearch(texts=pos_doc_texts, top_hits=hard_negatives_max+1)
+ for (pos_text, hit) in zip(pos_doc_texts, hits):
+ for (neg_id, _) in hit.get("hits"):
+ if neg_id not in pos_docs:
+ neg_text = corpus[neg_id]["title"] + " " + corpus[neg_id]["text"]
+ triplets.append([query_text, pos_text, neg_text])
+
+#### Provide any sentence-transformers or HF model
+model_name = "distilbert-base-uncased"
+word_embedding_model = models.Transformer(model_name, max_seq_length=300)
+pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
+model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+
+#### Provide a high batch-size to train better with triplets!
+retriever = TrainRetriever(model=model, batch_size=32)
+
+#### Prepare triplets samples
+train_samples = retriever.load_train_triplets(triplets=triplets)
+train_dataloader = retriever.prepare_train_triplets(train_samples)
+
+#### Training SBERT with cosine-product
+train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
+
+#### training SBERT with dot-product
+# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)
+
+#### Prepare dev evaluator
+# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)
+
+#### If no dev set is present from above use dummy evaluator
+ir_evaluator = retriever.load_dummy_evaluator()
+
+#### Provide model save path
+model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v2-{}-bm25-hard-negs".format(model_name, dataset))
+os.makedirs(model_save_path, exist_ok=True)
+
+#### Configure Train params
+num_epochs = 1
+evaluation_steps = 10000
+warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)
+
+retriever.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=ir_evaluator,
+ epochs=num_epochs,
+ output_path=model_save_path,
+ warmup_steps=warmup_steps,
+ evaluation_steps=evaluation_steps,
+ use_amp=True)
\ No newline at end of file
diff --git a/explore.py b/explore.py
new file mode 100644
index 0000000..056a98e
--- /dev/null
+++ b/explore.py
@@ -0,0 +1,224 @@
+from beir import util, LoggingHandler
+from beir.retrieval import models
+from beir.datasets.data_loader import GenericDataLoader
+from beir.retrieval.evaluation import EvaluateRetrieval
+from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
+
+
+import logging
+import pathlib, os
+import os, sys
+
+def intersect_res(results):
+ intersect_keys = set(results[list(results.keys())[0]].keys())
+ for idx in range(len(list(results.keys()))-1):
+ intersect_keys = intersect_keys.intersection(set(results[list(results.keys())[idx + 1]].keys()))
+
+ intersect_res = {}
+ for key in intersect_keys:
+ score = 1
+ for idx in range(len(list(results.keys()))):
+ score = score*results[list(results.keys())[idx]][key]/100
+ intersect_res[key] = score*100
+
+ return intersect_res
+
+
+
+def compare_and_not_query(retriever, corpus, queries, qrels):
+
+ results = retriever.retrieve(corpus, {"8": queries["8"]})
+
+ print("length of the results::", len(results))
+
+ print("results without decomposition::")
+
+ ndcg, _map, recall, precision = retriever.evaluate({"8": qrels["8"]}, results, retriever.k_values)
+
+ new_results = retriever.retrieve(corpus, {"8": ["lack of testing availability", "underreporting of true incidence of Covid-19"]})
+
+ #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+
+
+ new_merged_resuts = intersect_res(new_results)
+
+ print("length of the merged results::", len(new_merged_resuts))
+
+ print("results with decomposition::")
+
+ new_ndcg, new_map, new_recall, new_precision = retriever.evaluate({"8": qrels["8"]}, {"8": new_merged_resuts}, retriever.k_values)
+
+ new_results2 = retriever.retrieve(corpus, {"8": ["testing availability", "underreporting of true incidence of Covid-19"]}, query_negations={"8":[True,False]})
+
+ #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+
+
+ new_merged_resuts2 = intersect_res(new_results2)
+
+ print("length of the merged results::", len(new_merged_resuts2))
+
+ print("results with decomposition::")
+
+ new_ndcg2, new_map2, new_recall2, new_precision2 = retriever.evaluate({"8": qrels["8"]}, {"8": new_merged_resuts2}, retriever.k_values)
+
+
+def compare_query_weather(retriever, corpus, queries, qrels):
+ # retriever.top_k = 2500
+
+ results = retriever.retrieve(corpus, {"2": queries["2"], "3": queries["2"]})
+
+ print("length of the results::", len(results))
+
+ print("results without decomposition::")
+
+ ndcg, _map, recall, precision = retriever.evaluate({"2": qrels["2"], "3":qrels["2"]}, results, retriever.k_values)
+
+ # retriever.top_k = 1000
+
+ new_results = retriever.retrieve(corpus, {"2": ["change of weather", "coronaviruses"]})
+
+ #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+
+
+ new_merged_resuts = intersect_res(new_results)
+
+ print("length of the merged results::", len(new_merged_resuts))
+
+ print("results with decomposition::")
+
+ new_ndcg, new_map, new_recall, new_precision = retriever.evaluate({"2": qrels["2"], "3":qrels["2"]}, {"2": new_merged_resuts, "3":new_merged_resuts}, retriever.k_values)
+
+
+ new_results2 = retriever.retrieve(corpus, {"2": ["weather", "coronaviruses"]})
+
+ new_merged_resuts2 = intersect_res(new_results2)
+
+ print("length of the merged results::", len(new_merged_resuts2))
+
+ print("results with decomposition 2::")
+
+ new_ndcg2, new_map2, new_recall2, new_precision2 = retriever.evaluate({"2": qrels["2"], "3":qrels["2"]}, {"2": new_merged_resuts2, "3":new_merged_resuts2}, retriever.k_values)
+
+def compare_query_social_distance(retriever, corpus, queries, qrels):
+ # retriever.top_k = 2500
+
+ results = retriever.retrieve(corpus, {"10": queries["10"]})
+
+ print("length of the results::", len(results))
+
+ print("results without decomposition::")
+
+ ndcg, _map, recall, precision = retriever.evaluate({"10": qrels["10"]}, results, retriever.k_values)
+
+ # retriever.top_k = 1000
+
+ new_results = retriever.retrieve(corpus, {"10": ["social distancing", "impact", "slowing", "spread", "COVID-19"]})
+
+ #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+
+
+ new_merged_resuts = intersect_res(new_results)
+
+ print("length of the merged results::", len(new_merged_resuts))
+
+ print("results with decomposition::")
+
+ new_ndcg, new_map, new_recall, new_precision = retriever.evaluate({"10": qrels["10"]}, {"10": new_merged_resuts}, retriever.k_values)
+
+def compare_ace_example(retriever, corpus, queries, qrels):
+ # retriever.top_k = 2500
+
+ results = retriever.retrieve(corpus, {"20": queries["20"]})
+
+ print("query::", queries["20"])
+
+ print("length of the results::", len(results))
+
+ print("results without decomposition::")
+
+ ndcg, _map, recall, precision = retriever.evaluate({"20": qrels["20"]}, results, retriever.k_values)
+
+ # retriever.top_k = 1000
+
+ new_results = retriever.retrieve(corpus, {"20": ["Angiotensin-converting enzyme inhibitors", "increased risk", "COVID-19"]})
+
+ #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+
+
+ new_merged_resuts = intersect_res(new_results)
+
+ print("length of the merged results::", len(new_merged_resuts))
+
+ print("results with decomposition::")
+
+ new_ndcg, new_map, new_recall, new_precision = retriever.evaluate({"20": qrels["20"]}, {"20": new_merged_resuts}, retriever.k_values)
+
+ new_results = retriever.retrieve(corpus, {"20": ["patients", "Angiotensin-converting enzyme inhibitors", "increased risk", "COVID-19"]})
+
+ #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+
+
+ new_merged_resuts = intersect_res(new_results)
+
+ print("length of the merged results 2::", len(new_merged_resuts))
+
+ print("results with decomposition 2::")
+
+ new_ndcg, new_map, new_recall, new_precision = retriever.evaluate({"20": qrels["20"]}, {"20": new_merged_resuts}, retriever.k_values)
+ # new_results2 = retriever.retrieve(corpus, {"2": ["weather", "coronaviruses"]})
+
+ # new_merged_resuts2 = intersect_res(new_results2)
+
+ # print("length of the merged results::", len(new_merged_resuts2))
+
+ # print("results with decomposition 2::")
+
+ # new_ndcg2, new_map2, new_recall2, new_precision2 = retriever.evaluate({"2": qrels["2"], "3":qrels["2"]}, {"2": new_merged_resuts2, "3":new_merged_resuts2}, retriever.k_values)
+
+
+#### Just some code to print debug information to stdout
+logging.basicConfig(format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ level=logging.INFO,
+ handlers=[LoggingHandler()])
+
+# data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data/")
+# dataset = "scifact"
+# url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+dataset = "trec-covid"
+url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
+# out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "data")
+out_dir = "/data6/wuyinjun/beir/data/"
+data_path = util.download_and_unzip(url, out_dir)
+
+#### Provide the data_path where scifact has been downloaded and unzipped
+corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
+
+#### Load the SBERT model and retrieve using cosine-similarity
+model = DRES(models.SentenceBERT("msmarco-distilbert-base-tas-b"), batch_size=16)
+retriever = EvaluateRetrieval(model, score_function="cos_sim") # or "cos_sim" for cosine similarity
+
+
+# compare_query_social_distance(retriever, corpus, queries, qrels)
+
+compare_ace_example(retriever, corpus, queries, qrels)
+# compare_query_weather(retriever, corpus, queries, qrels)
+# compare_and_not_query(retriever, corpus, queries, qrels)
+
+# results = retriever.retrieve(corpus, {"2": queries["2"], "3": queries["2"]})
+
+# new_results = retriever.retrieve(corpus, {"51": "change of weather", "52":"coronaviruses"})
+
+# #### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
+# print("results without decomposition::")
+
+# ndcg, _map, recall, precision = retriever.evaluate({"2": qrels["2"], "3":qrels["2"]}, results, retriever.k_values)
+
+# new_merged_resuts = intersect_res(new_results)
+
+# print("results with decomposition::")
+
+# new_ndcg, new_map, new_recall, new_precision = retriever.evaluate({"2": qrels["2"], "3":qrels["2"]}, {"2": new_merged_resuts, "3":new_merged_resuts}, retriever.k_values)
+
+
+print()
\ No newline at end of file
diff --git a/images/HF.png b/images/HF.png
new file mode 100644
index 0000000..96f37f1
Binary files /dev/null and b/images/HF.png differ
diff --git a/images/color_logo.png b/images/color_logo.png
new file mode 100644
index 0000000..bb41d85
Binary files /dev/null and b/images/color_logo.png differ
diff --git a/images/color_logo_transparent_cropped.png b/images/color_logo_transparent_cropped.png
new file mode 100644
index 0000000..c8f9820
Binary files /dev/null and b/images/color_logo_transparent_cropped.png differ
diff --git a/images/tu-darmstadt.png b/images/tu-darmstadt.png
new file mode 100644
index 0000000..04d2809
Binary files /dev/null and b/images/tu-darmstadt.png differ
diff --git a/images/ukp.png b/images/ukp.png
new file mode 100644
index 0000000..7c3a2b4
Binary files /dev/null and b/images/ukp.png differ
diff --git a/images/uwaterloo.png b/images/uwaterloo.png
new file mode 100644
index 0000000..405be05
Binary files /dev/null and b/images/uwaterloo.png differ
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..224a779
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[metadata]
+description-file = README.md
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..a9768a3
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,39 @@
+from setuptools import setup, find_packages
+
+with open("README.md", mode="r", encoding="utf-8") as readme_file:
+ readme = readme_file.read()
+
+optional_packages = {
+ "tf" : ['tensorflow>=2.2.0', 'tensorflow-text', 'tensorflow-hub']
+}
+
+setup(
+ name="beir",
+ version="1.0.1",
+ author="Nandan Thakur",
+ author_email="nandant@gmail.com",
+ description="A Heterogeneous Benchmark for Information Retrieval",
+ long_description=readme,
+ long_description_content_type="text/markdown",
+ license="Apache License 2.0",
+ url="https://github.com/beir-cellar/beir",
+ download_url="https://github.com/beir-cellar/beir/archive/v1.0.1.zip",
+ packages=find_packages(),
+ python_requires='>=3.6',
+ install_requires=[
+ 'sentence-transformers',
+ 'pytrec_eval',
+ 'faiss_cpu',
+ 'elasticsearch==7.9.1',
+ 'datasets'
+ ],
+ extras_require = optional_packages,
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3.6",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
+ ],
+ keywords="Information Retrieval Transformer Networks BERT PyTorch IR NLP deep learning"
+)
\ No newline at end of file