forked from CalebJKim/DataPilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
282 lines (218 loc) · 11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import sys
import os
from llm import *
from autogen import ConversableAgent
from autogen import AssistantAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from prompts import Prompts
from typing import Dict, Any
#from dataframe_analyzer import is_data_relevant, is_sample_size_sufficient
warnings.filterwarnings("ignore", category=DeprecationWarning)
from datetime import datetime
import os
import sqlite3
import logging
import visualization_executor
# Set up logging
logging.basicConfig(filename="pipeline.log", level=logging.INFO)
project_dir = os.getcwd()
def query_database(db_path: str, query: str) -> pd.DataFrame:
"""
Executes an SQL query on the given SQLite database and returns the results as a DataFrame.
"""
conn = sqlite3.connect(db_path)
result = pd.read_sql_query(query, conn)
conn.close()
return result
def fetch_and_process_data(csv_path):
# Load the carsalesdata.csv file
df = pd.read_csv(csv_path)
logging.info(f"Loaded carsalesdata.csv with {df.shape[0]} rows.")
# Simulate data cleaning
df.dropna(inplace=True) # Example cleaning step
logging.info("Data cleaned. Dropped NA values.")
# Save to SQLite database
conn = sqlite3.connect('data/sqlite_db/data.db')
df.to_sql('transactions', conn, if_exists='replace', index=False)
conn.close()
logging.info("Data saved to SQLite database.")
def is_data_relevant(data: Any) -> bool:
"""
Placeholder function to check if the data is relevant.
"""
return bool(data) and not data.empty
def is_sample_size_sufficient(data: Any) -> bool:
"""
Placeholder function to check if the sample size is sufficient.
"""
return len(data) >= 10
def evaluate_query_results(data: Any, original_prompt: str, sql_generator_agent: ConversableAgent, llm: LLM) -> None:
"""
Evaluates query results and determines next actions based on data quality.
- If data is relevant and useful, calls the data analyst function with the data and original prompt.
- If data is not relevant or the sample size is too small, prompts Agent 1 for a new SQL query.
"""
if is_data_relevant(data) and is_sample_size_sufficient(data):
# Data meets quality criteria - send to analyst
logging.info("Data is relevant and sufficient. Proceeding to analysis.")
analyze_data(data, original_prompt)
else:
# Data does not meet quality criteria - request new SQL query from Agent 1
logging.warning("Data is not relevant or insufficient. Reinvoking Agent 1 for a new SQL query.")
new_prompt = f"""The previous SQL query did not yield relevant or sufficient results.
Original request: "{original_prompt}"
Please generate a new SQL query with refined conditions to improve relevance or increase the sample size."""
# Generate a new SQL query using Agent 1
new_query = llm.generate_sql_query(new_prompt)
logging.info(f"Generated new SQL query: {new_query}")
# Execute the new SQL query using Agent 2
csv_file_path = "datasets/samples.csv"
new_data = llm.execute_sql_query(new_query, csv_file_path)
# Re-evaluate the new results
evaluate_query_results(new_data, original_prompt, sql_generator_agent, llm)
def analyze_data(data: Any, prompt: str) -> None:
"""
Function call to data analyst with validated data and original prompt.
"""
logging.info("Sending data to data analyst for further analysis.")
print("Data is relevant and sufficient. Sending to data analyst.")
print(f"Prompt: {prompt}")
print(f"Data: {data}")
def main(user_query):
api_key = os.environ.get("OPENAI_API_KEY")
#llm = LLM(model_name="gemini-1.5-pro-latest")
llm_config = {"config_list": [{"model": "gpt-4o-mini", "api_key": api_key}]}
# agent 0a in diagram
db_eda_agent = ConversableAgent("db_eda_agent",
system_message=Prompts.database_EDA_agent_prompt,
llm_config=llm_config)
db_eda_agent.register_for_llm(name="db_eda_agent", description="Performs exploratory data analysis on the database and return useful information.")
db_eda_agent.register_for_execution(name="db_eda_agent")
# agent 0b in diagram
metric_agent = ConversableAgent("metric_agent",
system_message=Prompts.metric_agent_prompt,
llm_config=llm_config)
metric_agent.register_for_llm(name="metric_agent", description="Combines EDA with dealership-specific frameworks to provide metrics to look at")
metric_agent.register_for_execution(name="metric_agent")
# RAG Agent
# Get all file paths in the 'Frameworks' directory; to setup RAG
frameworks_dir = 'Frameworks'
docs_paths = [os.path.join(frameworks_dir, file) for file in os.listdir(frameworks_dir) if os.path.isfile(os.path.join(frameworks_dir, file))]
# retrieve_config = {"config_list": [{"model": "gpt-4o-mini", "api_key": api_key}]}
# # agent 0b with rag
# ragproxyagent = RetrieveUserProxyAgent(
# # name="ragproxyagent",
# system_message=Prompts.RAG_agent_prompt,
# # retrieve_config = {
# # "task":"qa",
# # "docs_path":docs_paths
# # })
# name="RAG Proxy Agent",
# is_termination_msg=None, # Assuming termination_msg is defined elsewhere
# # system_message="Running RAG agent who has extra content retrieval power for solving difficult problems.",
# human_input_mode="NEVER",
# max_consecutive_auto_reply=3,
# retrieve_config={
# "task": "qa",
# "model" : "gpt-4o-mini",
# "docs_path": docs_paths
# },
# code_execution_config=False
# )
# ragproxyagent.register_for_llm(name="ragproxyagent")
# ragproxyagent.register_for_execution(name="ragproxyagent")
# agent 3 in diagram
data_analyst_agent = ConversableAgent("data_analyst_agent",
system_message=Prompts.data_analyst_agent_prompt,
llm_config=llm_config)
data_analyst_agent.register_for_llm(name="data_analyst_agent", description="Performs data analysis on the SQL query results and web sentiment data and return useful information.")
# data_analyst_agent.register_for_execution(name="data_analyst_agent")
# agent 4 in diagram
visualization_agent = ConversableAgent("visualization_agent",
system_message=Prompts.visualization_agent_prompt,
llm_config=llm_config)
visualization_agent.register_for_llm(name="visualization_agent", description="Performs data visualization on the SQL query results and web sentiment data and return useful information.")
visualization_agent.register_for_execution(name="visualization_agent")
# Agentic Workflow
csv_path = 'data/raw/carsalesdata.csv'
db_path = 'data/sqlite_db/data.db'
fetch_and_process_data(csv_path)
data_sample = query_database(db_path, "SELECT * FROM transactions LIMIT 20;").to_string(index=False)
print(f"EDA Agent will analyze {db_path} in context of {user_query}")
eda_response = db_eda_agent.generate_reply(
messages=[
{"role": "user", "content": f"Analyze the database at {db_path} based on this query: {user_query}. Here is a sample of the data:\n{data_sample}\n"}
]
)
print("Agent Response:\n", eda_response)
# Using EDA response, use RAG Agent to find metrics that could be interesting to look at
print(f"The metric agent is now analyzing the previous EDA response")
# metric_response = ragproxyagent.generate_reply(
# messages=[
# {"role": "user", "content": f"Analyze the response given here, and pay attention to the schema of the database: {eda_response}. Using your knowledge from documents passed in, return 7-8 metrics that would be worthwhile to look more into. Provide explanations for each one."}
# ]
# )
# Function to read all .txt files in the frameworks folder and return their contents as strings
def read_frameworks_as_strings(directory_path):
frameworks_content = {}
for filename in os.listdir(directory_path):
if filename.endswith('.txt'):
file_path = os.path.join(directory_path, filename)
with open(file_path, 'r') as file:
content = file.read()
frameworks_content[filename] = content
return frameworks_content
# Path to the frameworks folder
frameworks_folder_path = 'Frameworks'
# Read the .txt files and store their contents
frameworks_data = read_frameworks_as_strings(frameworks_folder_path)
# Convert the contents into f-strings for easy parsing
frameworks_fstrings = {name: f"{content}" for name, content in frameworks_data.items()}
metric_response = metric_agent.generate_reply(
messages=[
{"role": "user", "content": f"Based on this query: {eda_response}, and these sources: {frameworks_fstrings}, analyze potential metrics that would be useful for {user_query}"}
]
)
print("Metric/RAG Response", metric_response)
#Add logic for SQL generation / execution based on EDA response
sql_result = pd.DataFrame()
#Add call to web scraper agent
web_sentiments = []
#(TODO: adjust analysis/visualization agents to take correct types, assuming df and array right now.)
#Logic for data analyst agent
analysis_response = data_analyst_agent.generate_reply(
messages=[
{
"role": "user",
"content": f"Analyze the SQL query results and web sentiments based on this query: {user_query}. "
f"SQL Results: {sql_result.to_string(index=False)}\n"
f"Web Sentiments: {web_sentiments}\n"
}
]
)
print("Data Analyst Agent Response:\n", analysis_response)
#Logic for visualization agent
visualization_code = visualization_agent.generate_reply(
messages=[
{
"role": "user",
"content": f"Generate visualization code for the SQL query results and web sentiments based on this query: {user_query}. "
f"SQL Results: {sql_result.to_string(index=False)}\n"
f"Web Sentiments: {web_sentiments}\n"
}
]
)
# Execute the generated visualization code
visualizations = visualization_executor.execute_visualization_code(visualization_code, sql_result, web_sentiments)
# Return both analysis and visualizations
return {
"analysis": analysis_response,
"visualizations": visualizations
}
if __name__ == "__main__":
assert len(sys.argv) > 1, "Please ensure you include a query for some restaurant when executing main."
main(sys.argv[1])