-
Notifications
You must be signed in to change notification settings - Fork 38
/
openai_model.py
83 lines (70 loc) · 2.25 KB
/
openai_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Use to get completions from an OpenAI completions endpoint. This version
of the script only works with Azure OpenAI service, since OpenAI no longer
hosts their code completion models.
"""
from typing import List
from multipl_e.completions import partial_arg_parser, make_main
import json
import openai
import openai.error
import os
import time
from typing import List
global engine, model
def completions(
prompts: List[str], max_tokens: int, temperature: float, top_p, stop
) -> List[str]:
results = []
for prompt in prompts:
kwargs = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stop": stop
}
if engine is not None:
kwargs["engine"] = engine
elif model is not None:
kwargs["model"] = model
while True:
try:
result = openai.Completion.create(**kwargs)
result = results["choices"][0]["text"]
break
except openai.error.RateLimitError:
print("Rate limited...")
time.sleep(5)
results.append(result)
time.sleep(0.5)
return results
def main():
global engine, model
args = partial_arg_parser()
args.add_argument("--model", type=str)
args.add_argument("--engine", type=str)
args.add_argument("--name-override", type=str)
args.add_argument("--azure", action="store_true")
args = args.parse_args()
if args.engine is None and args.model is None:
raise ValueError("Must specify either engine or model.")
elif args.engine is not None and args.model is not None:
raise ValueError("Must specify either engine or model, not both.")
engine = args.engine
model = args.model
if args.azure:
openai.api_type = "azure"
openai.api_base = os.getenv("OPENAI_API_BASE")
openai.api_version = "2022-12-01"
openai.api_key = os.getenv("OPENAI_API_KEY")
if args.name_override:
name = args.name_override
else:
if args.engine is not None:
name = args.engine
else:
name = args.model
make_main(args, name, completions)
if __name__ == "__main__":
main()