Skip to content

Commit

Permalink
Merge branch 'main' into workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
mbalunovic committed Aug 6, 2024
2 parents 70504c2 + 47e5c10 commit ae55c0d
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 106 deletions.
4 changes: 2 additions & 2 deletions invariant/examples/code_agents/swe_agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@
"outputs": [],
"source": [
"policy = Policy.from_string(\"\"\"\n",
" from invariant.detectors import pii, code_shield, CodeIssue, secrets\n",
" from invariant.detectors import pii, semgrep, CodeIssue, secrets\n",
" \n",
" raise PolicyViolation(\"found unsafe code: \", issue) if:\n",
" (call1: ToolCall) -> (call2: ToolCall)\n",
" call1.function.name == \"edit\"\n",
" (issue: CodeIssue) in code_shield(call1.function.arguments[\"code\"])\n",
" (issue: CodeIssue) in semgrep(call1.function.arguments[\"code\"])\n",
" call2.function.name == \"python\"\n",
" \"pickle\" in issue.description\n",
" \"\"\")"
Expand Down
5 changes: 0 additions & 5 deletions invariant/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ def find_all() -> list["Extra"]:
"spacy": ExtrasImport("spacy", "spacy", ">=3.7.5")
})

"""Extra for features that rely on the Llama `codeshield` library."""
codeshield_extra = Extra("Code Scanning with Llama CodeShield", "Enables the use of CodeShield", {
"codeshield.cs": ExtrasImport("codeshield.cs", "codeshield", ">=1.0.1")
})

