-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathllm_model.py
108 lines (88 loc) · 4.15 KB
/
llm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
import time
import torch
import cml.metrics_v1 as metrics
import cml.models_v1 as models
# Quantization
# Here quantization is setup to use "Normal Float 4" data type for weights.
# This way each weight in the model will take up 4 bits of memory.
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
# Create a model object with above parameters
model_name = "NousResearch/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map='auto',
)
# Args helper
def opt_args_value(args, arg_name, default):
"""
Helper function to interact with LLMs parameters for each call to the model.
Returns value provided in args[arg_name] or the default value provided.
"""
if arg_name in args.keys():
return args[arg_name]
else:
return default
# Define tokenizer parameters
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Mamke the call to
def generate(prompt, max_new_tokens=50, temperature=0, repetition_penalty=1.0, num_beams=1, top_p=1.0, top_k=0):
"""
Make a request to the LLM, with given parameters (or using default values).
max_new_tokens - at how many words will the generated response be capped?
temperature - a.k.a. "response creatibity". Controls randomness of the generated response (0 = least random, 1 = more random).
repetition_penalty - penalizes the next token if it has already been used in the response (1 = no penlaty)
num_beams - controls the number of token sequences generate (1 = only one sequence generated)
top_p - cumulative probability to determine how many tokens to keep (i.e. enough tokens will be considered, so their combined probabiliy reaches top_p)
top_k - numbe of highest-probability tokens to keep (i.e. only top_k "best" tokens will be considered for response)
"""
batch = tokenizer(prompt, return_tensors='pt')
with torch.cuda.amp.autocast():
output_tokens = model.generate(**batch,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
temperature=temperature,
num_beams=num_beams,
top_p=top_p,
top_k=top_k)
output=tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Log the response along with parameters
print("Prompt: %s" % (prompt))
print("max_new_tokens: %s; temperature: %s; repetition_penalty: %s; num_beams: %s; top_p: %s; top_k: %s" % (max_new_tokens, temperature, repetition_penalty, num_beams, top_p, top_k))
print("Full Response: %s" % (output))
return output
@models.cml_model(metrics=True)
def api_wrapper(args):
"""
Process an incoming API request and return a JSON output.
"""
start = time.time()
# Pick up args from model api
prompt = args["prompt"]
# Pick up or set defaults for inference options
# TODO: More intelligent control of max_new_tokens
temperature = float(opt_args_value(args, "temperature", 0))
max_new_tokens = float(opt_args_value(args, "max_new_tokens", 50))
top_p = float(opt_args_value(args, "top_p", 1.0))
top_k = int(opt_args_value(args, "top_k", 0))
repetition_penalty = float(opt_args_value(args, "repetition_penalty", 1.0))
num_beams = int(opt_args_value(args, "num_beams", 1))
# Generate response from the LLM
response = generate(prompt, max_new_tokens, temperature, repetition_penalty, num_beams, top_p, top_k)
# Calculate elapsed time
response_time = time.time() - start
# Track model outputs over time
metrics.track_metric("prompt", prompt)
metrics.track_metric("response", response)
metrics.track_metric("response_time_s", response_time)
return {"response": response, "response_time_s": round(response_time,1)}