From 276ffb0beb27e5d7302d12fe06101eb1494b39bf Mon Sep 17 00:00:00 2001 From: Jeffrey Ip Date: Tue, 28 Jan 2025 00:25:06 -0800 Subject: [PATCH] reformat --- deepeval/integrations/langchain/callback.py | 40 ++++++++++++++----- deepeval/integrations/llama_index/callback.py | 2 +- deepeval/tracing/tracer.py | 2 +- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/deepeval/integrations/langchain/callback.py b/deepeval/integrations/langchain/callback.py index 9d43e2a00..c808e2d49 100644 --- a/deepeval/integrations/langchain/callback.py +++ b/deepeval/integrations/langchain/callback.py @@ -46,16 +46,19 @@ from langsmith import utils as ls_utils import warnings -warnings.filterwarnings("ignore", category=ls_utils.LangSmithMissingAPIKeyWarning) +warnings.filterwarnings( + "ignore", category=ls_utils.LangSmithMissingAPIKeyWarning +) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) + class LangChainCallbackHandler(BaseTracer): def __init__(self, auto_eval=False, *args, **kwargs) -> None: self.auto_eval = auto_eval self.event_map: Dict[str, BaseTrace] = {} self.track_params: Dict[str, Dict] = {} - self.event_map_lock = threading.Lock() + self.event_map_lock = threading.Lock() super().__init__(*args, **kwargs) def _start_trace(self, run: Run) -> None: @@ -67,10 +70,19 @@ def _start_trace(self, run: Run) -> None: self.track_params[run_id_string] = {} self.run_map[str(run.id)] = run with self.event_map_lock: - root_parent_id = self.event_map[parent_id_string].rootParentId if parent_id else run_id_string - event_type = self.convert_event_type_to_deepeval_trace_type(run.run_type) + root_parent_id = ( + self.event_map[parent_id_string].rootParentId + if parent_id + else run_id_string + ) + event_type = self.convert_event_type_to_deepeval_trace_type( + run.run_type + ) trace_instance = self.create_trace_instance( - event_type=event_type, name=run.name, parent_id=parent_id_string, root_parent_id=root_parent_id + event_type=event_type, + name=run.name, + parent_id=parent_id_string, + root_parent_id=root_parent_id, ) # Update event_map with trace instance self.event_map[run_id_string] = trace_instance @@ -115,8 +127,12 @@ def _end_trace(self, run: Run) -> None: track_params = self.track_params.get(run.id, {}) dict_representation = dataclass_to_dict(trace_instance) if trace_instance.type == LangChainTraceType.CHAIN: - track_params["input"] = trace_instance.chainAttributes.input - track_params["response"] = trace_instance.chainAttributes.output + track_params["input"] = ( + trace_instance.chainAttributes.input + ) + track_params["response"] = ( + trace_instance.chainAttributes.output + ) monitor( event_name=trace_instance.name, model=track_params.get("model") or "NA", @@ -131,7 +147,7 @@ def _end_trace(self, run: Run) -> None: parent_trace = current_trace_stack[-1] parent_trace.traces.append(trace_instance) trace_manager.set_trace_stack(current_trace_stack) - + # Delete trace_instance from the event_map once processed del self.event_map[str(run.id)] @@ -140,7 +156,11 @@ def _end_trace(self, run: Run) -> None: ############################################ def create_trace_instance( - self, event_type: LangChainTraceType | str, name: str, parent_id: Optional[str], root_parent_id: str + self, + event_type: LangChainTraceType | str, + name: str, + parent_id: Optional[str], + root_parent_id: str, ) -> TraceData: trace_kwargs = { "traceProvider": TraceProvider.LANGCHAIN, @@ -152,7 +172,7 @@ def create_trace_instance( "inputPayload": None, "outputPayload": None, "parentId": parent_id, - "rootParentId": root_parent_id + "rootParentId": root_parent_id, } if event_type == LangChainTraceType.CHAIN: trace_kwargs["chainAttributes"] = None diff --git a/deepeval/integrations/llama_index/callback.py b/deepeval/integrations/llama_index/callback.py index 023592e86..5389d14b2 100644 --- a/deepeval/integrations/llama_index/callback.py +++ b/deepeval/integrations/llama_index/callback.py @@ -193,7 +193,7 @@ def create_trace_instance( "inputPayload": None, "outputPayload": None, "parentId": None, - "rootParentId": None + "rootParentId": None, } if "exception" in processed_payload: diff --git a/deepeval/tracing/tracer.py b/deepeval/tracing/tracer.py index 4a4581eef..9895d08fb 100644 --- a/deepeval/tracing/tracer.py +++ b/deepeval/tracing/tracer.py @@ -487,7 +487,7 @@ def create_trace_instance( "inputPayload": None, "outputPayload": None, "parentId": None, - "rootParentId": None + "rootParentId": None, } if trace_provider == TraceProvider.DEFAULT: if trace_type == TraceType.AGENT: