Skip to content

Commit

Permalink
fix in lc + html/links parser
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Jun 14, 2024
1 parent f037cb9 commit d26aa88
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ test.py

# venv
.venv
.DS_Store
.DS_Store
.env.local
31 changes: 20 additions & 11 deletions invariant/integrations/langchain_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
"""

import json
import uuid
import termcolor
import pickle
import contextvars
import os
from typing import AsyncIterator, Dict, List, Tuple, Any, Optional

from invariant import parse, Monitor
from invariant import parse, Monitor, UnhandledError
from invariant.monitor import wrappers, ValidatedOperation, OperationCall, WrappingHandler, stack
from invariant.stdlib.invariant.errors import UpdateMessage, UpdateMessageHandler, PolicyViolation
from invariant.stdlib.invariant import ToolCall
Expand All @@ -24,7 +25,7 @@
from langchain_core.agents import AgentAction, AgentFinish, AgentStep, AgentActionMessageLog
from langchain_core.tools import BaseTool

def format_invariant_chat_messages(agent_input, intermediate_steps: list[AgentAction], next_step: AgentAction | AgentFinish):
def format_invariant_chat_messages(run_id: str, agent_input, intermediate_steps: list[AgentAction], next_step: AgentAction | AgentFinish):
from langchain_core.messages import AIMessage, ToolMessage, FunctionMessage

messages = []
Expand All @@ -45,7 +46,7 @@ def format_invariant_chat_messages(agent_input, intermediate_steps: list[AgentAc
def next_id():
nonlocal msg_id
msg_id += 1
return msg_id
return run_id + "_" + str(msg_id)

for step in intermediate_steps:
msg = step
Expand Down Expand Up @@ -195,21 +196,31 @@ def __str__(self):
class MonitoringAgentExecutor(AgentExecutor):
monitor: Monitor
verbose_policy: bool = False
raise_on_violation: bool = True
run_id: str = None

async def ainvoke(self, inputs: dict, **kwargs):
# choose UUID for this run
self.run_id = str(uuid.uuid4().hex)
return await super().ainvoke(inputs, **kwargs)

def invoke(self, inputs: dict, **kwargs):
raise NotImplementedError("MonitoringAgentExecutor does not support synchronous execution yet. Use 'ainvoke' instead.")

async def _atake_next_step(self, name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager = None):
with AgentState(inputs, intermediate_steps) as state:
# analysis current state
analysis_result = self.monitor.analyze(format_invariant_chat_messages(state.inputs, state.intermediate_steps, None), raise_unhandled=True)
analysis_result = self.monitor.analyze(format_invariant_chat_messages(self.run_id, state.inputs, state.intermediate_steps, None), raise_unhandled=True)
# apply the handlers (make sure side-effects apply to tool_call_msg)
analysis_result.execute_handlers()

if len(analysis_result.handled_errors) > 0:
self.print_chat(format_invariant_chat_messages(state.inputs, state.intermediate_steps, None), heading="== POLICY APPLIED == ")
self.print_chat(format_invariant_chat_messages(self.run_id, state.inputs, state.intermediate_steps, None), heading="== POLICY APPLIED == ")

result = await super()._atake_next_step(name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager)
result = MutableAgentActionTuple.from_result(result)

self.print_chat(format_invariant_chat_messages(state.inputs, state.intermediate_steps, result))
self.print_chat(format_invariant_chat_messages(self.run_id, state.inputs, state.intermediate_steps, result))

return result

Expand Down Expand Up @@ -249,7 +260,7 @@ def update_tool_input(tool_input):
# agent_action.message_log[0].additional_kwargs['function_call']['arguments'] = json.dumps(tool_input)

# compute current chat state
chat = format_invariant_chat_messages(agent_state.inputs, agent_state.intermediate_steps, agent_action)
chat = format_invariant_chat_messages(self.run_id, agent_state.inputs, agent_state.intermediate_steps, agent_action)
tool_call_msg = chat.pop(-1)
self.print_chat(chat + [tool_call_msg])

Expand All @@ -259,7 +270,7 @@ def update_tool_input(tool_input):
# apply the handlers (make sure side-effects apply to tool_call_msg)
analysis_result.execute_handlers()

chat = format_invariant_chat_messages(agent_state.inputs, agent_state.intermediate_steps, agent_action)
chat = format_invariant_chat_messages(self.run_id, agent_state.inputs, agent_state.intermediate_steps, agent_action)
tool_call_msg = chat.pop(-1)

# actual tool call is last fct in stack
Expand Down Expand Up @@ -289,9 +300,7 @@ async def actual_tool(tool_input: dict, **kwargs):
if len(analysis_result.handled_errors) - len(wrappers(analysis_result)) > 0:
self.print_chat(chat + [tool_call_msg], heading="== POLICY APPLIED == ")

result = await super(MonitoringAgentExecutor, self)._aperform_agent_action(patched_map, color_mapping, agent_action, run_manager)

return result
return await super(MonitoringAgentExecutor, self)._aperform_agent_action(patched_map, color_mapping, agent_action, run_manager)


class WrappedOneTimeTool(BaseTool):
Expand Down
4 changes: 4 additions & 0 deletions invariant/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def __init__(self, policy_root: PolicyRoot, policy_parameters: dict, raise_unhan
# whether to raise unhandled errors in `check()`
self.raise_unhandled = raise_unhandled or policy_parameters.pop("raise_unhandled", False)

def reset(self):
"""Resets the monitor to its initial state (incremental state is cleared)."""
self.rule_set = RuleSet.from_policy(self.policy_root, cached=self.cached)

@classmethod
def from_file(cls, path: str, **policy_parameters):
return cls(parse_file(path), policy_parameters)
Expand Down
5 changes: 5 additions & 0 deletions invariant/runtime/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ def select(self, selector, data="<root>"):
for key, value in data.__dict__.items():
result += self.select(type_name, value)
return result
elif type(data) is tuple:
result = []
for item in data:
result += self.select(type_name, item)
return result
else:
print("cannot sub-select type", type(data))
return []
Expand Down
4 changes: 2 additions & 2 deletions invariant/runtime/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def instance_key(self, rule, model):
model_keys = []
for k,v in model.items():
if type(v) is dict and "key" in v:
model_keys.append((k, v["key"]))
model_keys.append((k.name, v["key"]))
else:
model_keys.append((k, id(v)))
model_keys.append((k.name, id(v)))
return (id(rule), *(vkey for k,vkey in sorted(model_keys, key=lambda x: x[0])))

def log_apply(self, rule, model):
Expand Down
10 changes: 9 additions & 1 deletion invariant/runtime/utils/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@ class PythonDetectorResult:
builtins: list[str] = Field(default_factory=list, description="List of built-in functions used.")
# whether code has syntax errors
syntax_error: bool = Field(default=False, description="Flag which is true if code has syntax errors.")

# function call identifier names
function_calls: set[str] = Field(default_factory=set, description="Set of function call targets as returned by 'ast.unparse(node.func).strip()'")

def add_import(self, module: str):
self.imports.append(module)

def add_builtin(self, builtin: str):
self.builtins.append(builtin)

def add_function_call(self, function: str):
self.function_calls.add(function)

def extend(self, other: "PythonDetectorResult"):
if type(other) != PythonDetectorResult:
raise ValueError("Expected PythonDetectorResult object")
Expand Down Expand Up @@ -76,6 +81,9 @@ def visit_Import(self, node):
def visit_ImportFrom(self, node):
self.res.add_import(node.module)

def visit_Call(self, node):
self.res.add_function_call(ast.unparse(node.func).strip())

class PythonCodeDetector(BaseDetector):
"""Detector which extracts entities from Python code.
Expand Down
48 changes: 35 additions & 13 deletions invariant/stdlib/invariant/parsers/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,20 @@ def handle_data(self, data):

