Skip to content

Commit

Permalink
update: meta_evaluation module
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Jan 2, 2025
1 parent 14a92df commit 0e6261e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,5 @@ artifacts/
evaluation_results/
.cache_dir/
checkpoints/
.DS_Store
.DS_Store
**.json
3 changes: 3 additions & 0 deletions safeguards/meta_evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .evaluation_classification import EvaluationClassifier

__all__ = ["EvaluationClassifier"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from typing import Optional

import weave
from rich.progress import track

from .trace_utils import serialize_input_output_objects


class EvaluationTraceParser:
class EvaluationClassifier:

def __init__(self, project: str, call_id: str) -> None:
self.base_call = weave.init(project).get_call(call_id=call_id)
Expand All @@ -13,17 +16,28 @@ def __init__(self, project: str, call_id: str) -> None:
def _get_call_name_from_op_name(self, op_name: str) -> str:
return op_name.split("/")[-1].split(":")[0]

def register_predict_and_score_calls(self):
def register_predict_and_score_calls(
self,
max_predict_and_score_calls: Optional[int] = None,
save_filepath: Optional[str] = None,
):
count_traces_parsed = 0
for predict_and_score_call in track(
self.base_call.children(), description="Parsing predict and score calls"
self.base_call.children(),
description="Parsing predict and score calls",
total=max_predict_and_score_calls - 1,
):
if "Evaluation.summarize" in predict_and_score_call._op_name:
break
elif "Evaluation.predict_and_score" in predict_and_score_call._op_name:
self.predict_and_score_calls.append(
self.parse_call(predict_and_score_call)
)
break
count_traces_parsed += 1
if count_traces_parsed == max_predict_and_score_calls:
break
if len(self.predict_and_score_calls) > 0 and save_filepath is not None:
self.save_calls(save_filepath)

def parse_call(self, child_call) -> dict:
call_dict = {
Expand All @@ -33,3 +47,7 @@ def parse_call(self, child_call) -> dict:
"child_calls": [self.parse_call(child) for child in child_call.children()],
}
return call_dict

def save_calls(self, filepath: str):
with open(filepath, "w") as file:
json.dump(self.predict_and_score_calls, file, indent=4)

0 comments on commit 0e6261e

Please sign in to comment.