-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMAIN_UI
83 lines (65 loc) · 3.01 KB
/
MAIN_UI
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
## MAIN UI ##
import streamlit as st
from streamlit import write, empty
from langchain.chains import ConversationChain
from langchain.llms import OpenAI
from ingest_data import embed_doc
from query_data import _template, CONDENSE_QUESTION_PROMPT, QA_PROMPT, get_chain
import pickle
import os
# Streamlit UI
st.set_page_config(page_title="ChatDoc", page_icon=':robot:')
st.header('ChatDoc')
#set hyperparameter sliders
temperature = st.slider('Set Creativity:', min_value=0.01, max_value=2.0, value=0.9, step=0.05, format='%0.2f', key=None)
max_tokens = st.slider('Set Max Words:', min_value=10, max_value=1000, value=500, step=5, format='%f', key=None)
chunk_size = st.slider('Set Chunk_size:', min_value=100, max_value=5000, value=500, step=10, format='%f', key=None)
chunk_overlap = st.slider('Set Chunk_overlap:', min_value=0, max_value=500, value=20, step=10, format='%f', key=None)
uploaded_file = st.file_uploader('Upload a Doc you want to Chat with!',
type=None, accept_multiple_files=False,
key=None, help=None, on_change=None, args=None,
kwargs=None, disabled=False, label_visibility='visible')
if uploaded_file is not None and uploaded_file.name not in os.listdir('data'):
# write file to dir
with open('data/' + uploaded_file.name, 'wb') as f:
f.write(uploaded_file.read())
st.write("Doc uploaded :)")
with st.spinner('Fetching Doc Vectors...'):
embed_doc(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
if 'vectorstore.pkl' in os.listdir('.'):
with open('vectorstore.pkl', 'rb') as f:
vectorstore = pickle.load(f)
print('Loading vectorstore...')
chain = get_chain(vectorstore, temperature=temperature)
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
placeholder = st.empty()
def get_text():
input_text = placeholder.text_input('you: ', value='', key='input')
return input_text
user_input = get_text()
print(st.session_state.input)
print(user_input)
docs = []
if user_input:
#search chunks for similar vectors
docs = vectorstore.similarity_search(user_input)
#run langchain
output = chain.run(input=user_input, vectorstore=vectorstore,
context=docs[:10], chat_history=[], question=user_input,
QA_PROMPT=QA_PROMPT, CONDENSE_QUESTION_PROMPT=CONDENSE_QUESTION_PROMPT,
template = _template)
st.session_state.past.append(user_input)
print(st.session_state.past)
st.session_state.generated.append(output)
print(len(docs))
if st.checkbox('Show source documents:'):
for i in range(5):
doc_text = docs[i].page_content
st.markdown(f':red[{doc_text}]')
if st.session_state['generated']:
for i in range(len(st.session_state['generated']) -1, -1, -1):
st.write(st.session_state['generated'][i], key=str(i))
st.write(st.session_state['past'][i], key=str(i))