def parse(self, data: str) -> HiddenHTMLData:
self.feed(data)
self.links = self.links.union(get_links_regex(data))
self.links = self.links.union(HiddenDataParser.get_links_regex(data))

def get_links_regex(data: str) -> list[str]:
"""
Extracts links from a string of HTML code.
Returns:
- list[str]: A list of links.
"""
@staticmethod
def get_links_regex(data: str) -> list[str]:
"""
Extracts links from a string of HTML code.
Returns:
- list[str]: A list of links.
"""

# link including path etc.
pattern = r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+" + r"(?:/[^ \n\"]+)*"
return list(set(re.findall(pattern, data)))
# link including path etc.
pattern = r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+" + r"(?:/[^ \n\"]+)*"
return list(set(re.findall(pattern, data)))


def html_code(data: str | list | dict, **config: dict) -> HiddenHTMLData:
Expand All @@ -57,12 +58,33 @@ def html_code(data: str | list | dict, **config: dict) -> HiddenHTMLData:
for message in chat:
if message is None:
continue
if message["content"] is None:
if "content" in message and message["content"] is None:
continue
content = message.get("content", str(message))
parser = HiddenDataParser()
parser.parse(message["content"])
parser.parse(content)

res.alt_texts.extend(parser.alt_texts)
res.links.extend(list(parser.links))

return res

def links(data: str | list | dict, **config: dict) -> list[str]:
"""
Extracts links from a string of HTML code or text.
Returns:
- list[str]: A list of links.
"""

chat = data if isinstance(data, list) else ([{"content": data}] if type(data) == str else [data])

res = []
for message in chat:
if message is None:
continue
if message["content"] is None:
continue
res.extend(HiddenDataParser.get_links_regex(message["content"]))

return res
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dev-dependencies = [
"transformers>=4.41.1",
"torch>=2.3.0",
"codeshield>=1.0.1",
"python-dotenv>=1.0.1",
]

[project.scripts]
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ defusedxml==0.7.1
# via semgrep
distro==1.9.0
# via openai
dotfiles==0.6.4
exceptiongroup==1.2.1
# via semgrep
face==22.0.0
Expand Down Expand Up @@ -191,6 +192,7 @@ pydantic-core==2.18.4
pygments==2.18.0
# via rich
pytest==8.2.2
python-dotenv==1.0.1
pyyaml==6.0.1
# via codeshield
# via huggingface-hub
Expand Down
22 changes: 22 additions & 0 deletions tests/test_html_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,27 @@ def test_links(self):
errors = p.analyze(trace).errors
assert len(errors) == 1, "Expected 1 error, got {}: {}".format(len(errors), errors)

def test_links_from_text(self):
p = Policy.from_string(
"""
from invariant.parsers.html import links
raise PolicyViolation("found a non-trusted link", link=link) if:
(output: ToolOutput)
(link: str) in links(output)
# only allow links to trusted.com
not match("https://trusted.com/.+", link)
"""
)

trace = [
user("Hello there"),
# injection in alt text broken html
tool("1", """# Getting started\n First, visit [this link](https://trusted.com/1), then [this one](https://untrusted.com/2)""")
]

errors = p.analyze(trace).errors
assert len(errors) == 1, "Expected 1 error, got {}: {}".format(len(errors), errors)

if __name__ == "__main__":
unittest.main()

0 comments on commit d26aa88

Please sign in to comment.