Skip to content

Commit

Permalink
[feat] enable function caching with decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
mbalunovic committed Jul 6, 2024
1 parent 6491287 commit a557982
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 9 deletions.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions invariant/runtime/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
invariant agent analyzer.
"""

def nocache(func):
def cache(func):
"""
Decorator to mark a function as non-cacheable.
Expand All @@ -14,5 +14,5 @@ def nocache(func):
during the evaluation of a policy rule, even for partial variable
assignemnts that are not part of the final result.
"""
setattr(func, "__invariant_nocache__", True)
setattr(func, "__invariant_cache__", True)
return func
4 changes: 2 additions & 2 deletions invariant/runtime/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def contains(self, function, args, kwargs):
return self.call_key(function, args, kwargs) in self.cache

def call(self, function, args, **kwargs):
# check if function is marked as @nocache (see ./functions.py module)
if True: # hasattr(function, "__invariant_nocache__"):
# if function is not marked with @cache we just call it directly (see ./functions.py module)
if not hasattr(function, "__invariant_cache__"):
return function(*args, **kwargs)
if not self.contains(function, args, kwargs):
self.cache[self.call_key(function, args, kwargs)] = function(*args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions invariant/stdlib/invariant/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from invariant.stdlib.invariant.errors import *
from invariant.stdlib.invariant.message import *
from invariant.runtime.utils.base import DetectorResult
from invariant.runtime.functions import nocache
import builtins as py_builtins

# Utilities
Expand Down Expand Up @@ -36,7 +35,6 @@ def sum(*args, **kwargs):

# Utilities

@nocache
def print(*args, **kwargs):
"""
Prints the given arguments just like with Python's built-in print function.
Expand Down
6 changes: 4 additions & 2 deletions invariant/stdlib/invariant/detectors/code.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from invariant.runtime.utils.code import *
from invariant.runtime.functions import cache

PYTHON_ANALYZER = None
CODE_SHIELD_DETECTOR = None
SEMGREP_DETECTOR = None

@cache
def python_code(data: str | list | dict, **config: dict) -> PythonDetectorResult:
"""Predicate used to extract entities from Python code."""

Expand All @@ -23,7 +25,7 @@ def python_code(data: str | list | dict, **config: dict) -> PythonDetectorResult
res.extend(PYTHON_ANALYZER.detect(message.content, **config))
return res


@cache
def code_shield(data: str | list | dict, **config: dict) -> list[CodeIssue]:
"""Predicate used to run CodeShield on code."""

Expand All @@ -43,7 +45,7 @@ def code_shield(data: str | list | dict, **config: dict) -> list[CodeIssue]:
res.extend(CODE_SHIELD_DETECTOR.detect_all(message.content, **config))
return res


@cache
def semgrep(data: str | list | dict, **config: dict) -> list[CodeIssue]:
"""Predicate used to run Semgrep on code."""

Expand Down
2 changes: 2 additions & 0 deletions invariant/stdlib/invariant/detectors/moderation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from invariant.runtime.utils.moderation import ModerationAnalyzer
from invariant.runtime.functions import cache

MODERATION_ANALYZER = None

@cache
def moderated(data: str | list | dict, **config: dict) -> bool:
"""Predicate which evaluates to true if the given data should be moderated.
Expand Down
2 changes: 2 additions & 0 deletions invariant/stdlib/invariant/detectors/pii.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dataclasses import dataclass
from invariant.stdlib.invariant.nodes import LLM
from invariant.runtime.functions import cache

PII_ANALYZER = None

@dataclass
class PIIException(Exception):
llm_call: LLM

@cache
def pii(data: str | list, **config):
"""Predicate which detects PII in the given data.
Expand Down
4 changes: 3 additions & 1 deletion invariant/stdlib/invariant/detectors/prompt_injection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from invariant.runtime.utils.prompt_injections import PromptInjectionAnalyzer, UnicodeDetector
from invariant.runtime.functions import cache

PROMPT_INJECTION_ANALYZER = None
UNICODE_ANALYZER = None

@cache
def prompt_injection(data: str | list | dict, **config: dict) -> bool:
"""Predicate used for detecting prompt injections in the given data.
Expand All @@ -27,7 +29,7 @@ def prompt_injection(data: str | list | dict, **config: dict) -> bool:
return True
return False


@cache
def unicode(data: str | list | dict, **config: dict) -> bool:
"""Predicate used for detecting disallowed types of unicode characters in the given data."""
assert data is not None, "cannot call unicode(...) on None"
Expand Down
2 changes: 2 additions & 0 deletions invariant/stdlib/invariant/detectors/secrets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from invariant.runtime.utils.secrets import SecretsAnalyzer
from invariant.runtime.functions import cache

SECRETS_ANALYZER = None

@cache
def secrets(data: str | list | dict, **config: dict) -> list[str]:
"""Predicate which evaluates to true if the given data should be moderated.
Expand Down

0 comments on commit a557982

Please sign in to comment.