diff --git a/.gitignore b/.gitignore index 313c53c..b580cc8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ test.py # venv .venv -.DS_Store \ No newline at end of file +.DS_Store +.env.local \ No newline at end of file diff --git a/invariant/integrations/langchain_integration.py b/invariant/integrations/langchain_integration.py index 9af3d8d..46d7187 100644 --- a/invariant/integrations/langchain_integration.py +++ b/invariant/integrations/langchain_integration.py @@ -3,13 +3,14 @@ """ import json +import uuid import termcolor import pickle import contextvars import os from typing import AsyncIterator, Dict, List, Tuple, Any, Optional -from invariant import parse, Monitor +from invariant import parse, Monitor, UnhandledError from invariant.monitor import wrappers, ValidatedOperation, OperationCall, WrappingHandler, stack from invariant.stdlib.invariant.errors import UpdateMessage, UpdateMessageHandler, PolicyViolation from invariant.stdlib.invariant import ToolCall @@ -24,7 +25,7 @@ from langchain_core.agents import AgentAction, AgentFinish, AgentStep, AgentActionMessageLog from langchain_core.tools import BaseTool -def format_invariant_chat_messages(agent_input, intermediate_steps: list[AgentAction], next_step: AgentAction | AgentFinish): +def format_invariant_chat_messages(run_id: str, agent_input, intermediate_steps: list[AgentAction], next_step: AgentAction | AgentFinish): from langchain_core.messages import AIMessage, ToolMessage, FunctionMessage messages = [] @@ -45,7 +46,7 @@ def format_invariant_chat_messages(agent_input, intermediate_steps: list[AgentAc def next_id(): nonlocal msg_id msg_id += 1 - return msg_id + return run_id + "_" + str(msg_id) for step in intermediate_steps: msg = step @@ -195,21 +196,31 @@ def __str__(self): class MonitoringAgentExecutor(AgentExecutor): monitor: Monitor verbose_policy: bool = False + raise_on_violation: bool = True + run_id: str = None + + async def ainvoke(self, inputs: dict, **kwargs): + # choose UUID for this run + self.run_id = str(uuid.uuid4().hex) + return await super().ainvoke(inputs, **kwargs) + + def invoke(self, inputs: dict, **kwargs): + raise NotImplementedError("MonitoringAgentExecutor does not support synchronous execution yet. Use 'ainvoke' instead.") async def _atake_next_step(self, name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager = None): with AgentState(inputs, intermediate_steps) as state: # analysis current state - analysis_result = self.monitor.analyze(format_invariant_chat_messages(state.inputs, state.intermediate_steps, None), raise_unhandled=True) + analysis_result = self.monitor.analyze(format_invariant_chat_messages(self.run_id, state.inputs, state.intermediate_steps, None), raise_unhandled=True) # apply the handlers (make sure side-effects apply to tool_call_msg) analysis_result.execute_handlers() if len(analysis_result.handled_errors) > 0: - self.print_chat(format_invariant_chat_messages(state.inputs, state.intermediate_steps, None), heading="== POLICY APPLIED == ") + self.print_chat(format_invariant_chat_messages(self.run_id, state.inputs, state.intermediate_steps, None), heading="== POLICY APPLIED == ") result = await super()._atake_next_step(name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager) result = MutableAgentActionTuple.from_result(result) - self.print_chat(format_invariant_chat_messages(state.inputs, state.intermediate_steps, result)) + self.print_chat(format_invariant_chat_messages(self.run_id, state.inputs, state.intermediate_steps, result)) return result @@ -249,7 +260,7 @@ def update_tool_input(tool_input): # agent_action.message_log[0].additional_kwargs['function_call']['arguments'] = json.dumps(tool_input) # compute current chat state - chat = format_invariant_chat_messages(agent_state.inputs, agent_state.intermediate_steps, agent_action) + chat = format_invariant_chat_messages(self.run_id, agent_state.inputs, agent_state.intermediate_steps, agent_action) tool_call_msg = chat.pop(-1) self.print_chat(chat + [tool_call_msg]) @@ -259,7 +270,7 @@ def update_tool_input(tool_input): # apply the handlers (make sure side-effects apply to tool_call_msg) analysis_result.execute_handlers() - chat = format_invariant_chat_messages(agent_state.inputs, agent_state.intermediate_steps, agent_action) + chat = format_invariant_chat_messages(self.run_id, agent_state.inputs, agent_state.intermediate_steps, agent_action) tool_call_msg = chat.pop(-1) # actual tool call is last fct in stack @@ -289,9 +300,7 @@ async def actual_tool(tool_input: dict, **kwargs): if len(analysis_result.handled_errors) - len(wrappers(analysis_result)) > 0: self.print_chat(chat + [tool_call_msg], heading="== POLICY APPLIED == ") - result = await super(MonitoringAgentExecutor, self)._aperform_agent_action(patched_map, color_mapping, agent_action, run_manager) - - return result + return await super(MonitoringAgentExecutor, self)._aperform_agent_action(patched_map, color_mapping, agent_action, run_manager) class WrappedOneTimeTool(BaseTool): diff --git a/invariant/monitor.py b/invariant/monitor.py index 0225203..fe28e82 100644 --- a/invariant/monitor.py +++ b/invariant/monitor.py @@ -91,6 +91,10 @@ def __init__(self, policy_root: PolicyRoot, policy_parameters: dict, raise_unhan # whether to raise unhandled errors in `check()` self.raise_unhandled = raise_unhandled or policy_parameters.pop("raise_unhandled", False) + def reset(self): + """Resets the monitor to its initial state (incremental state is cleared).""" + self.rule_set = RuleSet.from_policy(self.policy_root, cached=self.cached) + @classmethod def from_file(cls, path: str, **policy_parameters): return cls(parse_file(path), policy_parameters) diff --git a/invariant/runtime/input.py b/invariant/runtime/input.py index 8c409ae..3cbb341 100644 --- a/invariant/runtime/input.py +++ b/invariant/runtime/input.py @@ -221,6 +221,11 @@ def select(self, selector, data=""): for key, value in data.__dict__.items(): result += self.select(type_name, value) return result + elif type(data) is tuple: + result = [] + for item in data: + result += self.select(type_name, item) + return result else: print("cannot sub-select type", type(data)) return [] diff --git a/invariant/runtime/rule.py b/invariant/runtime/rule.py index e2f1764..dcebb43 100644 --- a/invariant/runtime/rule.py +++ b/invariant/runtime/rule.py @@ -239,9 +239,9 @@ def instance_key(self, rule, model): model_keys = [] for k,v in model.items(): if type(v) is dict and "key" in v: - model_keys.append((k, v["key"])) + model_keys.append((k.name, v["key"])) else: - model_keys.append((k, id(v))) + model_keys.append((k.name, id(v))) return (id(rule), *(vkey for k,vkey in sorted(model_keys, key=lambda x: x[0]))) def log_apply(self, rule, model): diff --git a/invariant/runtime/utils/code.py b/invariant/runtime/utils/code.py index f8d54d1..e4c7886 100644 --- a/invariant/runtime/utils/code.py +++ b/invariant/runtime/utils/code.py @@ -32,13 +32,18 @@ class PythonDetectorResult: builtins: list[str] = Field(default_factory=list, description="List of built-in functions used.") # whether code has syntax errors syntax_error: bool = Field(default=False, description="Flag which is true if code has syntax errors.") - + # function call identifier names + function_calls: set[str] = Field(default_factory=set, description="Set of function call targets as returned by 'ast.unparse(node.func).strip()'") + def add_import(self, module: str): self.imports.append(module) def add_builtin(self, builtin: str): self.builtins.append(builtin) + def add_function_call(self, function: str): + self.function_calls.add(function) + def extend(self, other: "PythonDetectorResult"): if type(other) != PythonDetectorResult: raise ValueError("Expected PythonDetectorResult object") @@ -76,6 +81,9 @@ def visit_Import(self, node): def visit_ImportFrom(self, node): self.res.add_import(node.module) + def visit_Call(self, node): + self.res.add_function_call(ast.unparse(node.func).strip()) + class PythonCodeDetector(BaseDetector): """Detector which extracts entities from Python code. diff --git a/invariant/stdlib/invariant/parsers/html.py b/invariant/stdlib/invariant/parsers/html.py index 15ece43..7882db6 100644 --- a/invariant/stdlib/invariant/parsers/html.py +++ b/invariant/stdlib/invariant/parsers/html.py @@ -28,19 +28,20 @@ def handle_data(self, data): def parse(self, data: str) -> HiddenHTMLData: self.feed(data) - self.links = self.links.union(get_links_regex(data)) + self.links = self.links.union(HiddenDataParser.get_links_regex(data)) -def get_links_regex(data: str) -> list[str]: - """ - Extracts links from a string of HTML code. - - Returns: - - list[str]: A list of links. - """ + @staticmethod + def get_links_regex(data: str) -> list[str]: + """ + Extracts links from a string of HTML code. + + Returns: + - list[str]: A list of links. + """ - # link including path etc. - pattern = r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+" + r"(?:/[^ \n\"]+)*" - return list(set(re.findall(pattern, data))) + # link including path etc. + pattern = r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+" + r"(?:/[^ \n\"]+)*" + return list(set(re.findall(pattern, data))) def html_code(data: str | list | dict, **config: dict) -> HiddenHTMLData: @@ -57,12 +58,33 @@ def html_code(data: str | list | dict, **config: dict) -> HiddenHTMLData: for message in chat: if message is None: continue - if message["content"] is None: + if "content" in message and message["content"] is None: continue + content = message.get("content", str(message)) parser = HiddenDataParser() - parser.parse(message["content"]) + parser.parse(content) res.alt_texts.extend(parser.alt_texts) res.links.extend(list(parser.links)) + return res + +def links(data: str | list | dict, **config: dict) -> list[str]: + """ + Extracts links from a string of HTML code or text. + + Returns: + - list[str]: A list of links. + """ + + chat = data if isinstance(data, list) else ([{"content": data}] if type(data) == str else [data]) + + res = [] + for message in chat: + if message is None: + continue + if message["content"] is None: + continue + res.extend(HiddenDataParser.get_links_regex(message["content"])) + return res \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a183c35..9b54177 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev-dependencies = [ "transformers>=4.41.1", "torch>=2.3.0", "codeshield>=1.0.1", + "python-dotenv>=1.0.1", ] [project.scripts] diff --git a/requirements-dev.lock b/requirements-dev.lock index 21cda5d..afe12a5 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -63,6 +63,7 @@ defusedxml==0.7.1 # via semgrep distro==1.9.0 # via openai +dotfiles==0.6.4 exceptiongroup==1.2.1 # via semgrep face==22.0.0 @@ -191,6 +192,7 @@ pydantic-core==2.18.4 pygments==2.18.0 # via rich pytest==8.2.2 +python-dotenv==1.0.1 pyyaml==6.0.1 # via codeshield # via huggingface-hub diff --git a/tests/test_html_parsing.py b/tests/test_html_parsing.py index 639bd5f..9070289 100644 --- a/tests/test_html_parsing.py +++ b/tests/test_html_parsing.py @@ -97,5 +97,27 @@ def test_links(self): errors = p.analyze(trace).errors assert len(errors) == 1, "Expected 1 error, got {}: {}".format(len(errors), errors) + def test_links_from_text(self): + p = Policy.from_string( + """ + from invariant.parsers.html import links + + raise PolicyViolation("found a non-trusted link", link=link) if: + (output: ToolOutput) + (link: str) in links(output) + # only allow links to trusted.com + not match("https://trusted.com/.+", link) + """ + ) + + trace = [ + user("Hello there"), + # injection in alt text broken html + tool("1", """# Getting started\n First, visit [this link](https://trusted.com/1), then [this one](https://untrusted.com/2)""") + ] + + errors = p.analyze(trace).errors + assert len(errors) == 1, "Expected 1 error, got {}: {}".format(len(errors), errors) + if __name__ == "__main__": unittest.main() \ No newline at end of file