Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
penguine-ip committed Jan 28, 2025
1 parent 9076f1f commit 276ffb0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
40 changes: 30 additions & 10 deletions deepeval/integrations/langchain/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,19 @@
from langsmith import utils as ls_utils
import warnings

warnings.filterwarnings("ignore", category=ls_utils.LangSmithMissingAPIKeyWarning)
warnings.filterwarnings(
"ignore", category=ls_utils.LangSmithMissingAPIKeyWarning
)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class LangChainCallbackHandler(BaseTracer):
def __init__(self, auto_eval=False, *args, **kwargs) -> None:
self.auto_eval = auto_eval
self.event_map: Dict[str, BaseTrace] = {}
self.track_params: Dict[str, Dict] = {}
self.event_map_lock = threading.Lock()
self.event_map_lock = threading.Lock()
super().__init__(*args, **kwargs)

def _start_trace(self, run: Run) -> None:
Expand All @@ -67,10 +70,19 @@ def _start_trace(self, run: Run) -> None:
self.track_params[run_id_string] = {}
self.run_map[str(run.id)] = run
with self.event_map_lock:
root_parent_id = self.event_map[parent_id_string].rootParentId if parent_id else run_id_string
event_type = self.convert_event_type_to_deepeval_trace_type(run.run_type)
root_parent_id = (
self.event_map[parent_id_string].rootParentId
if parent_id
else run_id_string
)
event_type = self.convert_event_type_to_deepeval_trace_type(
run.run_type
)
trace_instance = self.create_trace_instance(
event_type=event_type, name=run.name, parent_id=parent_id_string, root_parent_id=root_parent_id
event_type=event_type,
name=run.name,
parent_id=parent_id_string,
root_parent_id=root_parent_id,
)
# Update event_map with trace instance
self.event_map[run_id_string] = trace_instance
Expand Down Expand Up @@ -115,8 +127,12 @@ def _end_trace(self, run: Run) -> None:
track_params = self.track_params.get(run.id, {})
dict_representation = dataclass_to_dict(trace_instance)
if trace_instance.type == LangChainTraceType.CHAIN:
track_params["input"] = trace_instance.chainAttributes.input
track_params["response"] = trace_instance.chainAttributes.output
track_params["input"] = (
trace_instance.chainAttributes.input
)
track_params["response"] = (
trace_instance.chainAttributes.output
)
monitor(
event_name=trace_instance.name,
model=track_params.get("model") or "NA",
Expand All @@ -131,7 +147,7 @@ def _end_trace(self, run: Run) -> None:
parent_trace = current_trace_stack[-1]
parent_trace.traces.append(trace_instance)
trace_manager.set_trace_stack(current_trace_stack)

# Delete trace_instance from the event_map once processed
del self.event_map[str(run.id)]

Expand All @@ -140,7 +156,11 @@ def _end_trace(self, run: Run) -> None:
############################################

def create_trace_instance(
self, event_type: LangChainTraceType | str, name: str, parent_id: Optional[str], root_parent_id: str
self,
event_type: LangChainTraceType | str,
name: str,
parent_id: Optional[str],
root_parent_id: str,
) -> TraceData:
trace_kwargs = {
"traceProvider": TraceProvider.LANGCHAIN,
Expand All @@ -152,7 +172,7 @@ def create_trace_instance(
"inputPayload": None,
"outputPayload": None,
"parentId": parent_id,
"rootParentId": root_parent_id
"rootParentId": root_parent_id,
}
if event_type == LangChainTraceType.CHAIN:
trace_kwargs["chainAttributes"] = None
Expand Down
2 changes: 1 addition & 1 deletion deepeval/integrations/llama_index/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def create_trace_instance(
"inputPayload": None,
"outputPayload": None,
"parentId": None,
"rootParentId": None
"rootParentId": None,
}

if "exception" in processed_payload:
Expand Down
2 changes: 1 addition & 1 deletion deepeval/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def create_trace_instance(
"inputPayload": None,
"outputPayload": None,
"parentId": None,
"rootParentId": None
"rootParentId": None,
}
if trace_provider == TraceProvider.DEFAULT:
if trace_type == TraceType.AGENT:
Expand Down

0 comments on commit 276ffb0

Please sign in to comment.