AT-RAG: An Adaptive RAG Model Enhancing Query Efficiency with Topic Filtering and Iterative Reasoning
AT-RAG (Adaptive Retrieval-Augmented Generation) is a novel RAG model developed to address the challenges of complex multi-hop queries, which are often problematic for large language models (LLMs) like GPT-4. By incorporating topic filtering and iterative reasoning, AT-RAG significantly improves both retrieval efficiency and reasoning accuracy in question answering (QA).
AT-RAG leverages BERTopic for dynamic topic modeling, which assigns relevant topics to each incoming query, thereby boosting retrieval accuracy and computational efficiency. This model is adept at handling different QA tasks, including both general inquiries and complex, domain-specific scenarios, such as medical QA, by managing intricate multi-step queries effectively.
The figure below provides an overview of the AT-RAG model architecture, showcasing the integration of topic filtering and iterative reasoning for enhanced query efficiency and accuracy:
The average overall score across multiple datasets for AT-RAG demonstrates a significant performance improvement over state-of-the-art models like Adaptive RAG. The graph below shows the comparison in performance scores, along with standard deviations for each model:
For more details, please refer to our paper:
AT-RAG: An Adaptive RAG Model Enhancing Query Efficiency with Topic Filtering and Iterative Reasoning
- Topic Filtering: Uses BERTopic to dynamically assign relevant topics to each query, improving retrieval accuracy.
- Iterative Reasoning: Employs multistep reasoning to answer complex, multi-hop queries.
- Efficiency & Precision: Reduces retrieval time while maintaining high precision, making it suitable for both general and specialized tasks.
- Versatile Use Cases: Demonstrated effectiveness in both standard QA benchmarks and medical QA case studies.
You can download multi-hop datasets (MuSiQue, HotpotQA, and 2WikiMultiHopQA) from StonyBrookNLP/ircot.
# Download the preprocessed datasets for the test set.
bash ./download/processed_data.sh
This section describes how to perform topic modeling using the BERTopicTrainer
class, which is based on the BERTopic algorithm. In this example, we load a dataset, train a topic model, and extract topics and their probabilities for new documents.
# Define dataset parameters
dataset = "2wikimultihopqa"
subsample = "test_subsampled"
dataset_path = f"../processed_data/{dataset}/{subsample}.jsonl"
# Create an instance of BERTopicTrainer and set the number of topics
trainer = BERTopicTrainer(dataset_path=dataset_path, nr_topics=20) # Reduce to 20 topics
# Load data and train the topic model
documents = trainer.load_data()
topics, probabilities = trainer.train_topic_model(documents)
# Get topics and probabilities for new documents
trainer.load_topic_model()
new_documents = ["This is an example of a new document about AI and technology."]
new_topics, new_probabilities = trainer.get_topics_with_probabilities(new_documents)
# Display topics and their probabilities for the new document
print("Topics for new documents:", new_topics)
print("Probability vectors for new documents:", new_probabilities)
# Get and display topic information
topic_info = trainer.get_topic_info()
print("Topic Information:\n", topic_info)
- dataset: Name of the dataset for topic modeling (e.g.,
2wikimultihopqa
). - subsample: Subsample of the dataset (e.g.,
test_subsampled
). - dataset_path: Path to the dataset in
.jsonl
format.
load_data()
: Loads the data from the specified dataset path.train_topic_model(documents)
: Trains the BERTopic model using the provided documents.load_topic_model()
: Loads a previously trained topic model.get_topics_with_probabilities(new_documents)
: Retrieves topics and their probabilities for new documents.get_topic_info()
: Retrieves information about the topics (e.g., topic labels and top words).
- Topics for New Documents: Lists topics assigned to each new document.
- Probability Vectors: Provides probability distribution over topics for each new document.
- Topic Information: Displays detailed information about each topic, including top contributing words.
This section demonstrates how to use the Ingestor
class to ingest a dataset and retrieve the top n
answers for a given question.
if __name__ == "__main__":
dataset = "2wikimultihopqa"
subsample = "test_subsampled"
top_n = 10 # Number of top answers to retrieve
# Create an instance of Ingestor with pre-trained models
ingestor = Ingestor(
dataset_path=f"../processed_data/{dataset}/{subsample}.jsonl",
persist_directory=f"../vectorDB/{dataset}",
openai_api_key=openai_api_key
)
# Create a vector database
vectordb = ingestor.create_vectordb()
# Example: Ask a question and retrieve top answers
question = "Who is the father-in-law of Queen Hyojeong?"
results = ingestor.query_question(question, top_n=top_n)
# Display the results
print(results)
- dataset: Name of the dataset for question answering (e.g.,
2wikimultihopqa
). - subsample: Subsample of the dataset (e.g.,
test_subsampled
). - top_n: Number of top answers to retrieve for the question.
create_vectordb()
: Creates a vector database from the ingested dataset.
This section demonstrates how to use the TopicCoTSelfRAG
class to process documents, query questions, and retrieve generated answers using a pre-trained model and vector database. The pipeline incorporates topic modeling through a CoT (Chain-of-Thought) approach and RAG (Retrieval-Augmented Generation) for answering questions.
if __name__ == "__main__":
dataset = "2wikimultihopqa"
subsample = "test_subsampled"
model = "AT_RAG"
top_n = 10 # Number of topics to retrieve
max_iter = 5 # Maximum iterations for topic modeling
# Create an instance of TopicCoTSelfRAG
pipeline = TopicCoTSelfRAG(
vectorDB_path=f"../vectorDB/{dataset}",
dataset_path=f"../processed_data/{dataset}/{subsample}.jsonl",
nr_topics=top_n,
max_iter=max_iter
)
# Load evaluation data
dict_results = pipeline.ingestor.load_evaluation_data()
dict_results["generated_answer"] = []
# Generate answers for each question
for i in range(len(dict_results["question_id"])):
question = dict_results["question_text"][i]
_ = pipeline.run_pipeline(question=question)
result = pipeline.last_answer.strip()
dict_results["generated_answer"].append(result)
print(f"{question} -> {dict_results['ground_truth'][i]} -> {result}")
# Save results to CSV
pd.DataFrame(dict_results).to_csv(
f"../results/results_{dataset}_{subsample}_{model}.csv", index=False
)
- dataset: Dataset for document processing and question answering (e.g.,
2wikimultihopqa
). - subsample: Subsample of the dataset (e.g.,
test_subsampled
). - model: Model used for generating answers (e.g.,
AT_RAG
). - top_n: Number of topics to retrieve during topic modeling.
- max_iter: Maximum iterations for the answering pipline.
load_evaluation_data()
: Loads evaluation data containing questions and ground truth.run_pipeline(question)
: Runs the pipeline for a given question, including topic modeling and answer generation.last_answer
: Retrieves the last generated answer.
This section demonstrates how to use the Evaluation
class to evaluate the results of a dataset using ROUGE scores and the LLMaaJ (Large Language Model as a Judge) framework with GPT models.
if __name__ == "__main__":
dataset = "2wikimultihopqa"
subsample = "test_subsampled"
rag = "AT_RAG"
# Create an instance of Evaluation
evaluation = Evaluation(dataset=dataset, subsample=subsample, rag=rag)
# Perform ROUGE evaluation
evaluation.rouge()
# Perform LLM-based evaluation using GPT-4o-mini
evaluation.LLMaaJ(model="gpt-4o-mini")
- dataset: Name of the dataset used for evaluation (e.g.,
2wikimultihopqa
). - subsample: Subsample of the dataset (e.g.,
test_subsampled
). - rag: RAG model used (e.g.,
AT_RAG
). - model: Model for LLM-based evaluation (e.g.,
gpt-4o-mini
).
rouge()
: Evaluates the dataset using ROUGE metrics.LLMaaJ(model)
: Uses an LLM to evaluate the quality of generated answers.
- Final results are saved in the results directory
When utilizing this repository to generate published results or integrate it into other software, kindly acknowledge our paper by citing it.
@article{rezaei2024rag,
title={AT-RAG: An Adaptive RAG Model Enhancing Query Efficiency with Topic Filtering and Iterative Reasoning},
author={Rezaei, Mohammad Reza and Hafezi, Maziar and Satpathy, Amit and Hodge, Lovell and Pourjafari, Ebrahim},
journal={arXiv preprint arXiv:2410.12886},
year={2024}
}