-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathllm_responses_utils.py
61 lines (47 loc) · 1.81 KB
/
llm_responses_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from utils.llm_completion_utils import chatCompletion, claudeCompletion
def gpt_responses(args, text: str):
user_input = text
user_message={"role": "user", "content": user_input}
messages = []
messages.append(user_message)
model_output = chatCompletion(args.test_model,
messages,
args.temperature,
args.retry_times,
args.round_sleep,
args.fail_sleep,
args.gpt_api_key,
args.gpt_base_url)
return model_output
def claude_responses(args, text: str):
user_input = text
model_output = claudeCompletion(
args.test_model,
args.max_tokens,
args.temperature,
user_input,
args.retry_times,
args.round_sleep,
args.fail_sleep,
args.claude_api_key,
args.claude_base_url)
return model_output
def llama2_responses(args, text: str):
pass
def mistral_responses(args, model, tokenizer, text: str):
user_input = [
{"role": "user", "content": text}
]
encodeds = tokenizer.apply_chat_template(user_input, return_tensors="pt")
model_inputs = encodeds.to("cuda")
model.to("cuda")
generated_ids = model.generate(model_inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=args.max_tokens, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
parts = decoded[0].split("[/INST] ")
if len(parts) > 1:
content_after_inst = parts[1]
else:
content_after_inst = ""
model_output = content_after_inst.replace("</s>", "")
return model_output
pass