From 5bb7b2cdb3a0a92401dbdc6dbb62168b4ee334a6 Mon Sep 17 00:00:00 2001 From: Baivab Sarkar <109382325+ThisIs-Developer@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:46:03 +0530 Subject: [PATCH] Update model.py --- model.py | 91 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/model.py b/model.py index 4ae8b70..f9f9259 100644 --- a/model.py +++ b/model.py @@ -7,50 +7,73 @@ from langchain.llms import CTransformers from langchain.chains import ConversationalRetrievalChain -st.title("Conversational Retrieval System") +def add_vertical_space(spaces=1): + for _ in range(spaces): + st.sidebar.markdown("---") -DB_FAISS_PATH = "vectorstore/db_faiss" -TEMP_DIR = "temp" +def main(): + st.set_page_config(page_title="Llama-2-GGML CSV Chatbot") + st.title("Llama-2-GGML CSV Chatbot") -# Create temp directory if it doesn't exist -if not os.path.exists(TEMP_DIR): - os.makedirs(TEMP_DIR) + st.sidebar.title("About") + st.sidebar.markdown(''' + The Llama-2-GGML CSV Chatbot uses the **Llama-2-7B-Chat-GGML** model. + + ### 🔄Bot evolving, stay tuned! + + ## Useful Links 🔗 + + - **Model:** [Llama-2-7B-Chat-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/tree/main) 📚 + - **GitHub:** [ThisIs-Developer/Llama-2-GGML-CSV-Chatbot](https://github.com/ThisIs-Developer/Llama-2-GGML-CSV-Chatbot) 💬 + ''') -# Sidebar for uploading CSV file -uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv']) + DB_FAISS_PATH = "vectorstore/db_faiss" + TEMP_DIR = "temp" -if uploaded_file is not None: - file_path = os.path.join(TEMP_DIR, uploaded_file.name) - with open(file_path, "wb") as f: - f.write(uploaded_file.getvalue()) + if not os.path.exists(TEMP_DIR): + os.makedirs(TEMP_DIR) - st.write(f"Uploaded file: {uploaded_file.name}") - st.write("Processing CSV file...") + uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv']) - loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','}) - data = loader.load() + add_vertical_space(1) + st.sidebar.write('Made by [@ThisIs-Developer](https://huggingface.co/ThisIs-Developer)') - text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20) - text_chunks = text_splitter.split_documents(data) + if uploaded_file is not None: + file_path = os.path.join(TEMP_DIR, uploaded_file.name) + with open(file_path, "wb") as f: + f.write(uploaded_file.getvalue()) - st.write(f"Total text chunks: {len(text_chunks)}") + st.write(f"Uploaded file: {uploaded_file.name}") + st.write("Processing CSV file...") - embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') - docsearch = FAISS.from_documents(text_chunks, embeddings) - docsearch.save_local(DB_FAISS_PATH) + loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','}) + data = loader.load() - llm = CTransformers(model="models/llama-2-7b-chat.ggmlv3.q4_0.bin", - model_type="llama", - max_new_tokens=512, - temperature=0.1) + text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20) + text_chunks = text_splitter.split_documents(data) - qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever()) + st.write(f"Total text chunks: {len(text_chunks)}") - st.write("Enter your query:") - query = st.text_input("Input Prompt:") - if query: - chat_history = [] - result = qa({"question": query, "chat_history": chat_history}) - st.write("Response:", result['answer']) + embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') + docsearch = FAISS.from_documents(text_chunks, embeddings) + docsearch.save_local(DB_FAISS_PATH) - os.remove(file_path) # Remove the temporary file after processing + llm = CTransformers(model="models/llama-2-7b-chat.ggmlv3.q4_0.bin", + model_type="llama", + max_new_tokens=512, + temperature=0.1) + + qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever()) + + st.write("Enter your query:") + query = st.text_input("Input Prompt:") + if query: + with st.spinner("Processing your question..."): + chat_history = [] + result = qa({"question": query, "chat_history": chat_history}) + st.write("Response:", result['answer']) + + os.remove(file_path) + +if __name__ == "__main__": + main()