-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathagent.py
172 lines (125 loc) · 5.03 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
from pathlib import Path
from typing import List
import chainlit as cl
from dotenv import load_dotenv
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import StructuredTool
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.agents import initialize_agent, AgentExecutor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain_community.document_loaders import CSVLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from openai import AsyncOpenAI
from modules.database.database import PostgresDB
"""
Here we define some environment variables and the tools that the agent will use.
Along with some configuration for the app to start.
"""
load_dotenv()
chunk_size = 512
chunk_overlap = 50
embeddings_model = OpenAIEmbeddings()
openai_client = AsyncOpenAI()
CSV_STORAGE_PATH = "./data"
def process_pdfs(pdf_storage_path: str):
csv_directory = Path(pdf_storage_path)
docs = [] # type: List[Document]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=50)
for csv_path in csv_directory.glob("*.csv"):
loader = CSVLoader(file_path=str(csv_path))
documents = loader.load()
docs += text_splitter.split_documents(documents)
documents_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
documents_search,
cleanup="incremental",
source_id_key="source",
)
print(f"Indexing stats: {index_result}")
return documents_search
doc_search = process_pdfs(CSV_STORAGE_PATH)
"""
Execute SQL query tool definition along schemas.
"""
def execute_sql(query: str) -> str:
"""
Execute queries against the database. It needs to be a clean SQL
query in one line without backticks or line jumps.
"""
db = PostgresDB()
db.connect_with_url(os.getenv("DB_URL"))
results = db.run_sql_to_markdown(query)
return results
class ExecuteSqlToolInput(BaseModel):
query: str = Field(
description="A clean SQL query in one line to be executed agains the database")
execute_sql_tool = StructuredTool(
func=execute_sql,
name="Execute SQL",
description="useful for when you need to execute SQL queries against the database. Always use a clause LIMIT 10",
args_schema=ExecuteSqlToolInput
)
"""
Research database tool definition along schemas.
"""
def research_database(user_request: str) -> str:
"""
Searches for table definitions matching the user request
"""
search_kwargs = {"k": 30}
retriever = doc_search.as_retriever(search_kwargs=search_kwargs)
def format_docs(docs):
for i, doc in enumerate(docs):
print(f"{i+1}. {doc.page_content}")
return "\n\n".join([d.page_content for d in docs])
results = retriever.invoke(user_request)
return format_docs(results)
class ResearchDatabaseToolInput(BaseModel):
user_request: str = Field(
description="The user query to search against the table definitions for matches. Always use a clase of LIMIT 10")
research_database_tool = StructuredTool(
func=research_database,
name="Search db info",
description="Search for database information so you can have context for building SQL queries.",
args_schema=ResearchDatabaseToolInput
)
@cl.on_chat_start
def start():
tools = [execute_sql_tool, research_database_tool]
llm = ChatOpenAI(model="gpt-4", temperature=0)
prompt = ChatPromptTemplate.from_template(
"""
You are a world class data scientist, your job is to listen to the user query
and based on it, use on of your tools to do the job. Usually you would start by analyzing
for possible SQL queries the user wants to build based on your knowledge base.
Remember your tools are:
- execute_sql (bring back the results as markdown table)
- research_database (search for table definitions so you can build a SQL Query)
Think carefully before routing to one of the tools. If you don't know what the user wants or you
dont understand, ask for clarification.
Remember, if you don't know the answer don't make anything up. Always ask for feedback.
One last detail: always run the querys with LIMIT 10.
User query: {input}
"""
)
agent = initialize_agent(tools=tools, prompt=prompt, llm=llm)
cl.user_session.set("agent", agent)
@cl.on_message
async def main(message: cl.Message):
agent = cl.user_session.get("agent") # type: AgentExecutor
res = await agent.arun(
message.content, callbacks=[cl.AsyncLangchainCallbackHandler()]
)
await cl.Message(content=res).send()