Skip to content

Commit

Permalink
Merge pull request #9 from invariantlabs-ai/workspace
Browse files Browse the repository at this point in the history
[Feat]: Functionality for sensitvity of files in the workspace
  • Loading branch information
mbalunovic authored Aug 6, 2024
2 parents 47e5c10 + ae55c0d commit 2eb95a2
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 17 deletions.
4 changes: 0 additions & 4 deletions invariant/language/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,6 @@ def raise_stmt(self, items):
body = items[1:]
# filter hidden body tokens
body = self.filter(body)
# flatten exprs
while len(body) == 1:
body = body[0]

return RaisePolicy(items[0], body).with_location(self.loc(items))

def def_stmt(self, items):
Expand Down
2 changes: 1 addition & 1 deletion invariant/language/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
IPL_BUILTINS = [
"LLM", "Message", "ToolCall", "Function", "ToolOutput",
"PolicyViolation", "UpdateMessage", "UpdateMessageHandler",
"any",
"any", "empty",
"match", "len",
"min", "max", "sum",
"print"
Expand Down
10 changes: 5 additions & 5 deletions invariant/runtime/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
import inspect
import json
import warnings
import textwrap
import termcolor
from collections.abc import KeysView, ValuesView, ItemsView
from copy import deepcopy
from typing import Optional
from invariant.stdlib.invariant.nodes import Message, ToolCall, ToolOutput, Event
from rich.pretty import pprint as rich_print
#from rich import print as rich_print

import invariant.language.types as types

Expand Down Expand Up @@ -108,11 +106,13 @@ def merge(self, lists):
return [item for sublist in lists for item in sublist]

def select(self, selector, data="<root>"):
if data == "<root>":
data = self.data
if self.should_ignore(data):
return []
if isinstance(data, (KeysView, ValuesView, ItemsView)):
return self.select(selector, list(data))
type_name = self.type_name(selector)
if data == "<root>":
data = self.data

if type(data).__name__ == type_name:
return [data]
Expand Down
53 changes: 53 additions & 0 deletions invariant/runtime/utils/copyright/copyright.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from invariant.runtime.utils.base import BaseDetector, DetectorResult
from invariant.runtime.utils.copyright.software_licenses import *

# TODO: Maybe want to use more sophisticated approach like https://github.com/licensee/licensee at some point

SOFTWARE_LICENSES = {
"GNU_AGPL_V3": GNU_AGPL_V3,
"GNU_GPL_V2": GNU_GPL_V2,
"GNU_LGPL_V3": GNU_LGPL_V3,
"MOZILLA_PUBLIC_LICENSE_2.0": MOZILLA_PUBLIC_LICENSE_2_0,
"APACHE_LICENSE_2.0": APACHE_LICENSE_2_0,
"MIT_LICENSE": MIT_LICENSE,
"BOOST_SOFTWARE_LICENSE": BOOST_SOFTWARE_LICENSE,
}

COPYRIGHT_PATTERNS = [
"Copyright (C)",
"Copyright ©",
]

class CopyrightAnalyzer(BaseDetector):

def detect_software_licenses(self, text: str, threshold: int = 0.5) -> list[DetectorResult]:
# First check if text starts with the license string
for license_name, license_text in SOFTWARE_LICENSES.items():
if text.strip().startswith(license_text.strip()):
return [DetectorResult(license_name, 0, len(license_text))]

# Next, use heuristics that checks how many tokens of the license text are in the given text
res = []
text_tokens = set(text.strip().split(" "))
for license_name, license_text in SOFTWARE_LICENSES.items():
tokens = list(filter(lambda x: len(x) > 0, license_text.strip().split(" ")))
in_text = [token in text_tokens for token in tokens]
in_ratio = sum(in_text) / float(len(tokens))
if in_ratio >= threshold:
res += [DetectorResult(license_name, 0, len(license_text))]
return res

def detect_copyright_patterns(self, text: str, threshold: int = 0.5) -> list[DetectorResult]:
res = []
for pattern in COPYRIGHT_PATTERNS:
pos = text.find(pattern)
if pos != -1:
res += [DetectorResult("COPYRIGHT", pos, pos+len(pattern))]
return res

def detect_all(self, text: str, threshold: int = 0.5) -> list[DetectorResult]:
res = []
res.extend(self.detect_software_licenses(text, threshold))
res.extend(self.detect_copyright_patterns(text, threshold))
return res

36 changes: 36 additions & 0 deletions invariant/runtime/utils/copyright/software_licenses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
GNU_AGPL_V3 = """
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
"""

GNU_GPL_V2 = """
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
"""

GNU_LGPL_V3 = """
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
"""

MOZILLA_PUBLIC_LICENSE_2_0 = """
Mozilla Public License Version 2.0
"""

APACHE_LICENSE_2_0 = """
Apache License
Version 2.0, January 2004
"""

MIT_LICENSE = """
MIT License
"""

BOOST_SOFTWARE_LICENSE = """
Boost Software License - Version 1.0 - August 17th, 2003
"""

BSL_LICENSE = """
License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved.
“Business Source License” is a trademark of MariaDB Corporation Ab.
"""
6 changes: 3 additions & 3 deletions invariant/stdlib/invariant/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
# Utilities

def any(iterable):
if isinstance(iterable, list) and len(iterable) > 0:
if isinstance(iterable[0], DetectorResult):
return True
return py_builtins.any(iterable)

def empty(iterable) -> bool:
"""Returns True if iterable is empty, False otherwise."""
return len(iterable) == 0

# String operations

Expand Down
1 change: 1 addition & 0 deletions invariant/stdlib/invariant/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from invariant.stdlib.invariant.detectors.secrets import *
from invariant.stdlib.invariant.detectors.code import *
from invariant.stdlib.invariant.detectors.pii import *
from invariant.stdlib.invariant.detectors.copyright import *

30 changes: 30 additions & 0 deletions invariant/stdlib/invariant/detectors/copyright.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from invariant.runtime.functions import cache

COPYRIGHT_ANALYZER = None

@cache
def copyright(data: str | list, **config) -> list[str]:
"""Predicate which detects PII in the given data.
Returns the list of PII detected in the data.
Supported data types:
- str: A single message
"""
global COPYRIGHT_ANALYZER
if COPYRIGHT_ANALYZER is None:
from invariant.runtime.utils.copyright.copyright import CopyrightAnalyzer
COPYRIGHT_ANALYZER = CopyrightAnalyzer()

if type(data) is str:
return COPYRIGHT_ANALYZER.get_entities(COPYRIGHT_ANALYZER.detect_all(data))
if type(data) is not list:
data = [data]

all_copyright = []
for message in data:
if message.content is None:
continue
res = COPYRIGHT_ANALYZER.detect_all(message.content)
all_copyright.extend(COPYRIGHT_ANALYZER.get_entities(res))
return all_copyright
2 changes: 1 addition & 1 deletion invariant/stdlib/invariant/detectors/pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class PIIException(Exception):
llm_call: LLM

@cache
def pii(data: str | list, **config):
def pii(data: str | list, **config) -> list[str]:
"""Predicate which detects PII in the given data.
Returns the list of PII detected in the data.
Expand Down
94 changes: 94 additions & 0 deletions invariant/stdlib/invariant/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import re
from invariant.stdlib.invariant.errors import PolicyViolation
from pathlib import Path
from pydantic.dataclasses import dataclass
from typing import Optional, Callable

@dataclass
class File:
path: str
content: str


def filter_path(path: list[Path], pattern: Optional[str]) -> Path:
return pattern is None or path.match(pattern)


def join_paths(workspace_path: str, path: str) -> Path:
"""Checks if path is inside workspace_path and it exists."""
joined_path = Path(workspace_path) / Path(path)
if (not joined_path.is_relative_to(workspace_path)) or (not joined_path.exists()):
raise FileNotFoundError("Path does not exist or is not inside the workspace.")
return joined_path


def get_files(workspace_path: str, path: str = ".", pattern: Optional[str] = None) -> list[str]:
"""Returns the list of files in the current agent workspace."""
path = join_paths(workspace_path, path)
return [file for file in path.iterdir() if file.is_file() and filter_path(file, pattern)]


def get_tree_files(workspace_path: str, path: str = ".", pattern: Optional[str] = None) -> list[str]:
"""Returns the list of files in the whole directory tree of the agent workspace."""
path = join_paths(workspace_path, path)
return [file for file in path.glob("**/*") if file.is_file() and filter_path(file, pattern)]


def get_file_content(workspace_path: str, file_path: str) -> File:
"""Returns the content of a file in the agent workspace."""
file_path = join_paths(workspace_path, file_path)
with open(file_path, "r") as file:
return File(str(file_path), file.read())


def get_file_contents(workspace_path: str, path: str = ".", pattern: Optional[str] = None, tree: bool = True) -> list[File]:
"""Returns the content of all files in the given path in the agent workspace.
Args:
workspace_path: The path to the agent workspace.
path: The path to the directory to search for files.
pattern: A regular expression pattern to filter the files.
tree: If True, search the whole directory tree of the workspace.
"""
if tree:
files = get_tree_files(workspace_path, path)
else:
files = get_files(workspace_path, path)
return [get_file_content(workspace_path, file) for file in files]


def is_sensitive(file: File, func: Callable[[str], bool | list]) -> bool:
"""Returns True if the file content is sensitive according to the given function.
Args:
file: The file to check for content sensitivity.
func: The function that determines sensitivity (each should return bool or list of sensitive results)
"""
res = func(file.content)
if type(res) is bool:
return res
if type(res) is list:
return len(res) > 0
raise ValueError("The sensitivity filter function must return bool or list, found: " + str(type(res)))


def is_sensitive_dir(workspace_path: str,
funcs: list[Callable[[str], bool | list]],
path: str = ".",
pattern: Optional[str] = None,
tree: bool = True) -> bool:
"""Returns True if any file in the given directory is sensitive according to any of the given sensitivity functions
Args:
workspace_path: The path to the agent workspace.
funcs: The list of functions that determine sensitivity (each should return bool or list of sensitive results)
path: The path to the directory inside the workspace to search for files.
pattern: A regular expression pattern to filter the files.
tree: If True, search the whole directory tree of the workspace.
"""
files = get_file_contents(workspace_path, path, pattern, tree)
for file in files:
for func in funcs:
if is_sensitive(file, func):
return True
return False
67 changes: 64 additions & 3 deletions tests/test_stdlib_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest
import json
from invariant import Policy, RuleSet
import tempfile
from invariant import Policy

class TestStdlibFunctions(unittest.TestCase):

def test_simple(self):
policy = Policy.from_string(
"""
Expand All @@ -25,5 +26,65 @@ def test_simple(self):
analysis_result = policy.analyze(input)
assert len(analysis_result.errors) == 1, "Expected one error, but got: " + str(analysis_result.errors)


class TestFiles(unittest.TestCase):

def test_sensitive_types(self):
with tempfile.TemporaryDirectory() as temp_dir:
with open(temp_dir + "/file1.docx", "w") as f:
f.write("test")
policy = Policy.from_string(
"""
from invariant.files import get_tree_files
raise "error" if:
not empty(get_tree_files(input.workspace, pattern="*.docx"))
""")
res = policy.analyze([], workspace=temp_dir)
self.assertEqual(len(res.errors), 1)


def test_sensitive_contents(self):
with tempfile.TemporaryDirectory() as temp_dir:
with open(temp_dir + "/file1.txt", "w") as f:
f.write("bob@gmail.com")
with open(temp_dir + "/file2.txt", "w") as f:
f.write("AB")

policy = Policy.from_string(
"""
from invariant.files import get_file_contents, File
raise "error" if:
(msg: Message)
file_contents := get_file_contents(input.workspace)
(file: File) in file_contents
msg.content in file.content
""")
res = policy.analyze([{"role": "user", "content": "AB"}], workspace=temp_dir)
self.assertEqual(len(res.errors), 1)
res = policy.analyze([{"role": "user", "content": "GH"}], workspace=temp_dir)
self.assertEqual(len(res.errors), 0)

policy2 = Policy.from_string(
"""
from invariant.files import is_sensitive_dir
from invariant.detectors import pii
raise "error" if:
(msg: Message)
is_sensitive_dir(input.workspace, [pii])
"AB" in msg.content
""")
input = [{"role": "user", "content": "AB"}]
res = policy2.analyze(input, workspace=temp_dir)
self.assertEqual(len(res.errors), 1)

with open(temp_dir + "/file1.txt", "w") as f:
f.write("CD")
input = [{"role": "user", "content": "AB"}]
res = policy2.analyze(input, workspace=temp_dir)
self.assertEqual(len(res.errors), 0)

if __name__ == "__main__":
unittest.main()
unittest.main()
Loading

0 comments on commit 2eb95a2

Please sign in to comment.