From 96c2761f0c31d8420185740382c8fd9129153340 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 20 Nov 2023 11:05:35 +0100 Subject: [PATCH 1/2] add comparison workflow example --- comparison.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++ ragna/core/_rag.py | 2 +- 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 comparison.py diff --git a/comparison.py b/comparison.py new file mode 100644 index 00000000..2d8f4600 --- /dev/null +++ b/comparison.py @@ -0,0 +1,68 @@ +import asyncio +import itertools +from pprint import pprint + +from tqdm import tqdm +from tqdm.asyncio import tqdm_asyncio + +from ragna import Rag, assistants, source_storages + +SOURCE_STORAGES = [ + source_storages.Chroma, + source_storages.LanceDB, +] +ASSISTANTS = [ + assistants.Gpt35Turbo16k, + assistants.Gpt4, +] + + +async def main(): + rag = Rag() + + # Pre-load all components to enable a fair comparison + for component in itertools.chain(SOURCE_STORAGES, ASSISTANTS): + rag._load_component(component) + + document = "ragna.txt" + with open(document, "w") as file: + file.write("Ragna is an open source RAG orchestration framework\n") + + prompt = "What is Ragna?" + + experiments = make_experiments(rag=rag, prompt=prompt, document=document) + pprint( + {name: await experiment for name, experiment in tqdm(experiments.items())}, + sort_dicts=False, + ) + + experiments = make_experiments(rag=rag, prompt=prompt, document=document) + pprint( + dict(zip(experiments.keys(), await tqdm_asyncio.gather(*experiments.values()))), + sort_dicts=False, + ) + + +def make_experiments(*, rag, document, prompt): + return { + f"{source_storage.display_name()} / {assistant.display_name()}": answer( + rag=rag, + documents=[document], + source_storage=source_storage, + assistant=assistant, + prompt=prompt, + ) + for source_storage, assistant in itertools.product(SOURCE_STORAGES, ASSISTANTS) + } + + +async def answer(*, rag, documents, source_storage, assistant, prompt): + chat = rag.chat( + documents=documents, source_storage=source_storage, assistant=assistant + ) + await chat.prepare() + return (await chat.answer(prompt)).content + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index b1af0739..45d26a67 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -45,7 +45,7 @@ def _load_component(self, component: Union[Type[C], C]) -> C: cls = component instance = None else: - raise RagnaException + raise RagnaException("Unknown component", component=component) if cls not in self._components: if instance is None: From 8432cfc67e2d8d71667c663d91ec394f6ab79e3b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 20 Nov 2023 11:41:11 +0100 Subject: [PATCH 2/2] add desc --- comparison.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comparison.py b/comparison.py index 2d8f4600..841bdd6f 100644 --- a/comparison.py +++ b/comparison.py @@ -32,13 +32,21 @@ async def main(): experiments = make_experiments(rag=rag, prompt=prompt, document=document) pprint( - {name: await experiment for name, experiment in tqdm(experiments.items())}, + { + name: await experiment + for name, experiment in tqdm(experiments.items(), desc="sync") + }, sort_dicts=False, ) experiments = make_experiments(rag=rag, prompt=prompt, document=document) pprint( - dict(zip(experiments.keys(), await tqdm_asyncio.gather(*experiments.values()))), + dict( + zip( + experiments.keys(), + await tqdm_asyncio.gather(*experiments.values(), desc="async"), + ) + ), sort_dicts=False, )