Skip to content

Commit

Permalink
Merge pull request #6 from invariantlabs-ai/inputs
Browse files Browse the repository at this point in the history
Parse user inputs to one of the trace event types supported by Invariant
  • Loading branch information
mbalunovic authored Jul 6, 2024
2 parents 37047e8 + ca2f0b6 commit 6491287
Show file tree
Hide file tree
Showing 22 changed files with 316 additions and 383 deletions.
5 changes: 4 additions & 1 deletion invariant/examples/openai_agent_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def openai_agent():
"something_else": something_else,
} # only one function in this example, but you can have multiple

messages.append(response_message.to_dict())
parsed_response = response_message.to_dict()
for tc in parsed_response["tool_calls"]:
tc["function"]["arguments"] = json.loads(tc["function"]["arguments"])
messages.append(response_message)

# monitor for security violations
monitor.check(messages)
Expand Down
10 changes: 5 additions & 5 deletions invariant/integrations/langchain_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def format_invariant_chat_messages(run_id: str, agent_input, intermediate_steps:
for msg in agent_input.get("chat_history", []):
messages.append({
"role": msg["role"],
"content": msg["content"],
"content": str(msg["content"]),
})

if "input" in agent_input:
messages.append({
"role": "user",
"content": agent_input["input"],
"content": str(agent_input["input"]),
})

msg_id = 0
Expand All @@ -55,7 +55,7 @@ def next_id():
if isinstance(action, AgentActionMessageLog):
messages.append({
"role": "assistant",
"content": action.message_log[0].content if len(action.message_log) > 0 else None,
"content": str(action.message_log[0].content) if len(action.message_log) > 0 else None,
"tool_calls": [
{
"id": "1",
Expand All @@ -71,7 +71,7 @@ def next_id():
})
messages.append({
"role": "tool",
"content": observation,
"content": str(observation),
"tool_call_id": "1",
"agent_output": step,
"key": "observation_" + str(next_id())
Expand Down Expand Up @@ -127,7 +127,7 @@ def next_id():
})
messages.append({
"role": "tool",
"content": tool_output,
"content": str(tool_output),
**({"tool_call_id": tool_call.tool_call_id} if hasattr(tool_call, "tool_call_id") else {}),
"agent_output": ns,
"key": "observation_" + str(next_id())
Expand Down
9 changes: 5 additions & 4 deletions invariant/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class Policy:
Use `analyze` to apply the policy to an application state and to obtain a list of violations.
"""
policy_root: PolicyRoot
rule_set: RuleSet
cached: bool

def __init__(self, policy_root: PolicyRoot, cached=False):
"""Creates a new policy with the given policy source.
Expand Down Expand Up @@ -113,10 +116,8 @@ def add_error_to_result(self, error, analysis_result):
"""Implements how errors are added to an analysis result (e.g. as handled or non-handled errors)."""
analysis_result.errors.append(error)

def analyze(self, input: Input | dict, raise_unhandled=False, **policy_parameters):
# prepare input
if type(input) is dict or type(input) is list:
input = Input(input, copy=not self.cached)
def analyze(self, input: list[dict], raise_unhandled=False, **policy_parameters):
input = Input(input)

# prepare policy parameters
if "data" in policy_parameters:
Expand Down
14 changes: 7 additions & 7 deletions invariant/runtime/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,13 @@ def visit_MemberAccess(self, node: MemberAccess):

if hasattr(obj, node.member):
return getattr(obj, node.member)
elif type(obj) is dict:
try:
return obj[node.member]
except KeyError:
raise KeyError(f"Object {obj} has no key {node.member}")
else:
raise KeyError(f"Object {obj} has no member {node.member}")

try:
if type(obj) is str:
obj = json.loads(obj)
return obj[node.member]
except Exception:
raise KeyError(f"Object {obj} has no key {node.member}")

def visit_KeyAccess(self, node: KeyAccess):
obj = self.visit(node.expr)
Expand Down
Loading

0 comments on commit 6491287

Please sign in to comment.