-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
198 lines (169 loc) · 8.41 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import datetime
from flask import Flask, request, jsonify
import openai
from dotenv import load_dotenv
import os
import re
import requests
import logging
import redis
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
redisHost = os.environ.get("REDIS_HOST")
redisClient = redis.Redis(host=redisHost, port=6379, db=0)
LOGGER = logging.getLogger(__name__)
load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
scraper_api = os.getenv('SCRAPER_BASE_URL')
DIAGON_ALLEY_BASE_URL = os.environ.get("DA_BASE_URL")
class Product:
def __init__(self, name, price, color):
self.name = name
self.price = price
self.color = color
@staticmethod
def multi_product_to_string(products):
result = ""
# maintain index and print like:
# Product 1: details, Product 2: details
for index, product in enumerate(products):
result += "Product {}: {}\n".format(index + 1, product)
return result
def __repr__(self):
return "Name: {}, Price: {}, Color: {}".format(self.name, self.price, self.color)
class Route:
def __init__(self, route, method):
self.route = route
self.method = method
def get_details(self):
return {
"url": DIAGON_ALLEY_BASE_URL + self.route,
"method": self.method
}
class DiagonAlleyClient:
ORDER_HISTORY = Route("/order/all", "GET")
USER_PROFILE = Route("/auth/user/me", "GET")
def __init__(self, bearer_token):
self.bearer_token = bearer_token
def _request_creator(self, route: Route, body=None):
details = route.get_details()
if details["method"] == "GET":
response = requests.get(details["url"], headers={"Authorization": self.bearer_token})
return response
elif details["method"] == "POST":
response = requests.post(details["url"], headers={"Authorization": self.bearer_token}, json=body)
return response
else:
raise Exception("Method not supported")
def _get_order_history(self):
response = self._request_creator(self.ORDER_HISTORY)
if response.status_code != 200:
LOGGER.error("Error in getting order history")
return response.json()
def user_product_history(self):
order_history = self._get_order_history()
products_bought = []
for order in order_history:
for product in order["products"]:
products_bought.append(Product(product["name"], product["price"], product["color"]))
return products_bought
def get_user_persona(self):
response = self._request_creator(self.USER_PROFILE)
if response.status_code != 200:
LOGGER.error("Error in getting user profile")
gender = response.json()["gender"]
age = response.json()["age"]
return f"{gender} of age {age}"
@app.route("/init", methods=['GET'])
def init_conversation():
# check headers for bearer token
bearer_token = request.headers.get('Authorization')
if not bearer_token:
return jsonify({"error": "No bearer token found"})
diagon_alley = DiagonAlleyClient(bearer_token)
products_bought = diagon_alley.user_product_history()
conversation_init = [
{"role": "system", "content": "You are an outfit recommender. You converse with the user, take in their suggestions and choices, ask for details, take their previous order history into account, and generate small search strings for them to search fashion websites"},
{"role": "system", "content": "Suggest clothes for a {}".format(diagon_alley.get_user_persona())}
]
if len(products_bought) > 0:
conversation_init.append({"role": "system", "content": "You are going to be provided with the user's previously ordered products. This will help you to understand them more"})
conversation_init.append({"role": "system", "content": "The user has bought the following products in the past: {}".format(Product.multi_product_to_string(products_bought))})
conversation_init.append({"role": "system", "content": "You can use the name, color and price to estimate the kind of user preference. You can still ask these questions to the user, but this might influence your search string"})
remainder_conversation = [
{"role": "system", "content": "You have to ask users questions to get their preferences around colour, their budget, occasion"},
{"role": "system", "content": "Get these details from users unless they tell you that they don't have a preference and then generate a search string"},
{"role": "system", "content": "Provide the search string only. The format of your reply should be: 'search_string = the search string'. Do not provide any other language. The search string should have the exact cloth article. For example, 'red shirt' or 'blue jeans'"},
{"role": "system", "content": "The gender provided earlier is very important. Include it in the search string as well"}
]
conversation_init.extend(remainder_conversation)
print(conversation_init)
# generate a unique code
redis_key = int(datetime.datetime.now().timestamp())
# store json in redis
conversation_to_bytes = str(conversation_init).encode('utf-8')
redisClient.set(redis_key, conversation_to_bytes, ex=86400)
return jsonify({"conversation_id": redis_key})
@app.route('/talk/<conversationID>', methods=['POST'])
def get_bot_response(conversationID):
if not conversationID:
return jsonify({"error": "No conversation ID found"})
# get conversation from redis
conversation = redisClient.get(conversationID)
if not conversation:
return jsonify({"error": "Conversation not found"})
conversation = conversation.decode('utf-8')
conversation = eval(conversation)
try:
data = request.get_json()
# Extract the user input from the conversation data
user_input = data['conversation']
for msg in user_input:
conversation.append(msg)
conversation.append(
{"role": "system", "content": "Provide the search string only or ask further questions. Its important to create some conversation. Ask the color, occassion etc. The format of your reply should be: 'search_string = the search string'. Do not suggest the attire in the chat itself. Just give me the search string which I will then put on a shopping site. Also keep the user's gender and age in mind"},
)
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=conversation,
max_tokens=100,
)
LOGGER.info("Response created")
bot_reply = response['choices'][0]['message']['content']
# extract search_string from bot_reply
if "use this search string" in bot_reply or "search string" in bot_reply:
conversation.append({"role": "system", "content": "adhere to the format!!!"})
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=conversation,
max_tokens=100,
)
bot_reply = response['choices'][0]['message']['content']
json_match = re.search(r"search_string = (.*)", bot_reply)
if json_match:
LOGGER.info("Final search term to be returned")
search_string = json_match.group(1)
search_url = scraper_api + search_string.replace(" ", "%20").replace('"', "").replace("'", "")
response = requests.get(search_url)
# Return the search results from the API
if response.status_code != 200:
LOGGER.error("Error in getting search results")
search_results = response.json()
results_to_return = []
if search_results.get("result"):
results_to_return = search_results["result"][:5]
return jsonify({"bot_reply_type": "search_results", "search_results": results_to_return})
LOGGER.info("Continue conversation")
conversation.append({"role": "system", "content": bot_reply})
print(conversation)
redisClient.set(conversationID, str(conversation).encode('utf-8'), ex=86400)
return jsonify({"bot_reply_type": "text", "bot_reply": bot_reply})
except Exception as e:
return jsonify({"error": str(e)})
if __name__ == '__main__':
app.run(
debug=True,
host="0.0.0.0",
port="6000"
)