"""Extra for features that rely on the `semgrep` library."""
semgrep_extra = Extra("Code Scanning with Semgrep", "Enables the use of Semgrep for code scanning", {
"semgrep": ExtrasImport("semgrep", "semgrep", ">=1.78.0")
Expand Down
67 changes: 49 additions & 18 deletions invariant/language/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from invariant.language.typing import typing
from invariant.runtime.patterns import VALUE_MATCHERS

"""
Lark EBNF grammar for the Invariant Policy Language.
"""
parser = lark.Lark(r"""
%import common.NUMBER
Expand All @@ -27,10 +30,11 @@
def_stmt: "def" func_signature INDENT NEWLINE? statement* DEDENT
decl_stmt: ( ID | func_signature ) ":=" expr | ( ID | func_signature ) INDENT NEWLINE? expr (NEWLINE expr)* DEDENT
expr: ID | binary_expr | "(" expr ("," expr)* ")" | block_expr
expr: ID | assignment_expr | "(" expr ("," expr)* ")" | block_expr
block_expr: INDENT expr (NEWLINE expr)* DEDENT
assignment_expr: ( ID ":=" binary_expr ) | binary_expr
binary_expr: cmp_expr LOGICAL_OPERATOR cmp_expr | cmp_expr
cmp_expr: ( term CMP_OPERATOR term ) | term
term: factor TERM_OPERATOR factor | factor
Expand Down Expand Up @@ -76,7 +80,7 @@
ID.2: /[a-zA-Z_]([a-zA-Z0-9_])*/
UNARY_OPERATOR.3: /not[\n\t ]/ | "-" | "+"
LOGICAL_OPERATOR: /and[\n\t ]/ | /or[\n\t ]/
CMP_OPERATOR: ":=" | "==" | "!=" | ">" | "<" | ">=" | "<=" | /is[\n\t ]/ | /contains_only[\n\t ]/ | /in[\n\t ]/ | "->"
CMP_OPERATOR: "==" | "!=" | ">" | "<" | ">=" | "<=" | /is[\n\t ]/ | /contains_only[\n\t ]/ | /in[\n\t ]/ | "->"
VALUE_TYPE: /<[a-zA-Z_:]+>/
TERM_OPERATOR: "+" | "-"
Expand All @@ -97,8 +101,8 @@ def indent_level(line, unit=1):
# count the number of leading spaces
return (len(line) - len(line.lstrip())) // unit


def derive_indentation_units(text):
# derive the indentation unit from the first non-empty line
lines = text.split("\n")
indents = set()
for line in lines:
Expand All @@ -108,8 +112,29 @@ def derive_indentation_units(text):
return 1
return min(indents)


def parse_indents(text):
"""
This function parses an intended snippet of IPL code and returns a version of the code
where indents are replaced by |INDENT| and |DEDENT| tokens.
This allows our actual language grammar above to be context-free, as it does not need to
handle indentation, but can rely on the |INDENT| and |DEDENT| tokens instead.
|INDENT| and |DEDENT| tokens fulfill the same role as e.g. `{` and `}` in C-like languages.
Example:
```
def foo:
bar
```
is transformed into:
```
def foo: |INDENT|
bar |DEDENT|
```
"""
indent_unit = derive_indentation_units(text)
lines = text.split("\n")

Expand Down Expand Up @@ -217,6 +242,11 @@ def parameter_decl(self, items):
def expr(self, items):
return items[0]

def assignment_expr(self, items):
if len(items) == 1:
return items[0]
return BinaryExpr(items[0], ":=", items[1]).with_location(self.loc(items))

def block_expr(self, items):
return self.filter(items)

Expand Down Expand Up @@ -355,22 +385,10 @@ def NUMBER(self, items):
def loc(self, items):
return Location.from_items(items, self.line_mappings, self.source_code)


def chain(ind_mappings, ml_mappings):
result = {}
for k, (line, char) in ind_mappings.items():
if line in ml_mappings:
result[k] = (ml_mappings[line], char)
else:
result[k] = (line, char)
return result


def transform(policy):
"""
Basic AST transformations to simplify the AST
Basic transformations to simplify the AST
"""

class PostParsingTransformations(Transformation):
# transforms FunctionCall with a ToolReference target into a SemanticPattern
def visit_FunctionCall(self, node: FunctionCall):
Expand All @@ -388,9 +406,22 @@ def visit_FunctionCall(self, node: FunctionCall):


def parse(text, path=None, verbose=True):
"""
Multi-stage parsing process to transform a string of IPL code into an Invariant Policy AST.
The parsing stages are as follows:
1. Indentation parsing: The input code is transformed into a version where indents are marked with |INDENT| and |DEDENT| tokens, instead of actual indentation.
2. Lark parsing: The indented code is parsed using the Lark parser as defined by the grammar above.
3. AST construction: The Lark parse tree is transformed into an AST.
4. AST post-processing: The AST is simplified and transformed.
5. Type checking: The AST is type-checked.
"""

# removes common leading indent (e.g. when parsing from an indented multiline string)
text = textwrap.dedent(text)
# creates source code hjandle
# creates source code handle
source_code = SourceCode(text, path=path, verbose=verbose)

# translates an indent-based code into code in which indented
Expand Down
19 changes: 0 additions & 19 deletions invariant/runtime/utils/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import Enum
from invariant.runtime.utils.base import BaseDetector, DetectorResult
from pydantic.dataclasses import dataclass, Field
from invariant.extras import codeshield_extra


class CodeSeverity(str, Enum):
Expand Down Expand Up @@ -133,24 +132,6 @@ def detect(self, text: str, ipython_mode=False) -> PythonDetectorResult:
return ast_visitor.res


class CodeShieldDetector(BaseDetector):
"""Detector which uses Llama CodeShield for safety (currently based on regex and semgrep rules)"""

async def scan_llm_output(self, llm_output_code):
self.CodeShield = codeshield_extra.package("codeshield.cs").import_names("CodeShield")
result = await self.CodeShield.scan_code(llm_output_code)
return result

def detect_all(self, text: str) -> list[CodeIssue]:
res = asyncio.run(self.scan_llm_output(text))
if res.issues_found is None:
return []
return [
CodeIssue(description=issue.description, severity=str(issue.severity).lower())
for issue in res.issues_found
]


class SemgrepDetector(BaseDetector):
"""Detector which uses Semgrep for safety evaluation."""

Expand Down
20 changes: 0 additions & 20 deletions invariant/stdlib/invariant/detectors/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from invariant.runtime.functions import cache

PYTHON_ANALYZER = None
CODE_SHIELD_DETECTOR = None
SEMGREP_DETECTOR = None

@cache
Expand Down Expand Up @@ -30,25 +29,6 @@ def ipython_code(data: str | list | dict, **config: dict) -> PythonDetectorResul
"""Predicate used to extract entities from IPython cell code."""
return python_code(data, ipython_mode=True, **config)

@cache
def code_shield(data: str | list | dict, **config: dict) -> list[CodeIssue]:
"""Predicate used to run CodeShield on code."""

global CODE_SHIELD_DETECTOR
if CODE_SHIELD_DETECTOR is None:
CODE_SHIELD_DETECTOR = CodeShieldDetector()

if type(data) is str:
return CODE_SHIELD_DETECTOR.detect_all(data, **config)
if type(data) is not list:
data = [data]

res = []
for message in data:
if message.content is None:
continue
res.extend(CODE_SHIELD_DETECTOR.detect_all(message.content, **config))
return res

@cache
def semgrep(data: str | list | dict, **config: dict) -> list[CodeIssue]:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dev-dependencies = [
"presidio-analyzer>=2.2.354",
"transformers>=4.41.1",
"torch>=2.3.0",
"codeshield>=1.0.1",
"python-dotenv>=1.0.1",
]

Expand Down
3 changes: 0 additions & 3 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ click-option-group==0.5.6
# via semgrep
cloudpathlib==0.18.1
# via weasel
codeshield==1.0.1
colorama==0.4.6
# via semgrep
confection==0.1.5
Expand Down Expand Up @@ -194,7 +193,6 @@ pygments==2.18.0
pytest==8.2.2
python-dotenv==1.0.1
pyyaml==6.0.1
# via codeshield
# via huggingface-hub
# via langchain
# via langchain-core
Expand Down Expand Up @@ -234,7 +232,6 @@ ruamel-yaml-clib==0.2.8
safetensors==0.4.3
# via transformers
semgrep==1.78.0
# via codeshield
# via invariant
setuptools==70.0.0
# via marisa-trie
Expand Down
2 changes: 1 addition & 1 deletion tests/test_html_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from invariant import Policy
from invariant.policy import analyze_trace
from invariant.traces import *
from invariant.extras import extras_available, presidio_extra, transformers_extra, codeshield_extra
from invariant.extras import extras_available, presidio_extra, transformers_extra
from invariant.traces import user, assistant, tool, tool_call

class TestHTMLParsing(unittest.TestCase):
Expand Down
10 changes: 9 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def test_with_member_access_call_in_addition(self):
(to: ToolOutput)
"File " + tc.function.arguments.arg.strip() + " not found" in to.content
""")

self.assertIsInstance(policy.statements[0].body[1], ast.BinaryExpr)
self.assertIsInstance(policy.statements[0].body[1].left, ast.BinaryExpr)
self.assertIsInstance(policy.statements[0].body[1].left.left, ast.BinaryExpr)
Expand All @@ -387,5 +386,14 @@ def test_with_member_access_call_in_addition(self):
self.assertEqual(policy.statements[0].body[1].left.left.right.name.expr.expr.expr.member, "function")
self.assertIsInstance(policy.statements[0].body[1].left.left.right.name.expr.expr.expr.expr, ast.Identifier)

def test_assign_in(self):
policy = parse("""
raise "error" if:
(msg: Message)
flag := ["a", "b"] in ["a"]
""")
self.assertEqual(policy.statements[0].body[1].op, ":=")


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit ae55c0d

Please sign in to comment.