From a557982c2b4b3903620edfe6b39ac3b03157b57c Mon Sep 17 00:00:00 2001 From: Mislav Balunovic Date: Sat, 6 Jul 2024 15:13:17 +0200 Subject: [PATCH] [feat] enable function caching with decorator --- invariant/examples/{ => error_handling}/lc_example.py | 0 invariant/examples/{ => error_handling}/tool_example.py | 0 invariant/runtime/functions.py | 4 ++-- invariant/runtime/rule.py | 4 ++-- invariant/stdlib/invariant/builtins.py | 2 -- invariant/stdlib/invariant/detectors/code.py | 6 ++++-- invariant/stdlib/invariant/detectors/moderation.py | 2 ++ invariant/stdlib/invariant/detectors/pii.py | 2 ++ invariant/stdlib/invariant/detectors/prompt_injection.py | 4 +++- invariant/stdlib/invariant/detectors/secrets.py | 2 ++ 10 files changed, 17 insertions(+), 9 deletions(-) rename invariant/examples/{ => error_handling}/lc_example.py (100%) rename invariant/examples/{ => error_handling}/tool_example.py (100%) diff --git a/invariant/examples/lc_example.py b/invariant/examples/error_handling/lc_example.py similarity index 100% rename from invariant/examples/lc_example.py rename to invariant/examples/error_handling/lc_example.py diff --git a/invariant/examples/tool_example.py b/invariant/examples/error_handling/tool_example.py similarity index 100% rename from invariant/examples/tool_example.py rename to invariant/examples/error_handling/tool_example.py diff --git a/invariant/runtime/functions.py b/invariant/runtime/functions.py index fb4fa1d..1a96343 100644 --- a/invariant/runtime/functions.py +++ b/invariant/runtime/functions.py @@ -4,7 +4,7 @@ invariant agent analyzer. """ -def nocache(func): +def cache(func): """ Decorator to mark a function as non-cacheable. @@ -14,5 +14,5 @@ def nocache(func): during the evaluation of a policy rule, even for partial variable assignemnts that are not part of the final result. """ - setattr(func, "__invariant_nocache__", True) + setattr(func, "__invariant_cache__", True) return func \ No newline at end of file diff --git a/invariant/runtime/rule.py b/invariant/runtime/rule.py index 8a5f01b..4e28a00 100644 --- a/invariant/runtime/rule.py +++ b/invariant/runtime/rule.py @@ -195,8 +195,8 @@ def contains(self, function, args, kwargs): return self.call_key(function, args, kwargs) in self.cache def call(self, function, args, **kwargs): - # check if function is marked as @nocache (see ./functions.py module) - if True: # hasattr(function, "__invariant_nocache__"): + # if function is not marked with @cache we just call it directly (see ./functions.py module) + if not hasattr(function, "__invariant_cache__"): return function(*args, **kwargs) if not self.contains(function, args, kwargs): self.cache[self.call_key(function, args, kwargs)] = function(*args, **kwargs) diff --git a/invariant/stdlib/invariant/builtins.py b/invariant/stdlib/invariant/builtins.py index 02cf0b2..c183c8e 100644 --- a/invariant/stdlib/invariant/builtins.py +++ b/invariant/stdlib/invariant/builtins.py @@ -2,7 +2,6 @@ from invariant.stdlib.invariant.errors import * from invariant.stdlib.invariant.message import * from invariant.runtime.utils.base import DetectorResult -from invariant.runtime.functions import nocache import builtins as py_builtins # Utilities @@ -36,7 +35,6 @@ def sum(*args, **kwargs): # Utilities -@nocache def print(*args, **kwargs): """ Prints the given arguments just like with Python's built-in print function. diff --git a/invariant/stdlib/invariant/detectors/code.py b/invariant/stdlib/invariant/detectors/code.py index 9d1499f..6868999 100644 --- a/invariant/stdlib/invariant/detectors/code.py +++ b/invariant/stdlib/invariant/detectors/code.py @@ -1,9 +1,11 @@ from invariant.runtime.utils.code import * +from invariant.runtime.functions import cache PYTHON_ANALYZER = None CODE_SHIELD_DETECTOR = None SEMGREP_DETECTOR = None +@cache def python_code(data: str | list | dict, **config: dict) -> PythonDetectorResult: """Predicate used to extract entities from Python code.""" @@ -23,7 +25,7 @@ def python_code(data: str | list | dict, **config: dict) -> PythonDetectorResult res.extend(PYTHON_ANALYZER.detect(message.content, **config)) return res - +@cache def code_shield(data: str | list | dict, **config: dict) -> list[CodeIssue]: """Predicate used to run CodeShield on code.""" @@ -43,7 +45,7 @@ def code_shield(data: str | list | dict, **config: dict) -> list[CodeIssue]: res.extend(CODE_SHIELD_DETECTOR.detect_all(message.content, **config)) return res - +@cache def semgrep(data: str | list | dict, **config: dict) -> list[CodeIssue]: """Predicate used to run Semgrep on code.""" diff --git a/invariant/stdlib/invariant/detectors/moderation.py b/invariant/stdlib/invariant/detectors/moderation.py index 270f6ae..41c4d24 100644 --- a/invariant/stdlib/invariant/detectors/moderation.py +++ b/invariant/stdlib/invariant/detectors/moderation.py @@ -1,7 +1,9 @@ from invariant.runtime.utils.moderation import ModerationAnalyzer +from invariant.runtime.functions import cache MODERATION_ANALYZER = None +@cache def moderated(data: str | list | dict, **config: dict) -> bool: """Predicate which evaluates to true if the given data should be moderated. diff --git a/invariant/stdlib/invariant/detectors/pii.py b/invariant/stdlib/invariant/detectors/pii.py index 00246ec..7ab0175 100644 --- a/invariant/stdlib/invariant/detectors/pii.py +++ b/invariant/stdlib/invariant/detectors/pii.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from invariant.stdlib.invariant.nodes import LLM +from invariant.runtime.functions import cache PII_ANALYZER = None @@ -7,6 +8,7 @@ class PIIException(Exception): llm_call: LLM +@cache def pii(data: str | list, **config): """Predicate which detects PII in the given data. diff --git a/invariant/stdlib/invariant/detectors/prompt_injection.py b/invariant/stdlib/invariant/detectors/prompt_injection.py index 2ee89a6..5a9208d 100644 --- a/invariant/stdlib/invariant/detectors/prompt_injection.py +++ b/invariant/stdlib/invariant/detectors/prompt_injection.py @@ -1,8 +1,10 @@ from invariant.runtime.utils.prompt_injections import PromptInjectionAnalyzer, UnicodeDetector +from invariant.runtime.functions import cache PROMPT_INJECTION_ANALYZER = None UNICODE_ANALYZER = None +@cache def prompt_injection(data: str | list | dict, **config: dict) -> bool: """Predicate used for detecting prompt injections in the given data. @@ -27,7 +29,7 @@ def prompt_injection(data: str | list | dict, **config: dict) -> bool: return True return False - +@cache def unicode(data: str | list | dict, **config: dict) -> bool: """Predicate used for detecting disallowed types of unicode characters in the given data.""" assert data is not None, "cannot call unicode(...) on None" diff --git a/invariant/stdlib/invariant/detectors/secrets.py b/invariant/stdlib/invariant/detectors/secrets.py index 0b02dfc..08ceac6 100644 --- a/invariant/stdlib/invariant/detectors/secrets.py +++ b/invariant/stdlib/invariant/detectors/secrets.py @@ -1,7 +1,9 @@ from invariant.runtime.utils.secrets import SecretsAnalyzer +from invariant.runtime.functions import cache SECRETS_ANALYZER = None +@cache def secrets(data: str | list | dict, **config: dict) -> list[str]: """Predicate which evaluates to true if the given data should be moderated.