-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_reka.py
76 lines (67 loc) · 2.92 KB
/
run_reka.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from reka.client import Reka
from reka import ChatMessage
import os, base64, argparse, json
from tqdm import tqdm
def video_to_url(video_path):
with open(video_path, 'rb') as video_file:
video_data = video_file.read()
video_base64 = base64.b64encode(video_data).decode('utf-8')
video_url = f'data:video/mp4;base64,{video_base64}'
return video_url
def get_response(client, model_version, video_path, question):
response = client.chat.create(
messages=[
ChatMessage(
content=[
{"type": "video_url", "video_url": video_to_url(video_path)},
{"type": "text", "text": question}
],
role="user",
)
],
model=model_version,
)
llm_response = response.responses[0].message.content
return llm_response
answer_prompt = {
"multi-choice": "\nPlease directly give the best option:",
"yes_no": "\nPlease answer yes or no:",
"caption_matching": "\nPlease directly give the best option:",
"captioning": "" # The answer "Generated Caption:" is already contained in the question
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--video_path', default='videos')
parser.add_argument('--output_path', default='predictions')
parser.add_argument('--model_version', default='reka-flash')
parser.add_argument('--task_type', default='multi-choice', choices=['multi-choice', 'captioning', 'caption_matching', 'yes_no'])
args = parser.parse_args()
# Loading questions
question_path = f"questions/{args.task_type}.json"
with open(question_path, 'r') as f:
input_datas = json.load(f)
output_path = f"{args.output_path}/{args.model_version}"
os.makedirs(output_path, exist_ok=True)
pred_file = f"{output_path}/{args.task_type}.json"
# Loading existing predictions
if os.path.isfile(pred_file):
with open(pred_file, 'r') as f:
predictions = json.load(f)
else:
predictions = {}
# Setup REKA API and client
reka_api_key = os.environ.get('REKA_API_KEY')
client = Reka(api_key=reka_api_key)
for vid, data in tqdm(input_datas.items()):
if vid not in predictions:
print(vid)
predictions[vid] = {}
video_path = os.path.join(args.video_path, f'{vid}.mp4')
for dim, questions in data.items():
predictions[vid][dim] = []
for question in questions:
inp = question['question'] + answer_prompt[args.task_type]
video_llm_pred = get_response(client, args.model_version, video_path, inp)
predictions[vid][dim].append({'question': question['question'], 'answer': question['answer'], 'prediction': video_llm_pred})
with open(pred_file, 'w') as f:
json.dump(predictions, f, indent=4)