-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathevaluate.py
95 lines (78 loc) · 3.71 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import re
import click
from datetime import datetime
import yaml
from dataset_interfaces.factory import DatasetFactory
from dataset_interfaces.interface import TestExample
from reporting.results import TestResult
from utils.files import gather_testdef_files, make_result_path, make_config_path
from utils.ui import colour_print, ask_yesno
def reconstruct_history(result: TestResult) -> list[dict[str, str | datetime]]:
start_pattern = r"^((Test|Agent) (\(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}\))?:)"
history = list()
for line in result.task_log:
m = re.match(start_pattern, line)
history.append(
dict(
role="user" if m.group(2) == "Test" else "assistant",
content=line.removeprefix(m.group(0) + " "),
timestamp=datetime.fromisoformat(m.group(3).strip("()")),
)
)
return history
def reconstruct_messages_timestamps(history: list[dict[str, str | datetime]], script: list[str]) -> list[datetime]:
script_lines = set(script)
return [msg["timestamp"] for msg in history if msg["content"] in script_lines]
def extract_questions(example: TestExample) -> list[str]:
return [line for line, is_q in zip(example.script, example.is_question) if is_q]
@click.command("evaluate")
@click.argument("run_name", type=str)
@click.argument("agent_name", type=str)
@click.argument("dataset_name", type=str, required=False, default="*")
@click.option("-y", required=False, is_flag=True, default=False, help="Automatically assent to questions")
def main(run_name: str, agent_name: str, dataset_name:str, y: bool):
_main(run_name, agent_name, dataset_name, y)
def _main(run_name: str, agent_name: str, dataset_name: str, y: bool):
examples = []
with open(make_config_path(run_name), "rb") as file:
yaml_configuration = yaml.safe_load(file)
for path in gather_testdef_files(run_name, dataset_name=dataset_name):
dataset = DatasetFactory.create_dataset_for_example(yaml_configuration, path)
examples.append(TestExample.load(dataset, path))
results = list()
for example in examples:
result_path = make_result_path(run_name, agent_name, example.dataset_name, example.example_id, 0)
assert result_path.exists(), f"Can't re-evaluate without an existing result file: {result_path}"
colour_print("yellow", f"Evaluating {result_path}")
result = TestResult.from_file(result_path)
if not example.uses_callback:
if example.is_temporal:
# Get question from task log instead.
questions = [result.task_log[-2].split(":")[1]]
else:
questions = extract_questions(example)
result.score, result.max_score, result.reasoning = example.evaluation_fn(
questions,
result.actual_responses,
example.expected_responses,
)
else:
callback = example.dataset_generator.continual_evaluation_callback
scheduler = None
result.score, result.max_score, result.reasons, deregister = callback(scheduler, example, result.full_log)
if not deregister:
colour_print("red", "WARNING: The following result did not deregister the callback.")
print(result)
results.append(result)
colour_print("green", "All tests have been re-evaluated.")
if not y and not ask_yesno(
info="Please inspect the evaluations carefully.",
question="Do you wish to overwrite the result files?",
default_yes=False,
):
return
for result in results:
result.save()
colour_print("green", "Test results have been overwritten.")
if __name__ == "__main__":
main()