From 9feda04719a50028858790ae33467c7aa61fb06a Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Fri, 25 Aug 2023 18:05:57 +0000 Subject: [PATCH] WIP refactor and add tests --- DeepResearchTool/deep_research/__init__.py | 0 .../deep_research/summary_generator.py | 79 +++++++ DeepResearchTool/deep_research/topics.py | 10 + DeepResearchTool/deep_research_writer_tool.py | 71 +------ DeepResearchTool/topic_managers.py | 40 ++-- poetry.lock | 192 ++++++++++++++---- pyproject.toml | 12 +- tests/test_deep_research_writer_tool.py | 56 +++++ 8 files changed, 333 insertions(+), 127 deletions(-) create mode 100644 DeepResearchTool/deep_research/__init__.py create mode 100644 DeepResearchTool/deep_research/summary_generator.py create mode 100644 DeepResearchTool/deep_research/topics.py create mode 100644 tests/test_deep_research_writer_tool.py diff --git a/DeepResearchTool/deep_research/__init__.py b/DeepResearchTool/deep_research/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/DeepResearchTool/deep_research/summary_generator.py b/DeepResearchTool/deep_research/summary_generator.py new file mode 100644 index 0000000..a7232e6 --- /dev/null +++ b/DeepResearchTool/deep_research/summary_generator.py @@ -0,0 +1,79 @@ +import logging +from typing import Callable, Optional + +from langchain.chat_models import ChatOpenAI +from langchain.schema import SystemMessage + + +class SummaryGenerator: + _user_query: str + _topics: list[dict] + _notes_getter: Callable[[str], str] + _chat: Callable[[str], str] + + def __init__( + self, + user_query: str, + topics: list[dict], + notes_getter: Callable[[str], str], + chat: Optional[Callable[[str], str]] = None, + ) -> None: + self._user_query = user_query + self._topics = topics + self._notes_getter = notes_getter + self._chat = chat or self._get_response_from_openai + + def _format_topic(self, topic: dict) -> str: + notes = self._notes_getter(topic["notes_file"]) + topic_str = f""" + Topic name: {topic["name"]} + Topic description: {topic["description"]} + Relevant because: {topic["relevant_because"]} + Notes: {notes} + """ + return topic_str + + def _generate_markdown_prompt(self, user_query: str, topics: list[dict]) -> str: + topics_str_list = [self._format_topic(topic) for topic in topics] + markdown_prompt = f""" + The user query is: {user_query} + + ### + + Given the following topics and notes about the topic, write an article addressing the user query + the best you can. If there is a question, try to answer it. If the user query has incorrect + facts or assumptions, address that. + + Start with a problem statement of some sort based on the user query, then follow up with a conclusion. + After the conclusion, explain how that conclusion was derived from the + topics researched. If needed, create a section for relevant topic, if it is important enough, + and explain how the topic contributes to the conclusion. You do not need to specifically mention + the conclusion when describing topics. + + When you can, cite your sources + + ### The topics are: + + {" # next topic # ".join(topics_str_list)} + + # Reminder! The conclusion should be helpful and specific. If there are upper and lower bounds or circumstances where something + may be true or false, then define it. If you cannot, then identify further research needed to get there. Do not make anything up! + If you do not know why you know something, then do not mention it, or identify further research needed to confirm it. + + Use inline citations. + + Markdown file contents: + """ + return markdown_prompt + + def get_markdown_summary(self) -> str: + markdown_prompt = self._generate_markdown_prompt(self._user_query, self._topics) + logging.warning(markdown_prompt) + return self._chat(markdown_prompt) + + def _get_response_from_openai(self, markdown_prompt: str) -> str: + OPEN_AI_MODEL = "gpt-4" + chat = ChatOpenAI(model=OPEN_AI_MODEL, temperature=0) + system_message_prompt = SystemMessage(content=markdown_prompt) + response = chat([system_message_prompt]) + return response.content diff --git a/DeepResearchTool/deep_research/topics.py b/DeepResearchTool/deep_research/topics.py new file mode 100644 index 0000000..80c505b --- /dev/null +++ b/DeepResearchTool/deep_research/topics.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class Topic: + name: str + description: str + notes_file: str + relevant_because: str + researched: bool = False diff --git a/DeepResearchTool/deep_research_writer_tool.py b/DeepResearchTool/deep_research_writer_tool.py index 96d32ed..840dbe6 100644 --- a/DeepResearchTool/deep_research_writer_tool.py +++ b/DeepResearchTool/deep_research_writer_tool.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Type +from typing import Any, Optional, Type from langchain.chat_models import ChatOpenAI from langchain.schema import SystemMessage @@ -10,6 +10,7 @@ from superagi.tools.base_tool import BaseTool from DeepResearchTool.const import SINGLE_FILE_OUTPUT_FILE, TOPICS_FILE, USER_QUERY_FILE +from DeepResearchTool.deep_research.summary_generator import SummaryGenerator class DeepResearchWriter(BaseModel): @@ -30,73 +31,21 @@ class DeepResearchWriterTool(BaseTool): llm: Optional[BaseLlm] = None resource_manager: Optional[FileManager] = None + def _notes_getter(self, notes_file: str) -> str: + assert self.resource_manager + + return self.resource_manager.read_file(notes_file) + def _execute(self, desired_output_format: str | None = None) -> str: assert self.resource_manager assert self.llm self.llm.temperature = 0 - user_query = self.resource_manager.read_file(USER_QUERY_FILE) - topics = self.resource_manager.read_file(TOPICS_FILE) - - topics_str_list = [] - - for topic in json.loads(topics): - notes = self.resource_manager.read_file(topic["notes_file"]) - # format is: - # name, description, notes_file, relevant_because, researched - topic_str = f""" -Topic name: {topic["name"]} -Topic description: {topic["description"]} -Relevant because: {topic["relevant_because"]} -Notes: {notes} - """ - topics_str_list.append(topic_str) - - markdown_prompt = f""" - The user query is: {user_query} - - ### - - Given the following topics and notes about the topic, write an article addressing the user query - the best you can. If there is an question, try to answer it. If the user query has incorrect - facts or assumptions, address that. - - Start with a problem statement of some sort based on the user query, then follow up with a conclusion. - After the conclusion, explain how that conclusion was derived from the - topics researched. If needed, create a section for relevant topic, if it is important enough, - and explain how the topic contributes to the conclusion. You do not need to specifically mention - the conclusion when describing topics. - - When you can, cite your sources - - ### The topics are: - - {" # next topic # ".join(topics_str_list)} - - # Reminder! The conclusion should be helpful and specific. If there are upper and lower bounds or circumstances where something - may be true or false, then define it. If you cannot, then identify further research needed to get there. Do not make anything up! - If you do not know why you know something, then do not mention it, or identify further research needed to confirm it. - - Use inline citations. - - Markdown file contents: - """ - logging.warning(markdown_prompt) - - OPEN_AI_MODEL = "gpt-4-32k" # not yet available - OPEN_AI_MODEL = "gpt-4" - - chat = ChatOpenAI(model=OPEN_AI_MODEL, temperature=0) - - system_message_prompt = SystemMessage(content=markdown_prompt) - response = chat([system_message_prompt]) - content = response.content - - # content = self.llm.chat_completion([{"role": "system", "content": markdown_prompt}])[ - # "content" - # ] + user_query, topics = self._read_files() + summary_writer = SummaryGenerator(user_query, topics, self._notes_getter) + content = summary_writer.get_markdown_summary() self.resource_manager.write_file(SINGLE_FILE_OUTPUT_FILE, content) return f"Deep research completed! Check the resource manager for {SINGLE_FILE_OUTPUT_FILE} to view the result!" diff --git a/DeepResearchTool/topic_managers.py b/DeepResearchTool/topic_managers.py index b040f6f..3d20c4c 100644 --- a/DeepResearchTool/topic_managers.py +++ b/DeepResearchTool/topic_managers.py @@ -1,41 +1,43 @@ import json import logging -from dataclasses import asdict, dataclass +from dataclasses import asdict from superagi.resource_manager.file_manager import FileManager from DeepResearchTool.const import TOPICS_FILE +from DeepResearchTool.deep_research.topics import Topic -@dataclass -class Topic: - name: str - description: str - notes_file: str - relevant_because: str - researched: bool = False +class ManagedTopic: + def __init__(self, topic: Topic, file_manager: FileManager) -> None: + self.topic = topic + self.file_manager = file_manager - def initialize_notes_file(self, file_manager: FileManager) -> None: - logging.info(f"Initializing notes file: {self.notes_file}") - file_manager.write_file(self.notes_file, json.dumps([])) + def initialize_notes_file(self) -> None: + logging.info(f"Initializing notes file: {self.topic.notes_file}") + self.file_manager.write_file(self.topic.notes_file, json.dumps([])) - def mark_as_researched(self, file_manager: FileManager) -> None: - topics_file = json.loads(file_manager.read_file(TOPICS_FILE)) + def mark_as_researched(self) -> None: + topics_file = json.loads(self.file_manager.read_file(TOPICS_FILE)) for topic in topics_file: - if topic["name"] == self.name: + if topic["name"] == self.topic.name: topic["researched"] = True break - file_manager.write_file(TOPICS_FILE, json.dumps(topics_file)) + self.file_manager.write_file(TOPICS_FILE, json.dumps(topics_file)) class TopicsManager: def __init__(self, file_manager: FileManager) -> None: self._file_manager = file_manager - def load_topics(self) -> list[Topic]: - return [Topic(**topic) for topic in json.loads(self._file_manager.read_file(TOPICS_FILE))] + def load_topics(self) -> list[ManagedTopic]: + return [ + ManagedTopic(Topic(**topic), self._file_manager) + for topic in json.loads(self._file_manager.read_file(TOPICS_FILE)) + ] - def write_topics(self, topics: list[Topic]) -> None: + def write_topics(self, topics: list[ManagedTopic]) -> None: + writing_topics = [topic.topic for topic in topics] self._file_manager.write_file( - TOPICS_FILE, json.dumps([asdict(topic) for topic in topics]) + TOPICS_FILE, json.dumps([asdict(topic) for topic in writing_topics]) ) diff --git a/poetry.lock b/poetry.lock index bdee626..d4e9f04 100644 --- a/poetry.lock +++ b/poetry.lock @@ -122,6 +122,23 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "asttokens" +version = "2.2.1" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.2.1-py2.py3-none-any.whl", hash = "sha256:6b0ac9e93fb0335014d382b8fa9b3afa7df546984258005da0b9e7095b3deb1c"}, + {file = "asttokens-2.2.1.tar.gz", hash = "sha256:4622110b2a6f30b77e1473affaa97e711bc2f07d3f10848420ff1898edbe94f3"}, +] + +[package.dependencies] +six = "*" + +[package.extras] +test = ["astroid", "pytest"] + [[package]] name = "async-timeout" version = "4.0.3" @@ -334,6 +351,20 @@ typing-inspect = ">=0.4.0" [package.extras] dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest (>=7.2.0)", "setuptools", "simplejson", "twine", "types-dataclasses", "wheel"] +[[package]] +name = "executing" +version = "1.2.0" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = "*" +files = [ + {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"}, + {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, +] + +[package.extras] +tests = ["asttokens", "littleutils", "pytest", "rich"] + [[package]] name = "frozenlist" version = "1.4.0" @@ -585,6 +616,21 @@ files = [ [package.dependencies] marshmallow = ">=2.0.0" +[[package]] +name = "megamock" +version = "0.1.0b7" +description = "Mega mocking capabilities - stop using dot-notated paths!" +optional = false +python-versions = ">=3.10,<4.0" +files = [ + {file = "megamock-0.1.0b7-py3-none-any.whl", hash = "sha256:e26960e482fef279255dc7639854b53d1d216416eac85b52e68c315e9d1abe83"}, + {file = "megamock-0.1.0b7.tar.gz", hash = "sha256:cd3a8b2c60d19a32e5b7b2b751b74d83908c96845234c9f447128e551fba8361"}, +] + +[package.dependencies] +asttokens = ">=2.2.1,<2.3.0" +varname = {version = ">=0.10.0,<0.11.0", extras = ["asttokens"]} + [[package]] name = "multidict" version = "6.0.4" @@ -919,6 +965,35 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-dotenv" +version = "0.5.2" +description = "A py.test plugin that parses environment files before running tests" +optional = false +python-versions = "*" +files = [ + {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, + {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, +] + +[package.dependencies] +pytest = ">=5.0.0" +python-dotenv = ">=0.9.1" + +[[package]] +name = "python-dotenv" +version = "1.0.0" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"}, + {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pyyaml" version = "6.0" @@ -1015,54 +1090,65 @@ files = [ {file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"}, ] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sqlalchemy" -version = "2.0.20" +version = "2.0.15" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759b51346aa388c2e606ee206c0bc6f15a5299f6174d1e10cadbe4530d3c7a98"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1506e988ebeaaf316f183da601f24eedd7452e163010ea63dbe52dc91c7fc70e"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5768c268df78bacbde166b48be788b83dddaa2a5974b8810af422ddfe68a9bc8"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3f0dd6d15b6dc8b28a838a5c48ced7455c3e1fb47b89da9c79cc2090b072a50"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:243d0fb261f80a26774829bc2cee71df3222587ac789b7eaf6555c5b15651eed"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6eb6d77c31e1bf4268b4d61b549c341cbff9842f8e115ba6904249c20cb78a61"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-win32.whl", hash = "sha256:bcb04441f370cbe6e37c2b8d79e4af9e4789f626c595899d94abebe8b38f9a4d"}, - {file = "SQLAlchemy-2.0.20-cp310-cp310-win_amd64.whl", hash = "sha256:d32b5ffef6c5bcb452723a496bad2d4c52b346240c59b3e6dba279f6dcc06c14"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dd81466bdbc82b060c3c110b2937ab65ace41dfa7b18681fdfad2f37f27acdd7"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6fe7d61dc71119e21ddb0094ee994418c12f68c61b3d263ebaae50ea8399c4d4"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4e571af672e1bb710b3cc1a9794b55bce1eae5aed41a608c0401885e3491179"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3364b7066b3c7f4437dd345d47271f1251e0cfb0aba67e785343cdbdb0fff08c"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1be86ccea0c965a1e8cd6ccf6884b924c319fcc85765f16c69f1ae7148eba64b"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1d35d49a972649b5080557c603110620a86aa11db350d7a7cb0f0a3f611948a0"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-win32.whl", hash = "sha256:27d554ef5d12501898d88d255c54eef8414576f34672e02fe96d75908993cf53"}, - {file = "SQLAlchemy-2.0.20-cp311-cp311-win_amd64.whl", hash = "sha256:411e7f140200c02c4b953b3dbd08351c9f9818d2bd591b56d0fa0716bd014f1e"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3c6aceebbc47db04f2d779db03afeaa2c73ea3f8dcd3987eb9efdb987ffa09a3"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d3f175410a6db0ad96b10bfbb0a5530ecd4fcf1e2b5d83d968dd64791f810ed"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea8186be85da6587456c9ddc7bf480ebad1a0e6dcbad3967c4821233a4d4df57"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3d99ba99007dab8233f635c32b5cd24fb1df8d64e17bc7df136cedbea427897"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:76fdfc0f6f5341987474ff48e7a66c3cd2b8a71ddda01fa82fedb180b961630a"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-win32.whl", hash = "sha256:d3793dcf5bc4d74ae1e9db15121250c2da476e1af8e45a1d9a52b1513a393459"}, - {file = "SQLAlchemy-2.0.20-cp37-cp37m-win_amd64.whl", hash = "sha256:79fde625a0a55220d3624e64101ed68a059c1c1f126c74f08a42097a72ff66a9"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:599ccd23a7146e126be1c7632d1d47847fa9f333104d03325c4e15440fc7d927"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1a58052b5a93425f656675673ef1f7e005a3b72e3f2c91b8acca1b27ccadf5f4"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79543f945be7a5ada9943d555cf9b1531cfea49241809dd1183701f94a748624"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63e73da7fb030ae0a46a9ffbeef7e892f5def4baf8064786d040d45c1d6d1dc5"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3ce5e81b800a8afc870bb8e0a275d81957e16f8c4b62415a7b386f29a0cb9763"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb0d3e94c2a84215532d9bcf10229476ffd3b08f481c53754113b794afb62d14"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-win32.whl", hash = "sha256:8dd77fd6648b677d7742d2c3cc105a66e2681cc5e5fb247b88c7a7b78351cf74"}, - {file = "SQLAlchemy-2.0.20-cp38-cp38-win_amd64.whl", hash = "sha256:6f8a934f9dfdf762c844e5164046a9cea25fabbc9ec865c023fe7f300f11ca4a"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:26a3399eaf65e9ab2690c07bd5cf898b639e76903e0abad096cd609233ce5208"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4cde2e1096cbb3e62002efdb7050113aa5f01718035ba9f29f9d89c3758e7e4e"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1b09ba72e4e6d341bb5bdd3564f1cea6095d4c3632e45dc69375a1dbe4e26ec"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b74eeafaa11372627ce94e4dc88a6751b2b4d263015b3523e2b1e57291102f0"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:77d37c1b4e64c926fa3de23e8244b964aab92963d0f74d98cbc0783a9e04f501"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:eefebcc5c555803065128401a1e224a64607259b5eb907021bf9b175f315d2a6"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-win32.whl", hash = "sha256:3423dc2a3b94125094897118b52bdf4d37daf142cbcf26d48af284b763ab90e9"}, - {file = "SQLAlchemy-2.0.20-cp39-cp39-win_amd64.whl", hash = "sha256:5ed61e3463021763b853628aef8bc5d469fe12d95f82c74ef605049d810f3267"}, - {file = "SQLAlchemy-2.0.20-py3-none-any.whl", hash = "sha256:63a368231c53c93e2b67d0c5556a9836fdcd383f7e3026a39602aad775b14acf"}, - {file = "SQLAlchemy-2.0.20.tar.gz", hash = "sha256:ca8a5ff2aa7f3ade6c498aaafce25b1eaeabe4e42b73e25519183e4566a16fc6"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78303719c6f72af97814b0072ad18bee72e70adca8d95cf8fecd59c5e1ddb040"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9d810b4aacd5ef4e293aa4ea01f19fca53999e9edcfc4a8ef1146238b30bdc28"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3fb5d09f1d51480f711b69fe28ad42e4f8b08600a85ab2473baee669e1257800"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51b19887c96d405599880da6a7cbdf8545a7e78ec5683e46a43bac8885e32d0f"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d6b17cb86908e7f88be14007d6afe7d2ab11966e373044137f96a6a4d83eb21c"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:df25052b92bd514357a9b370d74f240db890ea79aaa428fb893520e10ee5bc18"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-win32.whl", hash = "sha256:55ec62ddc0200b4fee94d11abbec7aa25948d5d21cb8df8807f4bdd3c51ba44b"}, + {file = "SQLAlchemy-2.0.15-cp310-cp310-win_amd64.whl", hash = "sha256:ae1d8deb391ab39cc8f0d5844e588a115ae3717e607d91482023917f920f777f"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4670ce853cb25f72115a1bbe366ae13cf3f28fc5c87222df14f8d3d55d51816e"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cea7c4a3dfc2ca61f88a2b1ddd6b0bfbd116c9b1a361b3b66fd826034b833142"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f5784dfb2d45c19cde03c45c04a54bf47428610106197ed6e6fa79f33bc63d3"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b31ebde27575b3b0708673ec14f0c305c4564d995b545148ab7ac0f4d9b847a"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6b42913a0259267e9ee335da0c36498077799e59c5e332d506e72b4f32de781d"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6a3f8020e013e9b3b7941dcf20b0fc8f7429daaf7158760846731cbd8caa5e45"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-win32.whl", hash = "sha256:88ab245ed2c96265441ed2818977be28c840cfa5204ba167425d6c26eb67b7e7"}, + {file = "SQLAlchemy-2.0.15-cp311-cp311-win_amd64.whl", hash = "sha256:5cc48a7fda2b5c5b8860494d6c575db3a101a68416492105fed6591dc8a2728a"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f6fd3c88ea4b170d13527e93be1945e69facd917661d3725a63470eb683fbffe"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e885dacb167077df15af2f9ccdacbd7f5dd0d538a6d74b94074f2cefc7bb589"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:201a99f922ac8c780b3929128fbd9df901418877c70e160e19adb05665e51c31"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e17fdcb8971e77c439113642ca8861f9465e21fc693bd3916654ceef3ac26883"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db269f67ed17b07e80aaa8fba1f650c0d84aa0bdd9d5352e4ac38d5bf47ac568"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-win32.whl", hash = "sha256:994a75b197662e0608b6a76935d7c345f7fd874eac0b7093d561033db61b0e8c"}, + {file = "SQLAlchemy-2.0.15-cp37-cp37m-win_amd64.whl", hash = "sha256:4d61731a35eddb0f667774fe15e5a4831e444d066081d1e809e1b8a0e3f97cae"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f7f994a53c0e6b44a2966fd6bfc53e37d34b7dca34e75b6be295de6db598255e"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:79bfe728219239bdc493950ea4a4d15b02138ecb304771f9024d0d6f5f4e3706"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d6320a1d175447dce63618ec997a53836de48ed3b44bbe952f0b4b399b19941"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f80a9c9a9af0e4bd5080cc0955ce70274c28e9b931ad7e0fb07021afcd32af6"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4a75fdb9a84072521bb2ebd31eefe1165d4dccea3039dda701a864f4b5daa17f"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:21c89044fc48a25c2184eba332edeffbbf9367913bb065cd31538235d828f06f"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-win32.whl", hash = "sha256:1a0754c2d9f0c7982bec0a31138e495ed1f6b8435d7e677c45be60ec18370acf"}, + {file = "SQLAlchemy-2.0.15-cp38-cp38-win_amd64.whl", hash = "sha256:bc5c2b0da46c26c5f73f700834f871d0723e1e882641932468d56833bab09775"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:670ecf74ee2e70b917028a06446ad26ff9b1195e84b09c3139c215123d57dc30"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d14282bf5b4de87f922db3c70858953fd081ef4f05dba6cca3dd705daffe1cc9"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:256b2b9660e51ad7055a9835b12717416cf7288afcf465107413917b6bb2316f"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810199d1c5b43603a9e815ae9487aef3ab1ade7ed9c0c485e12519358929fbfe"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:536c86ec81ca89291d533ff41a3a05f9e4e88e01906dcee0751fc7082f3e8d6c"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:435f6807fa6a0597d84741470f19db204a7d34625ea121abd63e8d95f673f0c4"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-win32.whl", hash = "sha256:da7381a883aee20b7d2ffda17d909b38134b6a625920e65239a1c681881df800"}, + {file = "SQLAlchemy-2.0.15-cp39-cp39-win_amd64.whl", hash = "sha256:788d1772fb8dcd12091ca82809eef504ce0f2c423e45284bc351b872966ff554"}, + {file = "SQLAlchemy-2.0.15-py3-none-any.whl", hash = "sha256:933d30273861fe61f014ce2a7e3c364915f5efe9ed250ec1066ca6ea5942c0bd"}, + {file = "SQLAlchemy-2.0.15.tar.gz", hash = "sha256:2e940a8659ef870ae10e0d9e2a6d5aaddf0ff6e91f7d0d7732afc9e8c4be9bbc"}, ] [package.dependencies] @@ -1070,7 +1156,7 @@ greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or pl typing-extensions = ">=4.2.0" [package.extras] -aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] +aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] @@ -1089,7 +1175,6 @@ postgresql-pg8000 = ["pg8000 (>=1.29.1)"] postgresql-psycopg = ["psycopg (>=3.0.7)"] postgresql-psycopg2binary = ["psycopg2-binary"] postgresql-psycopg2cffi = ["psycopg2cffi"] -postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3-binary"] @@ -1165,6 +1250,23 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "varname" +version = "0.10.0" +description = "Dark magics about variable names in python." +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "varname-0.10.0-py3-none-any.whl", hash = "sha256:20748d5cd3e125350726cd39d2cbd0e3000f30b3e0d3d5fe827efa0e71729809"}, + {file = "varname-0.10.0.tar.gz", hash = "sha256:045f7a409b3e91a760ab10a5539aabbb292db9d685f3011920b85fd4dbc5b9e3"}, +] + +[package.dependencies] +executing = ">=1.1,<2.0" + +[package.extras] +all = ["asttokens (>=2.0.0,<3.0.0)", "pure_eval (<1.0.0)"] + [[package]] name = "yarl" version = "1.9.2" @@ -1255,4 +1357,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "d49fbce1f8a118b01ca7d0b64928f6d5a051adc472e3f6ea68ac08400e182cfa" +content-hash = "b19c71878ee8d1123b1af3f0d565a823ce4e6f91efa005976489e96f2bffb287" diff --git a/pyproject.toml b/pyproject.toml index e3725ed..ffd4466 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,14 +10,18 @@ line-length = 98 name = "research-agi-agent" version = "0.1.0" description = "" -authors = ["James Hutchison ", "Muhammad Ibrahim Laeeq "] +authors = [ + "James Hutchison ", + "Muhammad Ibrahim Laeeq ", +] readme = "README.md" -packages = [{ include = "research_agi_agent" }] +packages = [{ include = "DeepResearchTool" }] [tool.poetry.dependencies] python = "^3.11" superagi-tools = "^1.0.7" langchain = "^0.0.268" +megamock = "^0.1.0b7" [tool.poetry.group.dev.dependencies] @@ -25,7 +29,11 @@ mypy = "^1.5.0" ruff = "^0.0.284" black = "^23.7.0" pytest = "^7.4.0" +pytest-dotenv = "^0.5.2" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/tests/test_deep_research_writer_tool.py b/tests/test_deep_research_writer_tool.py new file mode 100644 index 0000000..575251d --- /dev/null +++ b/tests/test_deep_research_writer_tool.py @@ -0,0 +1,56 @@ +import textwrap + +from megamock import Mega, MegaMock + +from DeepResearchTool.deep_research.summary_generator import SummaryGenerator + + +class TestDeepResearchWriterTool: + class TestGenerateMarkdownPrompt: + def test_is_dedented(self): + mock = MegaMock.it(SummaryGenerator) + mock._format_topic = lambda topic: f"formatted {topic}" + Mega(mock._generate_markdown_prompt).use_real_logic() + generated_prompt = mock._generate_markdown_prompt( + user_query="test query", + topics=[ + { + "name": "test topic", + "description": "test description", + "relevant_because": "test relevant because", + "notes_file": "test notes file", + } + ], + ) + + # ensure it is dedented + for line in generated_prompt.splitlines(): + # TODO: fails because logic isn't implemented + assert not line.startswith(" ") + + class TestFormatTopic: + def test_formats_topic(self): + mock = MegaMock.it(SummaryGenerator, spec_set=False) + mock._notes_getter = lambda notes_file: f"notes from {notes_file}" + Mega(mock._format_topic).use_real_logic() + formatted_topic = mock._format_topic( + { + "name": "test topic", + "description": "test description", + "relevant_because": "test relevant because", + "notes_file": "test_notes_file.json", + } + ) + + assert textwrap.dedent(formatted_topic) == textwrap.dedent( + """ + Topic name: test topic + Topic description: test description + Relevant because: test relevant because + Notes: notes from test_notes_file.json + """ + ) + + class TestGetMarkdownSummary: + def test_takes_prompt_and_plugs_it_into_chat(self) -> None: + pass # TODO