-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
351 lines (267 loc) · 9.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
from llama_cpp import Llama
from dotenv import load_dotenv
import random
import asyncio
import discord
import torch
import time
import yaml
import gc
import os
load_dotenv()
# Discord bot prequisites
TOKEN = os.getenv("BOT_TOKEN")
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
client = discord.Client(intents=intents)
# models
llm = None
diff = None
# globals
recent_users = {}
# templates
system_prompts_filename = "templates/system_prompts_showcase.yaml"
with open("templates/prompt_templates.yaml", "r") as file:
prompt_templates = yaml.safe_load(file)
with open(system_prompts_filename, "r") as file:
system_prompts = yaml.safe_load(file)
prompt_template = prompt_templates["template"]
message_template = prompt_templates["message"]
# Generation parameters
stops = ["<|eot_id|>", "<|end_of_text|>"]
generation_kwargs = {
"max_tokens": 450,
"temperature": 0.9,
"top_k": 50,
"top_p": 0.9,
"stop": stops,
}
n_ctx = 8192
max_message_tokens = n_ctx - generation_kwargs["max_tokens"] - 1
def timing(start_time, checkpoints):
for key in checkpoints:
print(f"{key} - {checkpoints[key] - start_time} seconds")
start_time = checkpoints[key]
# Frees up memory from the GPU and RAM
def free_memory():
global llm, diff
llm = None
diff = None
gc.collect()
torch.cuda.empty_cache()
# Loads the model
def load_llm():
free_memory()
global llm
if llm is None:
llm = Llama.from_pretrained(
"Orenguteng/Llama-3-8B-Lexi-Uncensored-GGUF",
filename="*Q4_K_M.gguf",
n_gpu_layers=-1,
n_ctx=8192,
verbose=True,
flash_attn=True,
)
def load_diff():
free_memory()
global diff
if diff is None:
diff = StableDiffusionPipeline.from_pretrained(
"Lykon/DreamShaper",
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
)
diff.scheduler = UniPCMultistepScheduler.from_config(diff.scheduler.config)
diff = diff.to("cuda")
# Returns the response from the model
def llm_response(prompt):
global llm
if llm is None:
load_llm()
response = llm(prompt, **generation_kwargs)
return response["choices"][0]["text"]
def diff_response(prompt):
global diff
if diff is None:
load_diff()
generator = torch.manual_seed(random.randint(0, 2 ** 32 - 1))
image = diff(prompt, generator=generator, num_inference_steps=15).images[0]
image.save("image.png")
return image
async def get_user(uid, message):
# Convert uid to int
uid = int(uid)
# If the user is in the cache, return it
if uid in user_cache:
return user_cache[uid]
# Otherwise, fetch the user and add it to the cache
user = await message.channel.guild.fetch_member(uid)
user_cache[uid] = user
return user
async def parse_message_content(message):
# ping format <@661591351581999187>
content = message.content
while "<@" in content:
start = content.find("<@")
end = content.find(">")
ping = content[start:end + 1]
uid = ping[2:-1]
user = await get_user(uid, message)
username = "user"
nickname = "nick"
if user is not None:
username = user.name
nickname = user.nick if user.nick is not None else username
fping = f"@{username} ({nickname})"
content = content[:start] + fping + content[end + 1:]
return content
def count_tokens(text):
tokens = llm.tokenize(text.encode())
return len(tokens)
async def handle_messages(message, history_length=500):
channel = message.channel
token_count = 0
messages = []
async for message in channel.history(limit=history_length):
message_content = await parse_message_content(message)
author = message.author
if message.author == client.user:
templated_message = message_template.format(user="assistant (Sparky)", user_message=message_content)
else:
name = f'{message.author.name} ({author.nick if author.nick is not None else author.name})'
templated_message = message_template.format(user=name, user_message=message_content)
token_count += count_tokens(templated_message)
if token_count > max_message_tokens:
print("token count: ", token_count)
break
if message.content == "!split":
break
messages.append(templated_message)
messages.reverse()
print("message count: ", len(messages))
messages = "".join(messages)
print("max token count: ", max_message_tokens)
print("token count: ", token_count, "\n\n")
return messages
def construct_prompt(messages):
global max_message_tokens, system_prompts
with open(system_prompts_filename, "r") as file:
system_prompts = yaml.safe_load(file)
config = system_prompts["config"].split(" ")
system_prompt = ""
for subprompt in config:
system_prompt += system_prompts[subprompt] + "\n"
max_message_tokens = n_ctx - generation_kwargs["max_tokens"] - 1 - count_tokens(system_prompt) - 500
return prompt_template.format(system_prompt=system_prompt, messages=messages)
# functions are in the middle of the llm response
# they are in a format like this:
# the whole function is surrounded by square brackets
# [function_name arg1 arg2 arg3 ...]
async def handle_functions(response, message):
# find curly brackets
if "[" not in response or "]" not in response:
return None
start = response.find("[")
end = response.find("]")
slice = response[start + 1:end]
function = slice.split(" ")[0].lower()
args = slice.split(" ")[1:]
print("function: ", function)
print("args: ", args)
match function:
case "img":
diff_response(" ".join(args))
await message.channel.send(file=discord.File("image.png"))
if llm is None:
load_llm()
case _:
return None
async def text_pipeline(message):
async with message.channel.typing():
start_time = time.time()
checkpoints = {}
if llm is None:
load_llm()
checkpoints["load_llm"] = time.time()
messages = await handle_messages(message)
checkpoints["messages"] = time.time()
prompt = construct_prompt(messages)
print("prompt: \n", prompt)
checkpoints["prompt"] = time.time()
response = llm_response(prompt)
checkpoints["response"] = time.time()
print("response: \n\n", response)
await message.channel.send(response)
checkpoints["send"] = time.time()
await handle_functions(response, message)
checkpoints["functions"] = time.time()
timing(start_time, checkpoints)
return response
async def manual_image_pipeline(message):
async with message.channel.typing():
start_time = time.time()
checkpoints = {}
if diff is None:
load_diff()
checkpoints["load_diff"] = time.time()
prompt = " ".join(message.content.split(" ")[1:])
checkpoints["prompt"] = time.time()
print("prompt: ", prompt)
diff_response(prompt)
checkpoints["response"] = time.time()
await message.channel.send(file=discord.File("image.png"))
checkpoints["send"] = time.time()
load_llm()
checkpoints["free and load_llm"] = time.time()
timing(start_time, checkpoints)
def handle_prefix(message):
prefix = message.content.split(" ")[0].lower()
if prefix == "!split":
return "split"
if message.channel.name == "natural-chat" and message.author != client.user:
return "normal"
if prefix == "!s":
return "normal"
elif prefix == "!img":
return "img"
elif prefix == "!freemem" and message.author.name == "gapi505":
return "freemem"
elif prefix == "!load" and message.author.name == "gapi505":
return "load"
else:
return None
# def spam_prevention(message):
# global recent_users
# # recent users: dict user: (time, count)
# if message.author not in recent_users:
# recent_users[message.author] = (time.time(), 1)
# return False
# else:
# time_diff = time.time() - recent_users[message.author][0]
# if time_diff < 5:
# recent_users[message.author] = (time.time(), recent_users[message.author][1] + 1)
# return True
# else:
# recent_users[message.author] = (time.time(), 1)
# return False
# Discord bot
@client.event
async def on_ready():
print(f"We have logged in as {client.user}")
@client.event
async def on_message(message):
# if message.author == client.user:
# bot_message = True
handled = handle_prefix(message)
if handled == "normal": # Normal text generation
await text_pipeline(message)
elif handled == "img": # Manual image generation
await manual_image_pipeline(message)
elif handled == "freemem": # Free memory
free_memory()
elif handled == "load": # Load the model
load_llm()
client.run(TOKEN)