-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanswer.py
36 lines (31 loc) · 1.23 KB
/
answer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
def process_files(dialogue_file, predictions_file, output_file):
# 读取对话数据
with open(dialogue_file, 'r', encoding='utf-8') as file:
dialogues = json.load(file)
# 读取预测数据
predictions = []
with open(predictions_file, 'r', encoding='utf-8') as file:
for line in file:
pred_data = json.loads(line)
# 只取前四个大写字符作为MBTI类型
mbti = ''.join([char for char in pred_data["predict"] if char.isupper()][:4])
predictions.append(mbti)
# 根据索引合并数据
results = []
for index, dialogue in enumerate(dialogues):
speaker = dialogue['output']
mbti_type = predictions[index]
result = {
"Speaker": speaker,
"MBTI": mbti_type
}
results.append(result)
# 保存结果
with open(output_file, 'w', encoding='utf-8') as file:
json.dump(results, file, indent=4, ensure_ascii=False)
# 调用函数处理文件
dialogue_file = '/root/autodl-tmp/dialogues.json'
predictions_file = '/root/autodl-tmp/generated_predictions.jsonl'
output_file = '/root/autodl-tmp/results.json'
process_files(dialogue_file, predictions_file, output_file)