-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
313 lines (245 loc) · 12.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import sys
import os
from llm import *
from autogen import ConversableAgent
from autogen import AssistantAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from prompts import Prompts
from typing import Dict, Any
from webagent import summarize_online_and_review_data
# from webAgent import main as web_agent_main
#from dataframe_analyzer import is_data_relevant, is_sample_size_sufficient
warnings.filterwarnings("ignore", category=DeprecationWarning)
from datetime import datetime
import os
import sqlite3
import logging
import visualization_executor
# Set up logging
logging.basicConfig(filename="pipeline.log", level=logging.INFO)
project_dir = os.getcwd()
def query_database(db_path: str, query: str) -> pd.DataFrame:
"""
Executes an SQL query on the given SQLite database and returns the results as a DataFrame.
"""
conn = sqlite3.connect(db_path)
result = pd.read_sql_query(query, conn)
conn.close()
return result
def fetch_and_process_data(csv_path):
# Load the carsalesdata.csv file
df = pd.read_csv(csv_path)
logging.info(f"Loaded carsalesdata.csv with {df.shape[0]} rows.")
# Simulate data cleaning
df.dropna(inplace=True) # Example cleaning step
logging.info("Data cleaned. Dropped NA values.")
# Save to SQLite database
conn = sqlite3.connect('data/sqlite_db/data.db')
df.to_sql('transactions', conn, if_exists='replace', index=False)
conn.close()
logging.info("Data saved to SQLite database.")
def is_data_relevant(data: Any) -> bool:
"""
Placeholder function to check if the data is relevant.
"""
return bool(data) and not data.empty
def is_sample_size_sufficient(data: Any) -> bool:
"""
Placeholder function to check if the sample size is sufficient.
"""
return len(data) >= 10
# def evaluate_query_results(data: Any, original_prompt: str, sql_generator_agent: ConversableAgent, llm: LLM) -> None:
# """
# Evaluates query results and determines next actions based on data quality.
# - If data is relevant and useful, calls the data analyst function with the data and original prompt.
# - If data is not relevant or the sample size is too small, prompts Agent 1 for a new SQL query.
# """
# if is_data_relevant(data) and is_sample_size_sufficient(data):
# # Data meets quality criteria - send to analyst
# logging.info("Data is relevant and sufficient. Proceeding to analysis.")
# analyze_data(data, original_prompt)
# else:
# # Data does not meet quality criteria - request new SQL query from Agent 1
# logging.warning("Data is not relevant or insufficient. Reinvoking Agent 1 for a new SQL query.")
# new_prompt = f"""The previous SQL query did not yield relevant or sufficient results.
# Original request: "{original_prompt}"
# Please generate a new SQL query with refined conditions to improve relevance or increase the sample size."""
# # Generate a new SQL query using Agent 1
# new_query = llm.generate_sql_query(new_prompt)
# logging.info(f"Generated new SQL query: {new_query}")
# # Execute the new SQL query using Agent 2
# csv_file_path = "datasets/samples.csv"
# new_data = llm.execute_sql_query(new_query, csv_file_path)
# # Re-evaluate the new results
# evaluate_query_results(new_data, original_prompt, sql_generator_agent, llm)
def analyze_data(data: Any, prompt: str) -> None:
"""
Function call to data analyst with validated data and original prompt.
"""
logging.info("Sending data to data analyst for further analysis.")
print("Data is relevant and sufficient. Sending to data analyst.")
print(f"Prompt: {prompt}")
print(f"Data: {data}")
def main(user_query):
api_key = 'JONAH'
llm = LLM()
llm_config = {"config_list": [{"model": "gpt-4o-mini", "api_key": api_key}]}
evaluation_agent = ConversableAgent("data_eval_agent",
system_message=Prompts.evaluation,
llm_config=llm_config)
evaluation_agent.register_for_llm(name="data_eval_agent", description="Evaluates data and gauges whether it is good enough or not for use in answering prompt.")
evaluation_agent.register_for_execution(name="data_eval_agent")
# agent 0a in diagram
db_eda_agent = ConversableAgent("db_eda_agent",
system_message=Prompts.database_EDA_agent_prompt,
llm_config=llm_config)
db_eda_agent.register_for_llm(name="db_eda_agent", description="Performs exploratory data analysis on the database and return useful information.")
db_eda_agent.register_for_execution(name="db_eda_agent")
# agent 0b in diagram
metric_agent = ConversableAgent("metric_agent",
system_message=Prompts.metric_agent_prompt,
llm_config=llm_config)
metric_agent.register_for_llm(name="metric_agent", description="Combines EDA with dealership-specific frameworks to provide metrics to look at")
metric_agent.register_for_execution(name="metric_agent")
# RAG Agent
# Get all file paths in the 'Frameworks' directory; to setup RAG
frameworks_dir = 'Frameworks'
docs_paths = [os.path.join(frameworks_dir, file) for file in os.listdir(frameworks_dir) if os.path.isfile(os.path.join(frameworks_dir, file))]
# retrieve_config = {"config_list": [{"model": "gpt-4o-mini", "api_key": api_key}]}
# # agent 0b with rag
# ragproxyagent = RetrieveUserProxyAgent(
# # name="ragproxyagent",
# system_message=Prompts.RAG_agent_prompt,
# # retrieve_config = {
# # "task":"qa",
# # "docs_path":docs_paths
# # })
# name="RAG Proxy Agent",
# is_termination_msg=None, # Assuming termination_msg is defined elsewhere
# # system_message="Running RAG agent who has extra content retrieval power for solving difficult problems.",
# human_input_mode="NEVER",
# max_consecutive_auto_reply=3,
# retrieve_config={
# "task": "qa",
# "model" : "gpt-4o-mini",
# "docs_path": docs_paths
# },
# code_execution_config=False
# )
# ragproxyagent.register_for_llm(name="ragproxyagent")
# ragproxyagent.register_for_execution(name="ragproxyagent")
# agent 3 in diagram
data_analyst_agent = ConversableAgent("data_analyst_agent",
system_message=Prompts.data_analyst_agent_prompt,
llm_config=llm_config)
data_analyst_agent.register_for_llm(name="data_analyst_agent", description="Performs data analysis on the SQL query results and web sentiment data and return useful information.")
# data_analyst_agent.register_for_execution(name="data_analyst_agent")
# agent 4 in diagram
visualization_agent = ConversableAgent("visualization_agent",
system_message=Prompts.visualization_agent_prompt,
llm_config=llm_config)
visualization_agent.register_for_llm(name="visualization_agent", description="Performs data visualization on the SQL query results and web sentiment data and return useful information.")
visualization_agent.register_for_execution(name="visualization_agent")
# Agentic Workflow
csv_path = 'data/raw/carsalesdata.csv'
db_path = 'data/sqlite_db/data.db'
fetch_and_process_data(csv_path)
data_sample = query_database(db_path, "SELECT * FROM transactions LIMIT 20;").to_string(index=False)
print(f"EDA Agent will analyze {db_path} in context of {user_query}")
eda_response = db_eda_agent.generate_reply(
messages=[
{"role": "user", "content": f"Analyze the database at {db_path} based on this query: {user_query}. Here is a sample of the data:\n{data_sample}\n"}
]
)
print("Agent Response:\n", eda_response)
# Using EDA response, use RAG Agent to find metrics that could be interesting to look at
print(f"The metric agent is now analyzing the previous EDA response")
# metric_response = ragproxyagent.generate_reply(
# messages=[
# {"role": "user", "content": f"Analyze the response given here, and pay attention to the schema of the database: {eda_response}. Using your knowledge from documents passed in, return 7-8 metrics that would be worthwhile to look more into. Provide explanations for each one."}
# ]
# )
# Function to read all .txt files in the frameworks folder and return their contents as strings
# def read_frameworks_as_strings(directory_path):
# frameworks_content = {}
# for filename in os.listdir(directory_path):
# if filename.endswith('.txt'):
# file_path = os.path.join(directory_path, filename)
# with open(file_path, 'r') as file:
# content = file.read()
# frameworks_content[filename] = content
# return frameworks_content
# # Path to the frameworks folder
# frameworks_folder_path = 'Frameworks'
# # Read the .txt files and store their contents
# frameworks_data = read_frameworks_as_strings(frameworks_folder_path)
# # Convert the contents into f-strings for easy parsing
# frameworks_fstrings = {name: f"{content}" for name, content in frameworks_data.items()}
# metric_response = metric_agent.generate_reply(
# messages=[
# {"role": "user", "content": f"Based on this query: {eda_response}, and these sources: {frameworks_fstrings}, analyze potential metrics that would be useful for {user_query}"}
# ]
# )
# print("Metric/RAG Response", metric_response)
# use counter avoid inf loop
count = 0
while True:
count += 1
# Generate SQL query
sql_query = llm.generate_sql_query(eda_response, user_query)
print(f"SQL Query Generated: \n {sql_query}")
sql_result = llm.execute_sql_query(sql_query, csv_path)
if isinstance(sql_result, pd.DataFrame):
print(f"SQL Query executed successfully: {sql_result}")
result_satisfactory = evaluation_agent.generate_reply(messages=[
{"role": "user", "content": f"Does the SQL data output: {sql_result} contain enough data to help us answer the user's query: {user_query}? Be lenient :)\n"}
])
if "continue" in result_satisfactory or count == 3:
print("Data retrieved from the database is satisfactory. Advancing to analysis.")
break
elif "redo" in result_satisfactory:
print("Data received from querying database is not relevant or enough to answer prompt. Regenerating a new query.")
continue
#Add call to web scraper agent
web_sentiments = []
#(TODO: adjust analysis/visualization agents to take correct types, assuming df and array right now.)
#Logic for data analyst agent
# web_summary = summarize_online_and_review_data(user_query)
# print(web_summary)
analysis_response = data_analyst_agent.generate_reply(
messages=[
{
"role": "user",
"content": f"Analyze the SQL query results and web sentiments based on this query: {user_query}. "
f"SQL Results: {sql_result.to_string(index=False)}\n"
f"Web Sentiments: {web_sentiments}\n"
}
]
)
print("Data Analyst Agent Response:\n", analysis_response)
#Logic for visualization agent
print(f"Visualization Agent now generating code...")
visualization_code = visualization_agent.generate_reply(
messages=[
{
"role": "user",
"content": f"Generate visualization code for the SQL query results and web sentiments based on this query: {user_query}. "
f"SQL Results: {sql_result.to_string(index=False)}\n"
f"Web Sentiments: {web_sentiments}\n"
}
]
)
# Execute the generated visualization code
visualizations = visualization_executor.execute_visualization_code(visualization_code, sql_result, web_sentiments)
# Return both analysis and visualizations
return {
"analysis": analysis_response,
"visualizations": visualizations
}
if __name__ == "__main__":
assert len(sys.argv) > 1, "Please ensure you include a query for some car when executing main."
main(sys.argv[1])