diff --git a/experimental/functionsmith/README.md b/experimental/functionsmith/README.md new file mode 100644 index 000000000..11e232870 --- /dev/null +++ b/experimental/functionsmith/README.md @@ -0,0 +1,55 @@ + + +# Earth Engine Dataset Explorer + +## Overview + +Functionsmith is a general-purpose problem-solving agent. + +USING THIS AGENT IS UNSAFE. It directly runs LLM-produced code, and thus +should only be used for demonstration purposes. + +This agent uses free-form function calling, which means that instead of +relying on a fixed set of tools predefined in the agent +[in normal LLM function calling](https://ai.google.dev/gemini-api/docs/function-calling), +we let the agent itself write with all the functions it needs. + +The functionsmith system prompt asks the agent to first write any low-level +function it needs, as well as tests for them. The agent loop will try +to run these functions and ask the LLM to make corrections if necessary. +Once all the functions are ready, the agent will write and run the code +to solve the actual user task. + +The agent does not use function calling features of LLM clients. Instead, +it simply tries to parse all the ```python or ```tool_use sections +present in the raw LLM output. It keeps all function definitions as well +as their source code in memory. Each call to the LLM is preceded +by the function definitions to let the LLM know what functions are available +locally. + +The functions are not saved permanently, though this feature can be added. + +# TODO(simonf): add a notebook version of the general-purpose agent, +# as well as an Earth Engine-specific notebook agent. + +## Attribution + +Functionsmith was written by Simon Ilyushchenko (simonf@google.com). +I am grateful to Renee Johnston and other Googlers for implementation advice, +as well as to Earth Engine expert advisors Jeffrey Cardille, Erin Trochim, +Morgan Crowley, and Samapriya Roy, who helped me choose the right training +tasks. diff --git a/experimental/functionsmith/agent.py b/experimental/functionsmith/agent.py new file mode 100644 index 000000000..887de3021 --- /dev/null +++ b/experimental/functionsmith/agent.py @@ -0,0 +1,239 @@ +"""Functionsmith is a general-purpose problem-solving agent. + +It writes functions for its own future use, tests the functions, and then uses +them to solve a user-specified problem. + +USING THIS AGENT IS UNSAFE. It directly runs LLM-produced code, and thus +should only be used for demonstration purposes. + +To get started, get a data file: +``` +wget https://raw.githubusercontent.com/davidmegginson/ourairports-data/refs/heads/main/airports.csv +``` +then run +``` +python3 agent.py +``` + +See README.md for more information. + +To execute the sample task analyzing airports, get the airtports.csv file from +https://raw.githubusercontent.com/davidmegginson/ourairports-data/refs/heads/main/airports.csv + +The agent will ask the LLM to find "something interesting" about the data +given its schema. Then the LLM will probably create one or two sets of +low-level function with tests, then actually analyze the files, then stop +and ask the user if they want to do anything else. + +Before each code execution phase, the agent will print the code and ask +the user to hit "enter" to confirm the code looks safe to run. +""" + +import copy +import inspect +import logging +import os +import sys + +import code_parser +import executor +import llm + + +STARS = '*' * 20 + '\n' + + +class CustomLoggingHandler(logging.Handler): + + def emit(self, record): + msg = self.format(record) + print(msg) + + +system_prompt = """ +To solve the task given below, write first low-level python functions with +tests for each of them in a ```python block. Include all the necessary imports. +The tests should be as simple as possible and not rely on anything external. +All asserts in tests should have an error message to make sure their failure is +easy to detect. + +In later responses, never omit parts of the code referring to earlier output - +if you need to do this, define a function and then call it later. + +I will save the functions locally, and you can write higher-level code that +will invoke them later. I will pass you the output from the code or any error +messages. + +Call the task_done() function when you consider the task done. +Ask the user questions if you need additional input. + +If I ask you to compute factorial of 10 and then prompt the user if they want +more factorials computed, your responses should be like this (return one +response at a time): Example chat session (each response should be returned in +a separate answer): + + Question 1: + Please compute the factorial of 10 + + Response 1: + Let's define the requested function and test it. + ```python + import math + def factorial(x): + return math.factorial(x) + def test_factorial(): + assert factorial(3) == 6 + assert factorial(4) == 24 + print('success') + + test_factorial() + ``` + + Question 2: + The code output was "success" + + Response 2: + + Now let's call the previously defined function to solve the user task. + ```python + print(factorial(10)) + ``` + + Question 3: + The code output was "3628800" + + Response 3: + + The computed answer looks reasonable. Please enter a number if you want another factorial to be computed, or instruct me to exit. + + Question 4: + You can exit here + + Response 4: + + ```python + task_done('We can exit') + ``` + + +""" +if not os.path.exists('airports.csv'): + print( + """Download +https://raw.githubusercontent.com/davidmegginson/ourairports-data/refs/heads/main/airports.csv +if you'd like to run this task. +""", + file=sys.stderr, + ) + sys.exit(1) + +schema = """ +"id","ident","type","name","latitude_deg","longitude_deg","elevation_ft","continent","iso_country","iso_region","municipality","scheduled_service","gps_code","iata_code","local_code","home_link","wikipedia_link","keywords" +""" + +task = f""" +Please explore a local file airports.csv. First, make some hypotheses about the +data, and then write code to test them to learn something interesting about the +data. By 'interesting', I mean something you wouldn't have guessed from first +principles - eg, finding that the largest countries have the most airports is +not interesting. Explain why what you discovered seems interesting. When done, +ask the user if they want to find out something else about this file. Output +findings in text form, not as images or plots. + +Do not overwrite the original file in your code or tests. +The file has the following schema {schema}""" + +# If you need to debug the agent, use this simple task. +# task = """ +# Compute the factorial of 20. When done, ask the user in a chat response +# if they want to compute another factorial and compute it if they give you +# a new value""" + +# This code works with several different LLMs. Uncomment the one you +# have access to. Make sure to set the API key in the appropriate +# environment variable (GOOGLE_API_KEY, ANTHROPIC_API_KEY, or OPENAI_API_KEY). +agent = llm.Gemini(system_prompt, model_name='gemini-2.0-flash-exp') +# agent = llm.Claude(system_prompt, model_name='claude-3-5-sonnet-20241022') +# agent = llm.ChatGPT(system_prompt, model_name='o1-mini') + + +def task_done(agent_message: str) -> None: + """Returns control back to the user when the agent thinks the task is done. + + This function must always be invoked in a separate response, not at the end + of a code snippet doing something else. + + Args: + agent_message(str): the message that the agent wants to print before exit. + """ + print(agent_message) + import sys # pylint:disable=g-import-not-at-top,redefined-outer-name,reimported + + sys.exit(0) + + +syscalls = {} + +# Set up a custom logger to be passed to helper objects. +# This is an overkill for the command-line agent, but makes more sense +# for the notebook version of this agent. +logger = logging.getLogger('functionsmith') +logger.handlers = [] +logger.addHandler(CustomLoggingHandler()) +logger.propagate = False + +# Create the object that parses Python code out of LLM responses. +parser = code_parser.Parser(logger) +# Create the object that runs the LLM-generated Python code. +code_executor = executor.Executor(logger) + +# 'Syscalls' are functions for which stdout/stderr won't be intercepted. +# For now we only have one syscall, 'task_done'. +for f in [task_done]: + starting_tools = parser.extract_functions(inspect.getsource(f)) + syscalls.update(starting_tools.functions) + +tools = {} + +question = task + +while True: + print(STARS) + all_tools = copy.deepcopy(tools) + all_tools.update(syscalls) + question_with_tools = ( + question + + 'The following functions are available:\n' + + '\n'.join([str(x) for x in all_tools.values()]) + ) + response = agent.chat(question_with_tools) + print(response) + + parsed_response = parser.extract_functions(response) + if not parsed_response.code and not parsed_response.functions: + if parsed_response.error: + question = parsed_response.error + continue + question = input('> ') + continue + + tools.update(parsed_response.functions) + + if parsed_response.code: + # Concatenate all known source code together. + # Functions that were defined in the most recent response will be repeated, + # which is okay + code_with_tools = ( + '\n'.join([x.code for x in tools.values()]) + + '\n' + + parsed_response.code + ) + + print(STARS) + input('HIT ENTER TO RUN THIS CODE') + print(STARS) + question = code_executor.run_code(code_with_tools, {'task_done': task_done}) + else: + # The response had functions but no code. The agent wanted to define them. + # We tell it to go on (that is, to keep writing code). + question = 'go on' diff --git a/experimental/functionsmith/code_parser.py b/experimental/functionsmith/code_parser.py new file mode 100644 index 000000000..d6778041b --- /dev/null +++ b/experimental/functionsmith/code_parser.py @@ -0,0 +1,249 @@ +"""Parser for (probably LLM-generated) Python code.""" +import ast +import dataclasses +import logging +import re +from typing import Optional +import docstring_parser + +# pylint:disable=logging-fstring-interpolation + + +@dataclasses.dataclass +class Function: + code: str + name: str + docstring: str + parameters: list[dict[str, str]] + return_type: str + + +@dataclasses.dataclass +class MaybeCode: + code: str + code_block_found: bool + + +@dataclasses.dataclass +class ParsedResponse: + code: str = '' + functions: dict[str, Function] = dataclasses.field(default_factory=dict) + error: str = '' + + +class ImportFinder(ast.NodeVisitor): + """AST Visitor that finds and preserves the source text of all imports.""" + + def __init__(self, source_lines: list[str]): + self.source_lines = source_lines + self.import_statements = [] + + def _node_source(self, node: ast.AST) -> str: + """Extracts the source code for a node using line numbers.""" + if hasattr(node, 'lineno') and hasattr(node, 'end_lineno'): + start = node.lineno - 1 + end = node.end_lineno + if start == end - 1: + return self.source_lines[start].strip() + else: + return '\n'.join(line.strip() for line in self.source_lines[start:end]) + else: + # Fallback for nodes without position info + return ast.unparse(node) + + def visit_Import(self, node: ast.Import) -> None: # pylint:disable=invalid-name + """Handles regular import statements.""" + self.import_statements.append(self._node_source(node)) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # pylint:disable=invalid-name + """Handles 'from ... import ...' statements.""" + self.import_statements.append(self._node_source(node)) + self.generic_visit(node) + + +class Parser: + """Parser of (probably LLM-generated) Python code.""" + + _logger: logging.Logger + + def __init__(self, logger: Optional[logging.Logger] = None): + self._logger = logger or logging.getLogger() + + def _extract_function_details(self, func_def: ast.FunctionDef) -> Function: + """Extracts details of a function.""" + name = func_def.name + docstring = ast.get_docstring(func_def) or '' + parameters = self._extract_parameters(func_def) + return_type = self._extract_return_type(func_def) + + return Function( + code=ast.unparse(func_def), + name=name, + docstring=docstring, + parameters=parameters, + return_type=return_type, + ) + + def _extract_parameters( + self, func_def: ast.FunctionDef + ) -> list[dict[str, str]]: + """Extracts function parameter details.""" + parameters = [] + parsed_docstring = self._parse_docstring(ast.get_docstring(func_def)) + + # Handle positional arguments, default values, and keyword-only arguments + args = func_def.args + defaults = dict( + zip( + [arg.arg for arg in args.args[::-1]], + [ast.unparse(d) for d in args.defaults[::-1]], + ) + ) + + # Include keyword-only arguments + kwonly_args = { + arg.arg: ast.unparse(d) + for arg, d in zip(args.kwonlyargs, args.kw_defaults) + if d is not None + } + defaults.update(kwonly_args) + all_args = args.args + args.kwonlyargs + + for arg in all_args: + param_name = arg.arg + param_type = self._extract_type_hint(arg) + param_description = self._find_param_description( + parsed_docstring, param_name + ) + param_default = defaults.get(param_name, '') + + parameters.append({ + 'name': param_name, + 'type': param_type, + 'description': param_description, + 'default': param_default, + }) + + return parameters + + def _extract_type_hint(self, node: ast.arg) -> str: + """Extracts the type hint of a parameter as a string.""" + if node.annotation: + return ast.unparse(node.annotation) + else: + return '' + + def _find_param_description(self, parsed_docstring, param_name: str) -> str: + """Finds the description of a parameter in the parsed docstring. + + Handles cases where the docstring parser might not find descriptions + reliably, especially with complex type hints. + + Args: + parsed_docstring: the output of docstring_parser + param_name: the name of the parameter to look for + + Returns: + the description of the parameter, if found, or an empty string + """ + for param in parsed_docstring.params: + if param.arg_name == param_name: + return param.description + + # More precise workaround for missing descriptions: + if parsed_docstring.long_description: + # This regex uses a negative lookahead assertion to ensure that + # we only match the description of the current parameter + # and not the descriptions of subsequent parameters. + match = re.search( + rf'(?m)^\s*{param_name}\s*\(?.*?\)?:\s*(.*?)(?=\n\s*[\w-]+\s*\(?.*?\)?:\s*|$)', + parsed_docstring.long_description, + re.DOTALL, + ) + if match: + return match.group(1).strip() + + return '' + + def _extract_return_type(self, func_def: ast.FunctionDef) -> str: + """Extracts the return type of a function as a string.""" + if func_def.returns: + return ast.unparse(func_def.returns) + else: + return '' + + def _parse_docstring(self, docstring: str): + return docstring_parser.parse(docstring) + + def _reduce_indentation(self, code: str) -> str: + """Reduces indentation of the whole code block if it makes sense.""" + if not code: + return code + + lines = code.split('\n') + + # Filter out empty lines first. + non_empty_lines = [line for line in lines if line.strip()] + if not non_empty_lines: + return code + + # Find the minimum indentation level + min_indent = min(len(line) - len(line.lstrip()) for line in non_empty_lines) + + # Only remove indentation if ALL non-empty lines have at least + # this indentation + if all(line.startswith(' ' * min_indent) for line in non_empty_lines): + result = '\n'.join( + line[min_indent:] if line.strip() else '' for line in lines + ) + return result + + return code + + def extract_python_code_blocks(self, text: str) -> MaybeCode: + pattern = re.compile(r'```(?:python|tool_code)\n*(.*?)\n*```', re.DOTALL) + code_blocks = pattern.findall(text) + if code_blocks: + result = MaybeCode('\n'.join(code_blocks), code_block_found=True) + else: + result = text + result = MaybeCode(text, code_block_found=False) + result.code = self._reduce_indentation(result.code) + return result + + def extract_functions(self, response: str) -> ParsedResponse: + """Extracts functions from the code, including their details.""" + extracted = self.extract_python_code_blocks(response) + if not extracted.code: + return ParsedResponse() + + try: + tree = ast.parse(extracted.code) + except SyntaxError as e: + if not extracted.code_block_found: + return ParsedResponse() + error = f'ERROR PARSING CODE: {e}' + self._logger.warning(error) + return ParsedResponse('', {}, error) + + functions = {} + import_finder = ImportFinder(extracted.code.splitlines()) + import_finder.visit(tree) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_name = node.name + if func_name.startswith('test_'): + continue + + try: + function_details = self._extract_function_details(node) + functions[func_name] = function_details + self._logger.info(f'FUNCTION FOUND: {func_name}') + except Exception as e: # pylint:disable=broad-exception-caught + self._logger.error(f'ERROR DEFINING FUNCTION: {func_name} - {e}') + return ParsedResponse( + extracted.code, {}, 'Error extracting function details' + ) + + return ParsedResponse(extracted.code, functions) diff --git a/experimental/functionsmith/code_parser_test.py b/experimental/functionsmith/code_parser_test.py new file mode 100644 index 000000000..4fe350e8b --- /dev/null +++ b/experimental/functionsmith/code_parser_test.py @@ -0,0 +1,327 @@ +import ast +from typing import Any +import unittest + +import code_parser + + +class TestParser(unittest.TestCase): + + def setUp(self): + super().setUp() + self._parser = code_parser.Parser() + self.maxDiff = 5000 + + def create_dummy_func_def(self, annotation_string: str) -> ast.FunctionDef: + """Creates an AST FunctionDef node with a specific parameter annotation.""" + code = f'def dummy_func(param: {annotation_string}):\n pass' + tree = ast.parse(code) + return tree.body[0] + + def annotation_from_func_def(self, func_def: ast.FunctionDef) -> Any: + param_node = func_def.args.args[0] + return self._parser._extract_type_hint(param_node) + + def test_extract_single_block(self): + text = "```python\ndef foo():\n print('hello')\n```" + expected = "def foo():\n print('hello')" + result = self._parser.extract_python_code_blocks(text) + self.assertEqual(expected, result.code) + self.assertEqual(True, result.code_block_found) + + def test_extract_multiple_blocks(self): + text = ( + "```python\ndef foo():\n print('hello')\n```\n```python\ndef bar():\n" + " print('world')\n```" + ) + expected = "def foo():\n print('hello')\ndef bar():\n print('world')" + self.assertEqual( + self._parser.extract_python_code_blocks(text).code, expected) + + def test_extract_code_no_blocks(self): + text = 'print(2)' + result = self._parser.extract_python_code_blocks(text) + self.assertEqual(text, result.code) + self.assertEqual(False, result.code_block_found) + + def test_extract_empty_block(self): + text = '```python\n```' + self.assertEqual(self._parser.extract_python_code_blocks(text).code, '') + + def test_extract_with_leading_and_trailing_whitespace(self): + text = '```python \n def foo():\n print("hello")\n ```' + expected = '\ndef foo():\n print("hello")\n' + self.assertEqual( + self._parser.extract_python_code_blocks(text).code, expected + ) + + def test_no_annotation(self): + code = 'def dummy_func(param):\n pass' + tree = ast.parse(code) + func_def = tree.body[0] + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, '') + + def test_int_annotation(self): + func_def = self.create_dummy_func_def('int') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'int') + + def test_str_annotation(self): + func_def = self.create_dummy_func_def('str') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'str') + + def test_float_annotation(self): + func_def = self.create_dummy_func_def('float') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'float') + + def test_bool_annotation(self): + func_def = self.create_dummy_func_def('bool') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'bool') + + def test_list_int_annotation(self): + func_def = self.create_dummy_func_def('list[int]') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'list[int]') + + def test_dict_str_float_annotation(self): + func_def = self.create_dummy_func_def('dict[str, float]') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'dict[str, float]') + + def test_tuple_annotation(self): + func_def = self.create_dummy_func_def('tuple[int, str, bool]') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'tuple[int, str, bool]') + + def test_set_annotation(self): + func_def = self.create_dummy_func_def('set[str]') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'set[str]') + + def test_union_annotation(self): + func_def = self.create_dummy_func_def('Union[int, str]') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'Union[int, str]') + + def test_optional_annotation(self): + func_def = self.create_dummy_func_def('Optional[int]') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'Optional[int]') + + def test_any_annotation(self): + func_def = self.create_dummy_func_def('Any') + annotation = self.annotation_from_func_def(func_def) + self.assertEqual(annotation, 'Any') + + def test_unhandled_annotation_type(self): + # Create a custom AST node that is not handled by _parameter_annotation + class CustomAnnotation(ast.AST): + _fields = () + + param_node = ast.arg(arg='param', annotation=CustomAnnotation()) + annotation = self._parser._extract_type_hint(param_node) + self.assertEqual(annotation, ast.unparse(param_node.annotation)) + + def test_default_arg(self): + func_def = self.create_dummy_func_def('int = 10') + annotation = self.annotation_from_func_def(func_def) + default_value = func_def.args.defaults[0] + self.assertEqual(annotation, 'int') + self.assertEqual(ast.unparse(default_value), '10') + + def test_basic_function(self): + code = ''' +def my_function(a: str, b: int = 22): + """This is a docstring. + + Args: + a (str): first argument + b (int, optional): second argument + """ + ''' + tree = ast.parse(code) + func_def = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)) + function = self._parser._extract_function_details(func_def) + self.assertEqual(function.name, 'my_function') + self.assertEqual(function.return_type, '') + self.assertEqual(len(function.parameters), 2) + self.assertEqual(function.parameters[0]['name'], 'a') + self.assertEqual(function.parameters[0]['type'], 'str') + self.assertEqual(function.parameters[0]['description'], 'first argument') + + def test_import_finder(self): + code = """ +import foo +from bar import baz +def my_function(): + pass +""" + tree = ast.parse(code) + import_finder = code_parser.ImportFinder(code.splitlines()) + import_finder.visit(tree) + expected_imports = ['import foo', 'from bar import baz'] + self.assertEqual(import_finder.import_statements, expected_imports) + + def test_extract_functions_as_tools(self): + code = '''import typing +from typing import List + +def my_function(a: int, b: str = 'default') -> str: + """This is a docstring.""" + return f'{a} {b}' + +def test_my_function(a: int) -> str: + return "test"''' + + code_block = f""" +```python +{code} +``` +""" + parsed_response = self._parser.extract_functions(code_block) + self.assertEqual(code, parsed_response.code) + + self.assertEqual(len(parsed_response.functions), 1) + + saved_function = parsed_response.functions['my_function'] + self.assertEqual(saved_function.name, 'my_function') + self.assertEqual(saved_function.docstring, 'This is a docstring.') + self.assertEqual(saved_function.return_type, 'str') + + def test_function_with_typing_and_docstring_types(self): + code = ''' +def my_function(a: int, b: str = 'default', c: list[int] = None): + """This is a docstring. + + Args: + a (int): first argument + b (str): second argument + c (list[int]): third argument + """ + pass +''' + tree = ast.parse(code) + func_def = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)) + func = self._parser._extract_function_details(func_def) + self.assertEqual(func.parameters[0]['name'], 'a') + self.assertEqual(func.parameters[0]['type'], 'int') + self.assertEqual(func.parameters[0]['description'], 'first argument') + + self.assertEqual(func.parameters[1]['name'], 'b') + self.assertEqual(func.parameters[1]['type'], 'str') + self.assertEqual(func.parameters[1]['description'], 'second argument') + + self.assertEqual(func.parameters[2]['name'], 'c') + self.assertEqual(func.parameters[2]['type'], 'list[int]') + self.assertEqual(func.parameters[2]['description'], 'third argument') + + def test_function_with_only_typing_annotations(self): + code = ''' +def my_function(a: int, b: Optional[str] = None): + """This is a docstring. + + Args: + a: The first argument. + b: The second argument. + """ + pass + ''' + tree = ast.parse(code) + func_def = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)) + func = self._parser._extract_function_details(func_def) + self.assertEqual(func.parameters[0]['name'], 'a') + self.assertEqual(func.parameters[0]['type'], 'int') + self.assertEqual(func.parameters[0]['description'], 'The first argument.') + + self.assertEqual(func.parameters[1]['name'], 'b') + self.assertEqual(func.parameters[1]['type'], 'Optional[str]') + self.assertEqual(func.parameters[1]['description'], 'The second argument.') + + def test_function_with_only_docstring_annotations(self): + code = ''' +def my_function(a, b='foo'): + """This is a docstring. + + Args: + a (int): The first argument. + b (str, optional): The second argument. + """ + pass +''' + tree = ast.parse(code) + func_def = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)) + func = self._parser._extract_function_details(func_def) + self.assertEqual(func.parameters[0]['name'], 'a') + self.assertEqual(func.parameters[0]['type'], '') + self.assertEqual(func.parameters[0]['description'], 'The first argument.') + self.assertEqual(func.parameters[1]['name'], 'b') + self.assertEqual(func.parameters[1]['type'], '') + self.assertEqual(func.parameters[1]['description'], 'The second argument.') + + def test_function_with_complex_typing_annotations(self): + code = ''' +def my_function(a: List[Dict[str, int]], b: Union[int, float] = 0): + """This is a docstring. + + Args: + a: A list of dictionaries. + b: Either an integer or a float. + """ + pass + ''' + tree = ast.parse(code) + func_def = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)) + func = self._parser._extract_function_details(func_def) + self.assertEqual(func.parameters[0]['name'], 'a') + self.assertEqual(func.parameters[0]['type'], 'List[Dict[str, int]]') + self.assertEqual( + func.parameters[0]['description'], 'A list of dictionaries.' + ) + + self.assertEqual(func.parameters[1]['name'], 'b') + self.assertEqual(func.parameters[1]['type'], 'Union[int, float]') + self.assertEqual( + func.parameters[1]['description'], 'Either an integer or a float.' + ) + self.assertEqual(func.parameters[1]['default'], '0') + + def test_import_finder_with_comments_and_multiline(self): + code = """ +import os # This is a comment +from typing import ( + List, + Dict, # Another comment +) + +def my_function(): + pass +""" + tree = ast.parse(code) + import_finder = code_parser.ImportFinder(code.splitlines()) + import_finder.visit(tree) + expected_imports = [ + 'import os # This is a comment', + 'from typing import (\nList,\nDict, # Another comment\n)', + ] + self.assertEqual(import_finder.import_statements, expected_imports) + + def test_extract_functions_as_tools_invalid_code(self): + code = '```python\nthis is not valid python code\n```' + result = self._parser.extract_functions(code) + self.assertEqual('', result.code) + self.assertEqual( + 'ERROR PARSING CODE: invalid syntax (, line 1)', result.error + ) + + def test_reduce_indentation_empty_lines(self): + code = 'def foo():\n\n \n print("hello")\n \n print("world")\n' + expected = 'def foo():\n\n\n print("hello")\n\n print("world")\n' + self.assertEqual(self._parser._reduce_indentation(code), expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/experimental/functionsmith/executor.py b/experimental/functionsmith/executor.py new file mode 100644 index 000000000..26c7ddd3b --- /dev/null +++ b/experimental/functionsmith/executor.py @@ -0,0 +1,89 @@ +"""Python code executor for functionsmith agents. + +USING THIS CLASS IS UNSAFE. It directly runs LLM-produced code, and thus +should only be used for demonstration purposes. Real code executions +should happen in a sandbox or in a docker image. +""" + +import io +import logging +import sys +from typing import Callable + + +# Intercepting stderr leads to the error "lost sys.stderr", +# but in practice I don't find intercepting stderr necessary. +class OutputManager: + """Context manager for intercepting stdout.""" + + def __init__(self): + self.buffer = io.StringIO() + self.original_stdout = None + + def __enter__(self): + self.original_stdout = sys.stdout + sys.stdout = self.buffer + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout = self.original_stdout + + def get_value(self): + return self.buffer.getvalue() + + def bypass(self): + """Temporarily restores original stdout.""" + sys.stdout = self.original_stdout + + def restore_capture(self): + """Resumes capturing to buffer.""" + sys.stdout = self.buffer + + +class Executor: + """Class for executing a Python snippet.""" + + def __init__(self, logger=None): + self._logger = logger or logging.getLogger() + + # pylint:disable=g-bare-generic + def run_code(self, python_code: str, syscalls: dict[str, Callable]) -> str: + """Runs the given code and captures is output except for syscalls.""" + try: + wrapped_syscalls = {} + for name, func in syscalls.items(): + + def wrap_syscall(f): + def wrapped(*args, **kwargs): + output_manager.bypass() + try: + result = f(*args, **kwargs) + output_manager.buffer.write(str(result) + '\n') + return result + finally: + output_manager.restore_capture() + + return wrapped + + wrapped_syscalls[name] = wrap_syscall(func) + + output_manager = OutputManager() + with output_manager: + exec(python_code, {**wrapped_syscalls}) # pylint:disable=exec-used + output = output_manager.get_value() + + if output: + self._logger.warning(f'OUTPUT: {output}\n') + else: + self._logger.warning('NO OUTPUT\n') + if output: + return output + else: + return ( + 'CODE DID NOT PRINT ANYTHING TO STDOUT. IF YOU ARE EXPECTING A' + ' VALUE, USE print(), NOT return' + ) + except Exception as e: # pylint:disable=broad-exception-caught + self._logger.error(f'ERROR: {e}') # pylint:disable=logging-fstring-interpolation + error = f'Exception occurred: {e}.' + return error diff --git a/experimental/functionsmith/executor_test.py b/experimental/functionsmith/executor_test.py new file mode 100644 index 000000000..1fcb90771 --- /dev/null +++ b/experimental/functionsmith/executor_test.py @@ -0,0 +1,111 @@ +import io +import sys +import unittest +from unittest import mock + +import executor + + +class TestOutputManager(unittest.TestCase): + + def test_capture_output(self): + with executor.OutputManager() as manager: + print("Hello, world!") + print("This is a test.") + self.assertEqual(manager.get_value(), "Hello, world!\nThis is a test.\n") + + def test_empty_output(self): + with executor.OutputManager() as manager: + pass # No output + self.assertEqual(manager.get_value(), "") + + def test_bypass_and_restore(self): + with executor.OutputManager() as manager: + print("Captured 1") + manager.bypass() + print("Not captured", end="") + manager.restore_capture() + print("Captured 2") + + # Capture stdout to compare with the not captured output + captured_output = io.StringIO() + sys.stdout = captured_output + print("Not captured", end="") + sys.stdout = sys.__stdout__ + + self.assertEqual(manager.get_value(), "Captured 1\nCaptured 2\n") + self.assertEqual(captured_output.getvalue(), "Not captured") + + +class TestExecutor(unittest.TestCase): + + def setUp(self): + super().setUp() + self.logger = mock.MagicMock() + self.executor = executor.Executor(self.logger) + + def test_simple_code_execution(self): + code = "print('Hello from code')" + result = self.executor.run_code(code, {}) + self.assertEqual(result, "Hello from code\n") + self.logger.warning.assert_called_with("OUTPUT: Hello from code\n\n") + + def test_code_with_no_output(self): + code = "x = 5" + result = self.executor.run_code(code, {}) + self.assertEqual( + result, + "CODE DID NOT PRINT ANYTHING TO STDOUT. IF YOU ARE EXPECTING " + "A VALUE, USE print(), NOT return", + ) + self.logger.warning.assert_called_with("NO OUTPUT\n") + + def test_code_with_exception(self): + code = "1 / 0" + result = self.executor.run_code(code, {}) + self.assertIn("Exception occurred: division by zero", result) + self.logger.error.assert_called() + + def test_syscall_wrapping(self): + def mock_syscall(a, b): + return a + b + + code = "print(add(2, 3))" + result = self.executor.run_code(code, {"add": mock_syscall}) + self.assertEqual(result, "5\n5\n") + self.logger.warning.assert_called_with("OUTPUT: 5\n5\n\n") + + def test_syscall_that_prints_and_returns(self): + def mock_syscall(): + print("Hello from syscall") # Not captured in output buffer + return 42 + + code = "my_syscall()" + result = self.executor.run_code(code, {"my_syscall": mock_syscall}) + self.assertEqual(result, "42\n") + self.logger.warning.assert_called_with("OUTPUT: 42\n\n") + + def test_syscall_with_bypass_and_restore(self): + def mock_syscall(): + print("Hello from syscall") # Not captured in output buffer + return 42 + + code = """ +print('Before syscall') +print(my_syscall()) +print('After syscall')""" + result = self.executor.run_code(code, {"my_syscall": mock_syscall}) + self.assertEqual(result, "Before syscall\n42\n42\nAfter syscall\n") + + captured_output = io.StringIO() + sys.stdout = captured_output + print("Hello from syscall", end="") + sys.stdout = sys.__stdout__ + self.assertEqual(captured_output.getvalue(), "Hello from syscall") + self.logger.warning.assert_called_with( + "OUTPUT: Before syscall\n42\n42\nAfter syscall\n\n" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/experimental/functionsmith/llm.py b/experimental/functionsmith/llm.py new file mode 100644 index 000000000..77eb09cfa --- /dev/null +++ b/experimental/functionsmith/llm.py @@ -0,0 +1,134 @@ +"""Chat clients for various LLMs.""" + +import os +import time + +import anthropic +from google import genai +from google.genai import types +import openai + + +class LLM: + """Parent class for LLM clients.""" + + +class Gemini(LLM): + """Gemini client. + + Some model names: + + * gemini-1.5-flash + * gemini-1.5-pro + * gemini-exp-1206 + * gemini-2.0-flash-exp + """ + + def __init__(self, system_instruction, model_name='gemini-2.0-flash-exp'): + if 'GOOGLE_API_KEY' not in os.environ: + raise ValueError('Please set the environent variable GOOGLE_API_KEY') + client = genai.Client(api_key=os.environ['GOOGLE_API_KEY']) + self._chat = client.chats.create( + model=model_name, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=0.1, + ), + ) + + def chat(self, question): + """Sends a single message to the LLM and returns its response.""" + return self._chat.send_message(question).text + + +class Claude(LLM): + """Claude client. + + Some model names: + + * claude-3-5-sonnet-20241022 + * claude-3-opus-20240229 + """ + + def __init__(self, system_prompt, model_name='claude-3-5-sonnet-20241022'): + + if 'ANTHROPIC_API_KEY' not in os.environ: + raise ValueError('Please set the environent variable ANTHROPIC_API_KEY') + self._client = anthropic.Anthropic(api_key=os.environ['ANTHROPIC_API_KEY']) + self._system_prompt = system_prompt + self._model_name = model_name + self._messages = [] + + def _send_message(self, temperature) -> anthropic.types.Message: + """Sends a single message to the LLM as a part of a larger chat.""" + while True: + try: + return self._client.messages.create( + model=self._model_name, + system=self._system_prompt, + max_tokens=4096, + messages=self._messages, + temperature=temperature, + ) + except anthropic.RateLimitError: + time.sleep(10) + continue + except Exception as e: # pylint:disable=broad-exception-caught: + print(f'UNEXPECTED RESPONSE: {e}') + if '500' in str(e): + continue + + def chat(self, question, temperature=0.1): + """Sends a single message to the LLM and returns its response.""" + self._messages.append({'role': 'user', 'content': question}) + + response = self._send_message(temperature) + model_answer = response.content + self._messages.append({'role': 'assistant', 'content': model_answer}) + return ' '.join([x.text for x in response.content]) + + +class ChatGPT(LLM): + """ChatGPT client. + + Some model names: + + * gpt-4-turbo + * gpt-4o + * o1-preview + * o1-mini + """ + + def __init__(self, system_prompt, model_name='o1-mini'): + if 'OPENAI_API_KEY' not in os.environ: + raise ValueError('Please set the environent variable OPENAI_API_KEY') + self._model_name = model_name + self._system_prompt = system_prompt + self._client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) + self._messages = [{'role': 'user', 'content': system_prompt}] + + def _one_message(self, **kwargs): + """Sends a single message to the LLM with error handling.""" + while True: + try: + return self._client.chat.completions.create(**kwargs) + except Exception as e: # pylint:disable=broad-exception-caught: + if '429' in str(e): + time.sleep(10) + continue + print('UNEXPECTED RESPONSE') + print(e) + + def chat(self, question, temperature=1): + """Sends a single message to the LLM and returns its response.""" + self._messages.append({'role': 'user', 'content': question}) + + kwargs = { + 'model': self._model_name, + 'messages': self._messages, + 'temperature': temperature, + } + response = self._one_message(**kwargs) + content = response.choices[0].message.content + self._messages.append({'role': 'assistant', 'content': content}) + return content