diff --git a/comparison.py b/comparison.py new file mode 100644 index 00000000..841bdd6f --- /dev/null +++ b/comparison.py @@ -0,0 +1,76 @@ +import asyncio +import itertools +from pprint import pprint + +from tqdm import tqdm +from tqdm.asyncio import tqdm_asyncio + +from ragna import Rag, assistants, source_storages + +SOURCE_STORAGES = [ + source_storages.Chroma, + source_storages.LanceDB, +] +ASSISTANTS = [ + assistants.Gpt35Turbo16k, + assistants.Gpt4, +] + + +async def main(): + rag = Rag() + + # Pre-load all components to enable a fair comparison + for component in itertools.chain(SOURCE_STORAGES, ASSISTANTS): + rag._load_component(component) + + document = "ragna.txt" + with open(document, "w") as file: + file.write("Ragna is an open source RAG orchestration framework\n") + + prompt = "What is Ragna?" + + experiments = make_experiments(rag=rag, prompt=prompt, document=document) + pprint( + { + name: await experiment + for name, experiment in tqdm(experiments.items(), desc="sync") + }, + sort_dicts=False, + ) + + experiments = make_experiments(rag=rag, prompt=prompt, document=document) + pprint( + dict( + zip( + experiments.keys(), + await tqdm_asyncio.gather(*experiments.values(), desc="async"), + ) + ), + sort_dicts=False, + ) + + +def make_experiments(*, rag, document, prompt): + return { + f"{source_storage.display_name()} / {assistant.display_name()}": answer( + rag=rag, + documents=[document], + source_storage=source_storage, + assistant=assistant, + prompt=prompt, + ) + for source_storage, assistant in itertools.product(SOURCE_STORAGES, ASSISTANTS) + } + + +async def answer(*, rag, documents, source_storage, assistant, prompt): + chat = rag.chat( + documents=documents, source_storage=source_storage, assistant=assistant + ) + await chat.prepare() + return (await chat.answer(prompt)).content + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index b1af0739..45d26a67 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -45,7 +45,7 @@ def _load_component(self, component: Union[Type[C], C]) -> C: cls = component instance = None else: - raise RagnaException + raise RagnaException("Unknown component", component=component) if cls not in self._components: if instance is None: