-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi_utils.py
71 lines (57 loc) · 2.51 KB
/
api_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from openai import OpenAI
import torch
# Load constants from environment variables
MODEL = os.getenv("OPENAI_MODEL", "gpt-4o")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)
class ChatHistory:
def __init__(self, system_message, logger):
self.messages = [{"role": "system", "content": system_message}]
self.logger = logger
def add_prompt(self, prompt):
self.messages.append({"role": "user", "content": prompt})
self.logger.info(f"Prompt: {prompt}")
def add_response(self, response):
self.messages.append({"role": "assistant", "content": response})
self.logger.info(f"Response: {response}")
def generate_response(self, prompt, example_num=0):
try:
self.add_prompt(prompt)
completion = client.chat.completions.create(
model=MODEL,
messages=self.messages,
seed=example_num,
)
response = completion.choices[0].message.content.strip()
self.add_response(response)
return response
except client.APIError as e:
self.logger.error(f"OpenAI API error: {e}")
return None
except Exception as e:
self.logger.error(f"Unexpected error: {e}")
return None
class Hf_ChatHistory:
def __init__(self, system_message, logger, model, tokenizer):
#self.messages = [{"role": "system", "content": system_message}]
self.model = model
self.tokenizer = tokenizer
self.logger = logger
def generate_response(self, prompt):
# Tokenize the input
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
# Generate text (enable sampling and set other parameters)
output_sequences = self.model.generate(
input_ids=input_ids,
do_sample=True, # Enable sampling
max_new_tokens=2500, # Limit the length of the output
temperature=0.7, # Control randomness
top_k=50, # Limit vocabulary for generation
top_p=0.95, # Nucleus sampling
pad_token_id=self.tokenizer.pad_token_id, # Set pad token ID
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(self.model.device) # Create attention mask
)
# Decode the generated output
generated_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
return generated_text