forked from docker/genai-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbot.py
193 lines (162 loc) · 6.21 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# bot.py - This script is a Streamlit application that provides a user interface for interacting
# with a language model (LLM) and a Neo4j graph database. It allows users to input coding-related
# questions and receive answers generated by the LLM, optionally enhanced with a Retrieval-Augmented
# Generation (RAG) approach using vector embeddings and graph data. Additionally, it allows you to
# submit a support ticket if you do not find what you are looking for.
import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import (
create_vector_index,
)
from chains import (
load_embedding_model,
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
generate_ticket,
)
# Load environment variables from a .env file
load_dotenv(".env")
# Retrieve Neo4j and Ollama configuration from environment variables
url = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
embedding_model_name = os.getenv("EMBEDDING_MODEL")
llm_name = os.getenv("LLM")
# Remapping for Langchain Neo4j integration
os.environ["NEO4J_URL"] = url
logger = get_logger(__name__)
# Initialize the Neo4j graph database connection
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
embeddings, dimension = load_embedding_model(
embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
)
create_vector_index(neo4j_graph, dimension)
# Define a callback handler for the language model's token generation
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
# Load the language model
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
# Configure the chains for LLM only and RAG
llm_chain = configure_llm_only_chain(llm)
rag_chain = configure_qa_rag_chain(
llm, embeddings, embeddings_store_url=url, username=username, password=password
)
# Streamlit UI
styl = f"""
<style>
/* not great support for :has yet (hello FireFox), but using it for now */
.element-container:has([aria-label="Select RAG mode"]) {{
position: fixed;
bottom: 33px;
background: white;
z-index: 101;
}}
.stChatFloatingInputContainer {{
bottom: 20px;
}}
/* Generate ticket text area */
textarea[aria-label="Description"] {{
height: 200px;
}}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
# Function to handle chat input from the user
def chat_input():
user_input = st.chat_input("What coding issue can I help you resolve today?")
if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
st.caption(f"RAG: {name}")
stream_handler = StreamHandler(st.empty())
result = output_function(
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
)["answer"]
output = result
st.session_state[f"user_input"].append(user_input)
st.session_state[f"generated"].append(output)
st.session_state[f"rag_mode"].append(name)
# Function to display the chat history
def display_chat():
# Session state
if "generated" not in st.session_state:
st.session_state[f"generated"] = []
if "user_input" not in st.session_state:
st.session_state[f"user_input"] = []
if "rag_mode" not in st.session_state:
st.session_state[f"rag_mode"] = []
# Display the chat history
if st.session_state[f"generated"]:
size = len(st.session_state[f"generated"])
# Display only the last three exchanges
for i in range(max(size - 3, 0), size):
with st.chat_message("user"):
st.write(st.session_state[f"user_input"][i])
with st.chat_message("assistant"):
st.caption(f"RAG: {st.session_state[f'rag_mode'][i]}")
st.write(st.session_state[f"generated"][i])
# Expander for generating a support ticket
with st.expander("Not finding what you're looking for?"):
st.write(
"Automatically generate a draft for an internal ticket to our support team."
)
st.button(
"Generate ticket",
type="primary",
key="show_ticket",
on_click=open_sidebar,
)
with st.container():
st.write(" ")
# Function to select the RAG mode
def mode_select() -> str:
options = ["Disabled", "Enabled"]
return st.radio("Select RAG mode", options, horizontal=True)
# Set the RAG mode based on user selection
name = mode_select()
if name == "LLM only" or name == "Disabled":
output_function = llm_chain
elif name == "Vector + Graph" or name == "Enabled":
output_function = rag_chain
# Functions to handle sidebar for generating a support ticket
def open_sidebar():
st.session_state.open_sidebar = True
def close_sidebar():
st.session_state.open_sidebar = False
# Initialize sidebar state
if not "open_sidebar" in st.session_state:
st.session_state.open_sidebar = False
# If sidebar is open, generate and display the ticket draft
if st.session_state.open_sidebar:
new_title, new_question = generate_ticket(
neo4j_graph=neo4j_graph,
llm_chain=llm_chain,
input_question=st.session_state[f"user_input"][-1],
)
with st.sidebar:
st.title("Ticket draft")
st.write("Auto generated draft ticket")
st.text_input("Title", new_title)
st.text_area("Description", new_question)
st.button(
"Submit to support team",
type="primary",
key="submit_ticket",
on_click=close_sidebar,
)
# Display the chat and handle user input
display_chat()
chat_input()