From 684dc35ae09572d390e5eb5c5c3ecd3aefd83f71 Mon Sep 17 00:00:00 2001 From: dillonroach Date: Thu, 13 Jun 2024 01:44:04 -0700 Subject: [PATCH 01/15] WIP draft of chat.generate() and worked example with exllamav2 in _exl2.py --- environment-dev-exl.yml | 43 ++++++++++++ ragna/assistants/__init__.py | 2 + ragna/assistants/_exl2.py | 123 +++++++++++++++++++++++++++++++++++ ragna/core/_components.py | 16 ++++- ragna/core/_rag.py | 20 ++++++ 5 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 environment-dev-exl.yml create mode 100644 ragna/assistants/_exl2.py diff --git a/environment-dev-exl.yml b/environment-dev-exl.yml new file mode 100644 index 00000000..1eb561be --- /dev/null +++ b/environment-dev-exl.yml @@ -0,0 +1,43 @@ +name: ragna-dev-exl +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python =3.11 + - pip + - git-lfs + - jupyterlab>=3 + - pandas + - numpy + - panel + - tokenizers + - pytorch=2.3.1 + - pytorch-cuda=12.1 + - cuda-nvcc + - rich + - pip: + - python-dotenv + - pytest >=6 + - pytest-mock + - pytest-asyncio + - pytest-playwright + - mypy ==1.10.0 + - pre-commit + - types-aiofiles + - sqlalchemy-stubs + - setuptools-scm + - pip-tools + # documentation + - mkdocs + - mkdocs-material + - mkdocstrings[python] + - mkdocs-gen-files + - material-plausible-plugin + - mkdocs-gallery >=0.10 + - mdx_truly_sane_lists + # exl2 + - ninja + - packaging + - exllamav2@https://github.com/turboderp/exllamav2/releases/download/v0.1.5/exllamav2-0.1.5+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index bcf5ead6..412c7e82 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -18,8 +18,10 @@ "Jurassic2Ultra", "LlamafileAssistant", "RagnaDemoAssistant", + "Exl2Assistant", ] +from ._exl2 import Exl2Assistant from ._ai21labs import Jurassic2Ultra from ._anthropic import ClaudeHaiku, ClaudeOpus, ClaudeSonnet from ._cohere import Command, CommandLight diff --git a/ragna/assistants/_exl2.py b/ragna/assistants/_exl2.py new file mode 100644 index 00000000..1bbf8cf8 --- /dev/null +++ b/ragna/assistants/_exl2.py @@ -0,0 +1,123 @@ +import re +import textwrap +from typing import Iterator, Union, cast + +from ragna.core import Assistant, Source +from pathlib import Path + +from exllamav2 import( + ExLlamaV2, + ExLlamaV2Config, + ExLlamaV2Cache, + ExLlamaV2Tokenizer, + ExLlamaV2Cache_Q4, +) + +from exllamav2.generator import ( + ExLlamaV2BaseGenerator, + ExLlamaV2Sampler, + ExLlamaV2DynamicGenerator, + ExLlamaV2DynamicJob, + ExLlamaV2DynamicGeneratorAsync, + ExLlamaV2DynamicJobAsync, +) + +import time + +class Exl2Assistant(Assistant): + """Exl2Assistant - example to instantiate and run inference in process + """ + + @classmethod + def display_name(cls) -> str: + return "Ragna/Exl2Assistant" + + # TODO; known needs: - pytorch, pytorch-cuda, [cuda-nvcc, rich, ninja, packaging, flash-attn] for paged attention batching, exllamav2 + # @classmethod + # def requirements(cls, protocol: HttpStreamingProtocol) -> list[Requirement]: + # streaming_requirements: dict[HttpStreamingProtocol, list[Requirement]] = { + # HttpStreamingProtocol.SSE: [PackageRequirement("httpx_sse")], + # } + # return streaming_requirements.get(protocol, []) + + def __init__( + self, + ) -> None: + self._stream = False + self._paged = False + self._max_seq_length = 8192 + self._max_new_tokens = 512 + self._model_directory = "" + self._load() + + def _load(self): + self._config = ExLlamaV2Config(self._model_directory) + self._config.prepare() + self._config.max_seq_len = self._max_seq_length + self._model = ExLlamaV2(self._config) + self._cache = ExLlamaV2Cache_Q4(self._model, lazy = True, max_seq_len=self._config.max_seq_len) + self._model.load_autosplit(self._cache) + self._tokenizer = ExLlamaV2Tokenizer(self._config) + self.settings = ExLlamaV2Sampler.Settings() + self.settings.temperature = 0.35 + self.settings.top_k = 50 + self.settings.top_p = 0.8 + self.settings.token_repetition_penalty = 1.05 + + def _render_prompt(self, prompt: str) -> str: + system_prompt="You are an unbiased, helpful assistant." + texts = [f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n"] + texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>') + return ''.join(texts) + + + async def generate(self, prompt: str) -> str: + full_prompt = self._render_prompt(prompt) + + self._generator = ExLlamaV2DynamicGenerator( + model = self._model, + cache = self._cache, + tokenizer = self._tokenizer, + gen_settings = self.settings, + paged = self._paged, + ) + outputs = self._generator.generate( + prompt = full_prompt, + max_new_tokens = self._max_new_tokens, + stop_conditions = ["","<|eot_id|>", self._tokenizer.eos_token_id], + completion_only = True, + ) + yield outputs + + + async def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: + full_prompt = self._render_prompt(prompt) + input_ids = self._tokenizer.encode(full_prompt) + #examples at https://github.com/turboderp/exllamav2/blob/master/examples/inference_stream.py + #and https://github.com/turboderp/exllamav2/blob/master/examples/inference_async.py + if self._stream: + self._generator = ExLlamaV2DynamicGeneratorAsync( + model = self._model, + cache = self._cache, + tokenizer = self._tokenizer, + paged = self._paged, + gen_settings = self.settings, + ) + job = ExLlamaV2DynamicJobAsync( + generator = self._generator, + input_ids = input_ids, + max_new_tokens = self._max_new_tokens, + token_healing = True, + stop_conditions = ["","<|eot_id|>", self._tokenizer.eos_token_id], + completion_only = True, + ) + async for result in job: + text_chunk = result.get("text", "") + if not result["eos"]: + yield cast(str, text_chunk) + await self._generator.close() + else: + yield [i async for i in self.generate(prompt)][0] + + + diff --git a/ragna/core/_components.py b/ragna/core/_components.py index d237c1b8..8bc390c2 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -210,7 +210,7 @@ def __repr__(self) -> str: class Assistant(Component, abc.ABC): """Abstract base class for assistants used in [ragna.core.Chat][]""" - __ragna_protocol_methods__ = ["answer"] + __ragna_protocol_methods__ = ["answer","generate"] @abc.abstractmethod def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: @@ -224,3 +224,17 @@ def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: Answer. """ ... + + @abc.abstractmethod + def generate(self, prompt: str) -> str: + #TODO + """Answer a prompt given some sources. + + Args: + prompt: Prompt to be answered. + sources: Sources to use when answering answer the prompt. + + Returns: + Answer. + """ + ... diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 6cdff127..66dcb98f 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -232,6 +232,26 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: return answer + async def generate(self, *, prompt: str) -> str: + """Run Inference on assistant endpoint + + Returns: + Answer. + + Raises: + ragna.core.RagnaException: If chat is not + [`prepare`][ragna.core.Chat.prepare]d. + """ + if not self._prepared: + raise RagnaException( + "Chat is not prepared", + chat=self, + http_status_code=400, + detail=RagnaException.EVENT, + ) + + return self._run_gen(self.assistant.generate, prompt) + def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: documents_ = [] for document in documents: From 290f4c7741dba989266919c1f0a497d89f05f02c Mon Sep 17 00:00:00 2001 From: dillonroach Date: Sun, 7 Jul 2024 21:29:50 -0700 Subject: [PATCH 02/15] draft update assistants with generate() - needs chat history integrate to ingest messages, if appropriate. --- ragna/assistants/_ai21labs.py | 12 +++++++++--- ragna/assistants/_anthropic.py | 13 ++++++++++--- ragna/assistants/_cohere.py | 16 +++++++++++++--- ragna/assistants/_google.py | 14 ++++++++++---- ragna/assistants/_openai.py | 15 ++++++++++++--- 5 files changed, 54 insertions(+), 16 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 3e0c56b5..bd647338 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -22,8 +22,8 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + async def generate( + self, prompt: str, system_prompt: str, *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters @@ -46,10 +46,16 @@ async def answer( "role": "user", } ], - "system": self._make_system_content(sources), + "system": system_prompt, }, ): yield cast(str, data["outputs"][0]["text"]) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + system_prompt = self._make_system_content(sources) + yield generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index d74fc840..3e2bb234 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -36,8 +36,8 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + async def generate( + self, prompt: str, system_prompt: str, *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming @@ -52,7 +52,7 @@ async def answer( }, json={ "model": self._MODEL, - "system": self._instructize_system_prompt(sources), + "system": system, "messages": [{"role": "user", "content": prompt}], "max_tokens": max_new_tokens, "temperature": 0.0, @@ -68,6 +68,13 @@ async def answer( continue yield cast(str, data["delta"].pop("text")) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + system_prompt = self._instructize_system_prompt(sources) + yield self.generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens) + class ClaudeOpus(AnthropicAssistant): diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 4108d31b..29987c9c 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -24,8 +24,8 @@ def _make_preamble(self) -> str: def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + async def generate( + self, prompt: str, system_prompt: str, source_documents: list[dict[str, str]], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat @@ -39,7 +39,7 @@ async def answer( "authorization": f"Bearer {self._api_key}", }, json={ - "preamble_override": self._make_preamble(), + "preamble_override": system_prompt, "message": prompt, "model": self._MODEL, "stream": True, @@ -55,6 +55,16 @@ async def answer( raise RagnaException(event["error_message"]) if "text" in event: yield cast(str, event["text"]) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + # See https://docs.cohere.com/docs/cochat-beta + # See https://docs.cohere.com/reference/chat + # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag + system_prompt = self._make_preamble() + source_documents = self._make_source_documents(sources) + yield generate(prompt=prompt,system_prompt=system_prompt,source_documents=source_documents,max_new_tokens=max_new_tokens) class Command(CohereAssistant): diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 70c82936..0af50999 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -24,9 +24,9 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: *[f"\n{source.content}" for source in sources], ] ) - - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + + async def generate( + self, prompt: str, *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: async for chunk in self._call_api( "POST", @@ -35,7 +35,7 @@ async def answer( headers={"Content-Type": "application/json"}, json={ "contents": [ - {"parts": [{"text": self._instructize_prompt(prompt, sources)}]} + {"parts": [{"text": prompt}]} ], # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ @@ -60,6 +60,12 @@ async def answer( ): yield chunk + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + expanded_prompt = self._instructize_prompt(prompt, sources) + yield generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens) class GeminiPro(GoogleAssistant): """[Google Gemini Pro](https://ai.google.dev/models/gemini) diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 0f51d6d9..1e965ccf 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -23,8 +23,8 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - def _stream( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + def generate( + self, prompt: str, system_prompt: str, *, max_new_tokens: int ) -> AsyncIterator[dict[str, Any]]: # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming @@ -38,7 +38,7 @@ def _stream( "messages": [ { "role": "system", - "content": self._make_system_content(sources), + "content": system_prompt, }, { "role": "user", @@ -53,6 +53,15 @@ def _stream( json_["model"] = self._MODEL return self._call_api("POST", self._url, headers=headers, json=json_) + + def _stream( + self, prompt: str, sources: list[Source], *, max_new_tokens: int + ) -> AsyncIterator[dict[str, Any]]: + # See https://platform.openai.com/docs/api-reference/chat/create + # and https://platform.openai.com/docs/api-reference/chat/streaming + system_prompt = self._make_system_content(sources) + + return generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens) async def answer( self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 From 954c4afc58918b24e0a5e488929ef361b6d98f9f Mon Sep 17 00:00:00 2001 From: dillonroach Date: Sun, 21 Jul 2024 22:14:28 -0700 Subject: [PATCH 03/15] update generate pattern with _render_prompt in order to handle strings or messages --- ragna/assistants/_ai21labs.py | 18 ++++++++------- ragna/assistants/_anthropic.py | 13 ++++++++--- ragna/assistants/_cohere.py | 13 ++++++++--- ragna/assistants/_exl2.py | 41 +++++++++++++++++++++++++++++----- ragna/assistants/_google.py | 18 ++++++++++----- ragna/assistants/_openai.py | 23 +++++++++++++++++-- 6 files changed, 98 insertions(+), 28 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 3f00db86..32d80cfb 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, cast +from typing import AsyncIterator, cast, Union from ragna.core import Message, Source @@ -22,8 +22,15 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) + def _render_prompt(self, prompt: Union[str,list[Message]]) -> Union[str,list]: + if isinstance(prompt,str): + return [{"text": prompt, "role": "user",}] + else: + messages = [{"text":i["content"], "role":i["role"]} for i in prompt if i["role"] != "system"] + return messages + async def generate( - self, prompt: str, system_prompt: str, *, max_new_tokens: int = 256 + self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters @@ -41,12 +48,7 @@ async def generate( "numResults": 1, "temperature": 0.0, "maxTokens": max_new_tokens, - "messages": [ - { - "text": prompt, - "role": "user", - } - ], + "messages": _render_prompt(prompt), "system": system_prompt, }, ): diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 6beb5509..89f18e58 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, cast +from typing import AsyncIterator, cast, Union from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source @@ -36,8 +36,15 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) + def _render_prompt(self, prompt: Union[str,list[Message]]) -> Union[str,list]: + if isinstance(prompt,str): + return [{"content": prompt, "role": "user",}] + else: + messages = [{"content":i["content"], "role":i["role"]} for i in prompt if i["role"] != "system"] + return messages + async def generate( - self, prompt: str, system_prompt: str, *, max_new_tokens: int = 256 + self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming @@ -54,7 +61,7 @@ async def generate( json={ "model": self._MODEL, "system": system, - "messages": [{"role": "user", "content": prompt}], + "messages": _render_prompt(prompt), "max_tokens": max_new_tokens, "temperature": 0.0, "stream": True, diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index f5fc7a97..e1ed46dc 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, cast +from typing import AsyncIterator, cast, Union from ragna.core import Message, RagnaException, Source @@ -24,8 +24,15 @@ def _make_preamble(self) -> str: def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] + def _render_prompt(self, prompt: Union[str,list[Message]]) -> str: + if isinstance(prompt,str): + return prompt + else: + messages = [i["content"] for i in prompt if i["role"] == "user"][-1] + return messages + async def generate( - self, prompt: str, system_prompt: str, source_documents: list[dict[str, str]], *, max_new_tokens: int = 256 + self, prompt: Union[str,list[Message]], system_prompt: str, source_documents: list[dict[str, str]], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat @@ -41,7 +48,7 @@ async def generate( }, json={ "preamble_override": system_prompt, - "message": prompt, + "message": _render_prompt(prompt), "model": self._MODEL, "stream": True, "temperature": 0.0, diff --git a/ragna/assistants/_exl2.py b/ragna/assistants/_exl2.py index 70b9a1e5..e4955675 100644 --- a/ragna/assistants/_exl2.py +++ b/ragna/assistants/_exl2.py @@ -64,14 +64,43 @@ def _load(self): self.settings.top_p = 0.8 self.settings.token_repetition_penalty = 1.05 - def _render_prompt(self, prompt: str) -> str: - system_prompt="You are an unbiased, helpful assistant." - texts = [f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n"] - texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>') - return ''.join(texts) + def _render_prompt(self, prompt: Union[str,list[Message]]) -> str: + """ + Llama3 style prompt compile + + Currently Assuming Enums: + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + """ + if isinstance(prompt,str): + system_prompt="You are an unbiased, helpful assistant." + texts = [f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n"] + texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>') + return ''.join(texts) + else: + system_prompt=[i['content'] for i in prompt if i['role'] == 'system'] + if len(system_prompt) == 0: + system_prompt="You are an unbiased, helpful assistant." + else: + system_prompt = system_prompt[0] + texts = [f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n"] + for i in prompt: + if i['role'] == "user": + texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{i["content"]}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>') + elif i['role'] == "assistant": + if i['content'][-10:] == '<|eot_id|>': + texts.append(f'{i["content"]}\n') + elif i['content'][-12:] == '<|eot_id|>\n': + texts.append(f'{i["content"]}') + else: + texts.append(f'{i["content"]}<|eot_id|>\n') + else: + pass + return ''.join(texts) - async def generate(self, prompt: str) -> str: + async def generate(self, prompt: Union[str,list[Message]]) -> str: full_prompt = self._render_prompt(prompt) self._generator = ExLlamaV2DynamicGenerator( diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 28b2b962..bebd0e61 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator +from typing import AsyncIterator, Union from ragna.core import Message, Source @@ -24,9 +24,17 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: *[f"\n{source.content}" for source in sources], ] ) - + def _render_prompt(self, prompt: Union[str,list[Message]]) -> list[dict]: + #need to verify against https://ai.google.dev/api/generate-content#chat_1 + role_mapping = {"user":"user","assistant":"model"} + if isinstance(prompt,str): + return [{"parts": [{"text": prompt}]}] + else: + messages = [{"parts":[{"text":i["content"]}], "role":role_mapping[i["role"]]} for i in prompt if i["role"] != "system"] + return messages + async def generate( - self, prompt: str, *, max_new_tokens: int = 256 + self, prompt: Union[str,list[Message]], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: async for chunk in self._call_api( "POST", @@ -34,9 +42,7 @@ async def generate( params={"key": self._api_key}, headers={"Content-Type": "application/json"}, json={ - "contents": [ - {"parts": [{"text": prompt}]} - ], + "contents": _render_prompt(prompt), # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ { diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 6d234653..19ed2b25 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,6 +1,6 @@ import abc from functools import cached_property -from typing import Any, AsyncIterator, Optional, cast +from typing import Any, AsyncIterator, Optional, cast, Union from ragna.core import Message, Source @@ -23,8 +23,27 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) + def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) -> list[dict]: + #need to verify against https://ai.google.dev/api/generate-content#chat_1 + if isinstance(prompt,str): + messages = [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": prompt, + }, + ] + return messages + else: + system_message = [{"role":"system", "content":system_prompt}] + messages = [{"role":i["role"],"content":i["content"]} for i in prompt if i["role"] != "system"] + return system_message.extend(messages) + async def generate( - self, prompt: str, system_prompt: str, *, max_new_tokens: int + self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int ) -> AsyncIterator[dict[str, Any]]: # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming From 82637dda6450184ec411ce55d0442dc87ffb6496 Mon Sep 17 00:00:00 2001 From: dillonroach Date: Sun, 11 Aug 2024 22:17:48 -0700 Subject: [PATCH 04/15] linting, docstrings, minor cleanup --- ragna/assistants/_ai21labs.py | 49 ++++++++++-- ragna/assistants/_anthropic.py | 52 ++++++++++--- ragna/assistants/_cohere.py | 41 ++++++++-- ragna/assistants/_exl2.py | 135 ++++++++++++++++++--------------- ragna/assistants/_google.py | 34 ++++++--- ragna/assistants/_openai.py | 19 ++++- 6 files changed, 234 insertions(+), 96 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 32d80cfb..0f343762 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, cast, Union +from typing import AsyncIterator, Union, cast from ragna.core import Message, Source @@ -22,16 +22,47 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - def _render_prompt(self, prompt: Union[str,list[Message]]) -> Union[str,list]: - if isinstance(prompt,str): - return [{"text": prompt, "role": "user",}] + def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + ordered list of dicts with 'text' and 'role' keys + """ + if isinstance(prompt, str): + return [ + { + "text": prompt, + "role": "user", + } + ] else: - messages = [{"text":i["content"], "role":i["role"]} for i in prompt if i["role"] != "system"] + messages = [ + {"text": i["content"], "role": i["role"]} + for i in prompt + if i["role"] != "system" + ] return messages async def generate( - self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 + self, + prompt: Union[str, list[Message]], + system_prompt: str, + *, + max_new_tokens: int = 256, ) -> AsyncIterator[str]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + max_new_tokens: Max number of completion tokens (default 256_ + + Returns: + async streamed inference response string chunks + """ # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters # See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response @@ -53,13 +84,15 @@ async def generate( }, ): yield cast(str, data["outputs"][0]["text"]) - + async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: prompt, sources = (message := messages[-1]).content, message.sources system_prompt = self._make_system_content(sources) - yield generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens) + yield generate( + prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens + ) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 89f18e58..9979300a 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, cast, Union +from typing import AsyncIterator, Union, cast from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source @@ -36,16 +36,47 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) - def _render_prompt(self, prompt: Union[str,list[Message]]) -> Union[str,list]: - if isinstance(prompt,str): - return [{"content": prompt, "role": "user",}] + def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + ordered list of dicts with 'content' and 'role' keys + """ + if isinstance(prompt, str): + return [ + { + "content": prompt, + "role": "user", + } + ] else: - messages = [{"content":i["content"], "role":i["role"]} for i in prompt if i["role"] != "system"] + messages = [ + {"content": i["content"], "role": i["role"]} + for i in prompt + if i["role"] != "system" + ] return messages - + async def generate( - self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 + self, + prompt: Union[str, list[Message]], + system_prompt: str, + *, + max_new_tokens: int = 256, ) -> AsyncIterator[str]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + async streamed inference response string chunks + """ # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming @@ -76,14 +107,15 @@ async def generate( continue yield cast(str, data["delta"].pop("text")) - + async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: prompt, sources = (message := messages[-1]).content, message.sources system_prompt = self._instructize_system_prompt(sources) - yield self.generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens) - + yield self.generate( + prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens + ) class ClaudeOpus(AnthropicAssistant): diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index e1ed46dc..83c55500 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, cast, Union +from typing import AsyncIterator, Union, cast from ragna.core import Message, RagnaException, Source @@ -24,16 +24,40 @@ def _make_preamble(self) -> str: def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] - def _render_prompt(self, prompt: Union[str,list[Message]]) -> str: - if isinstance(prompt,str): + def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + prompt string + """ + if isinstance(prompt, str): return prompt else: messages = [i["content"] for i in prompt if i["role"] == "user"][-1] return messages async def generate( - self, prompt: Union[str,list[Message]], system_prompt: str, source_documents: list[dict[str, str]], *, max_new_tokens: int = 256 + self, + prompt: Union[str, list[Message]], + system_prompt: str, + source_documents: list[dict[str, str]], + *, + max_new_tokens: int = 256, ) -> AsyncIterator[str]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + source_documents: List of source content dicts with 'title' and 'snippet' keys + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + async streamed inference response string chunks + """ # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag @@ -63,7 +87,7 @@ async def generate( raise RagnaException(event["error_message"]) if "text" in event: yield cast(str, event["text"]) - + async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: @@ -73,7 +97,12 @@ async def answer( prompt, sources = (message := messages[-1]).content, message.sources system_prompt = self._make_preamble() source_documents = self._make_source_documents(sources) - yield generate(prompt=prompt,system_prompt=system_prompt,source_documents=source_documents,max_new_tokens=max_new_tokens) + yield generate( + prompt=prompt, + system_prompt=system_prompt, + source_documents=source_documents, + max_new_tokens=max_new_tokens, + ) class Command(CohereAssistant): diff --git a/ragna/assistants/_exl2.py b/ragna/assistants/_exl2.py index e4955675..0090eaaa 100644 --- a/ragna/assistants/_exl2.py +++ b/ragna/assistants/_exl2.py @@ -1,32 +1,23 @@ -import re -import textwrap -from typing import Iterator, Union, cast +from typing import Union, cast -from ragna.core import Assistant, Source -from pathlib import Path - -from exllamav2 import( +from exllamav2 import ( ExLlamaV2, + ExLlamaV2Cache_Q4, ExLlamaV2Config, - ExLlamaV2Cache, ExLlamaV2Tokenizer, - ExLlamaV2Cache_Q4, ) - from exllamav2.generator import ( - ExLlamaV2BaseGenerator, - ExLlamaV2Sampler, ExLlamaV2DynamicGenerator, - ExLlamaV2DynamicJob, ExLlamaV2DynamicGeneratorAsync, ExLlamaV2DynamicJobAsync, + ExLlamaV2Sampler, ) -import time +from ragna.core import Assistant + class Exl2Assistant(Assistant): - """Exl2Assistant - example to instantiate and run inference in process - """ + """Exl2Assistant - example to instantiate and run inference in process""" @classmethod def display_name(cls) -> str: @@ -49,13 +40,15 @@ def __init__( self._max_new_tokens = 512 self._model_directory = "" self._load() - + def _load(self): self._config = ExLlamaV2Config(self._model_directory) self._config.prepare() self._config.max_seq_len = self._max_seq_length self._model = ExLlamaV2(self._config) - self._cache = ExLlamaV2Cache_Q4(self._model, lazy = True, max_seq_len=self._config.max_seq_len) + self._cache = ExLlamaV2Cache_Q4( + self._model, lazy=True, max_seq_len=self._config.max_seq_len + ) self._model.load_autosplit(self._cache) self._tokenizer = ExLlamaV2Tokenizer(self._config) self.settings = ExLlamaV2Sampler.Settings() @@ -64,82 +57,103 @@ def _load(self): self.settings.top_p = 0.8 self.settings.token_repetition_penalty = 1.05 - def _render_prompt(self, prompt: Union[str,list[Message]]) -> str: + def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + Single string containing full rendered chat history with formatting tokens + Llama3 style prompt compile - + Currently Assuming Enums: SYSTEM = "system" USER = "user" ASSISTANT = "assistant" """ - if isinstance(prompt,str): - system_prompt="You are an unbiased, helpful assistant." - texts = [f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n"] - texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>') - return ''.join(texts) + if isinstance(prompt, str): + system_prompt = "You are an unbiased, helpful assistant." + texts = [ + f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n" + ] + texts.append( + f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" + ) + return "".join(texts) else: - system_prompt=[i['content'] for i in prompt if i['role'] == 'system'] + system_prompt = [i["content"] for i in prompt if i["role"] == "system"] if len(system_prompt) == 0: - system_prompt="You are an unbiased, helpful assistant." + system_prompt = "You are an unbiased, helpful assistant." else: system_prompt = system_prompt[0] - texts = [f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n"] + texts = [ + f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n" + ] for i in prompt: - if i['role'] == "user": - texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{i["content"]}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>') - elif i['role'] == "assistant": - if i['content'][-10:] == '<|eot_id|>': + if i["role"] == "user": + texts.append( + f'<|start_header_id|>user<|end_header_id|>\n\n{i["content"]}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>' + ) + elif i["role"] == "assistant": + if i["content"][-10:] == "<|eot_id|>": texts.append(f'{i["content"]}\n') - elif i['content'][-12:] == '<|eot_id|>\n': + elif i["content"][-12:] == "<|eot_id|>\n": texts.append(f'{i["content"]}') else: texts.append(f'{i["content"]}<|eot_id|>\n') else: pass - return ''.join(texts) + return "".join(texts) + + async def generate(self, prompt: Union[str, list[Message]]) -> str: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + Args: + prompt: Either a single prompt string or a list of ragna messages - async def generate(self, prompt: Union[str,list[Message]]) -> str: + Returns: + async streamed inference response string chunks + """ full_prompt = self._render_prompt(prompt) self._generator = ExLlamaV2DynamicGenerator( - model = self._model, - cache = self._cache, - tokenizer = self._tokenizer, - gen_settings = self.settings, - paged = self._paged, + model=self._model, + cache=self._cache, + tokenizer=self._tokenizer, + gen_settings=self.settings, + paged=self._paged, ) outputs = self._generator.generate( - prompt = full_prompt, - max_new_tokens = self._max_new_tokens, - stop_conditions = ["","<|eot_id|>", self._tokenizer.eos_token_id], - completion_only = True, + prompt=full_prompt, + max_new_tokens=self._max_new_tokens, + stop_conditions=["", "<|eot_id|>", self._tokenizer.eos_token_id], + completion_only=True, ) yield outputs - - + async def answer(self, messages: list[Message]) -> AsyncIterator[str]: prompt, sources = (message := messages[-1]).content, message.sources full_prompt = self._render_prompt(prompt) input_ids = self._tokenizer.encode(full_prompt) - #examples at https://github.com/turboderp/exllamav2/blob/master/examples/inference_stream.py - #and https://github.com/turboderp/exllamav2/blob/master/examples/inference_async.py + # examples at https://github.com/turboderp/exllamav2/blob/master/examples/inference_stream.py + # and https://github.com/turboderp/exllamav2/blob/master/examples/inference_async.py if self._stream: self._generator = ExLlamaV2DynamicGeneratorAsync( - model = self._model, - cache = self._cache, - tokenizer = self._tokenizer, - paged = self._paged, - gen_settings = self.settings, + model=self._model, + cache=self._cache, + tokenizer=self._tokenizer, + paged=self._paged, + gen_settings=self.settings, ) job = ExLlamaV2DynamicJobAsync( - generator = self._generator, - input_ids = input_ids, - max_new_tokens = self._max_new_tokens, - token_healing = True, - stop_conditions = ["","<|eot_id|>", self._tokenizer.eos_token_id], - completion_only = True, + generator=self._generator, + input_ids=input_ids, + max_new_tokens=self._max_new_tokens, + token_healing=True, + stop_conditions=["", "<|eot_id|>", self._tokenizer.eos_token_id], + completion_only=True, ) async for result in job: text_chunk = result.get("text", "") @@ -148,6 +162,3 @@ async def answer(self, messages: list[Message]) -> AsyncIterator[str]: await self._generator.close() else: yield [i async for i in self.generate(prompt)][0] - - - diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index bebd0e61..88893a99 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -24,18 +24,34 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: *[f"\n{source.content}" for source in sources], ] ) - def _render_prompt(self, prompt: Union[str,list[Message]]) -> list[dict]: - #need to verify against https://ai.google.dev/api/generate-content#chat_1 - role_mapping = {"user":"user","assistant":"model"} - if isinstance(prompt,str): - return [{"parts": [{"text": prompt}]}] + + def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: + # need to verify against https://ai.google.dev/api/generate-content#chat_1 + role_mapping = {"user": "user", "assistant": "model"} + if isinstance(prompt, str): + return [{"role": "user", "parts": [{"text": prompt}]}] else: - messages = [{"parts":[{"text":i["content"]}], "role":role_mapping[i["role"]]} for i in prompt if i["role"] != "system"] + messages = [ + {"parts": [{"text": i["content"]}], "role": role_mapping[i["role"]]} + for i in prompt + if i["role"] != "system" + ] return messages - + async def generate( - self, prompt: Union[str,list[Message]], *, max_new_tokens: int = 256 + self, prompt: Union[str, list[Message]], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + async streamed inference response string chunks + """ async for chunk in self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", @@ -66,7 +82,6 @@ async def generate( ): yield chunk - async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: @@ -74,6 +89,7 @@ async def answer( expanded_prompt = self._instructize_prompt(prompt, sources) yield generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens) + class GeminiPro(GoogleAssistant): """[Google Gemini Pro](https://ai.google.dev/models/gemini) diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 19ed2b25..3ce5bf57 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -24,7 +24,12 @@ def _make_system_content(self, sources: list[Source]) -> str: return instruction + "\n\n".join(source.content for source in sources) def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) -> list[dict]: - #need to verify against https://ai.google.dev/api/generate-content#chat_1 + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + ordered list of dicts with 'content' and 'role' keys + """ if isinstance(prompt,str): messages = [ { @@ -45,6 +50,18 @@ def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) - async def generate( self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int ) -> AsyncIterator[dict[str, Any]]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + yield call to self._call_api with formatted headers and json + """ # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming headers = { From a11919e38a5cbcd4bb81f885264f804df01c764b Mon Sep 17 00:00:00 2001 From: dillonroach Date: Mon, 19 Aug 2024 10:55:58 -0700 Subject: [PATCH 05/15] addressing comments, first pass --- ragna/assistants/_anthropic.py | 2 +- ragna/assistants/_openai.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 9979300a..43e71f72 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -36,7 +36,7 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) - def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: + def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: """ Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 3ce5bf57..3615ff1c 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -48,7 +48,7 @@ def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) - return system_message.extend(messages) async def generate( - self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int + self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() From ae59aee93f8cb9f52e495878621d588d466cdd2d Mon Sep 17 00:00:00 2001 From: dillonroach Date: Mon, 19 Aug 2024 11:09:51 -0700 Subject: [PATCH 06/15] disentangle exl2 from the generate PR - will reopen in another, specific, PR --- environment-dev-exl.yml | 43 --------- ragna/assistants/__init__.py | 2 - ragna/assistants/_exl2.py | 164 ----------------------------------- 3 files changed, 209 deletions(-) delete mode 100644 environment-dev-exl.yml delete mode 100644 ragna/assistants/_exl2.py diff --git a/environment-dev-exl.yml b/environment-dev-exl.yml deleted file mode 100644 index 1eb561be..00000000 --- a/environment-dev-exl.yml +++ /dev/null @@ -1,43 +0,0 @@ -name: ragna-dev-exl -channels: - - pytorch - - nvidia - - conda-forge - - defaults -dependencies: - - python =3.11 - - pip - - git-lfs - - jupyterlab>=3 - - pandas - - numpy - - panel - - tokenizers - - pytorch=2.3.1 - - pytorch-cuda=12.1 - - cuda-nvcc - - rich - - pip: - - python-dotenv - - pytest >=6 - - pytest-mock - - pytest-asyncio - - pytest-playwright - - mypy ==1.10.0 - - pre-commit - - types-aiofiles - - sqlalchemy-stubs - - setuptools-scm - - pip-tools - # documentation - - mkdocs - - mkdocs-material - - mkdocstrings[python] - - mkdocs-gen-files - - material-plausible-plugin - - mkdocs-gallery >=0.10 - - mdx_truly_sane_lists - # exl2 - - ninja - - packaging - - exllamav2@https://github.com/turboderp/exllamav2/releases/download/v0.1.5/exllamav2-0.1.5+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index 412c7e82..bcf5ead6 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -18,10 +18,8 @@ "Jurassic2Ultra", "LlamafileAssistant", "RagnaDemoAssistant", - "Exl2Assistant", ] -from ._exl2 import Exl2Assistant from ._ai21labs import Jurassic2Ultra from ._anthropic import ClaudeHaiku, ClaudeOpus, ClaudeSonnet from ._cohere import Command, CommandLight diff --git a/ragna/assistants/_exl2.py b/ragna/assistants/_exl2.py deleted file mode 100644 index 0090eaaa..00000000 --- a/ragna/assistants/_exl2.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import Union, cast - -from exllamav2 import ( - ExLlamaV2, - ExLlamaV2Cache_Q4, - ExLlamaV2Config, - ExLlamaV2Tokenizer, -) -from exllamav2.generator import ( - ExLlamaV2DynamicGenerator, - ExLlamaV2DynamicGeneratorAsync, - ExLlamaV2DynamicJobAsync, - ExLlamaV2Sampler, -) - -from ragna.core import Assistant - - -class Exl2Assistant(Assistant): - """Exl2Assistant - example to instantiate and run inference in process""" - - @classmethod - def display_name(cls) -> str: - return "Ragna/Exl2Assistant" - - # TODO; known needs: - pytorch, pytorch-cuda, [cuda-nvcc, rich, ninja, packaging, flash-attn] for paged attention batching, exllamav2 - # @classmethod - # def requirements(cls, protocol: HttpStreamingProtocol) -> list[Requirement]: - # streaming_requirements: dict[HttpStreamingProtocol, list[Requirement]] = { - # HttpStreamingProtocol.SSE: [PackageRequirement("httpx_sse")], - # } - # return streaming_requirements.get(protocol, []) - - def __init__( - self, - ) -> None: - self._stream = False - self._paged = False - self._max_seq_length = 8192 - self._max_new_tokens = 512 - self._model_directory = "" - self._load() - - def _load(self): - self._config = ExLlamaV2Config(self._model_directory) - self._config.prepare() - self._config.max_seq_len = self._max_seq_length - self._model = ExLlamaV2(self._config) - self._cache = ExLlamaV2Cache_Q4( - self._model, lazy=True, max_seq_len=self._config.max_seq_len - ) - self._model.load_autosplit(self._cache) - self._tokenizer = ExLlamaV2Tokenizer(self._config) - self.settings = ExLlamaV2Sampler.Settings() - self.settings.temperature = 0.35 - self.settings.top_k = 50 - self.settings.top_p = 0.8 - self.settings.token_repetition_penalty = 1.05 - - def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: - """ - Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. - - Returns: - Single string containing full rendered chat history with formatting tokens - - Llama3 style prompt compile - - Currently Assuming Enums: - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - """ - if isinstance(prompt, str): - system_prompt = "You are an unbiased, helpful assistant." - texts = [ - f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n" - ] - texts.append( - f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" - ) - return "".join(texts) - else: - system_prompt = [i["content"] for i in prompt if i["role"] == "system"] - if len(system_prompt) == 0: - system_prompt = "You are an unbiased, helpful assistant." - else: - system_prompt = system_prompt[0] - texts = [ - f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>\n" - ] - for i in prompt: - if i["role"] == "user": - texts.append( - f'<|start_header_id|>user<|end_header_id|>\n\n{i["content"]}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>' - ) - elif i["role"] == "assistant": - if i["content"][-10:] == "<|eot_id|>": - texts.append(f'{i["content"]}\n') - elif i["content"][-12:] == "<|eot_id|>\n": - texts.append(f'{i["content"]}') - else: - texts.append(f'{i["content"]}<|eot_id|>\n') - else: - pass - return "".join(texts) - - async def generate(self, prompt: Union[str, list[Message]]) -> str: - """ - Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() - This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. - - Args: - prompt: Either a single prompt string or a list of ragna messages - - Returns: - async streamed inference response string chunks - """ - full_prompt = self._render_prompt(prompt) - - self._generator = ExLlamaV2DynamicGenerator( - model=self._model, - cache=self._cache, - tokenizer=self._tokenizer, - gen_settings=self.settings, - paged=self._paged, - ) - outputs = self._generator.generate( - prompt=full_prompt, - max_new_tokens=self._max_new_tokens, - stop_conditions=["", "<|eot_id|>", self._tokenizer.eos_token_id], - completion_only=True, - ) - yield outputs - - async def answer(self, messages: list[Message]) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - full_prompt = self._render_prompt(prompt) - input_ids = self._tokenizer.encode(full_prompt) - # examples at https://github.com/turboderp/exllamav2/blob/master/examples/inference_stream.py - # and https://github.com/turboderp/exllamav2/blob/master/examples/inference_async.py - if self._stream: - self._generator = ExLlamaV2DynamicGeneratorAsync( - model=self._model, - cache=self._cache, - tokenizer=self._tokenizer, - paged=self._paged, - gen_settings=self.settings, - ) - job = ExLlamaV2DynamicJobAsync( - generator=self._generator, - input_ids=input_ids, - max_new_tokens=self._max_new_tokens, - token_healing=True, - stop_conditions=["", "<|eot_id|>", self._tokenizer.eos_token_id], - completion_only=True, - ) - async for result in job: - text_chunk = result.get("text", "") - if not result["eos"]: - yield cast(str, text_chunk) - await self._generator.close() - else: - yield [i async for i in self.generate(prompt)][0] From 986b34de197b029fab1bd2c944dc12cee7666d67 Mon Sep 17 00:00:00 2001 From: dillonroach Date: Mon, 26 Aug 2024 10:55:50 -0700 Subject: [PATCH 07/15] updated assistant dtype logic pattern - google will not work here, needs merge with corpus-dev --- ragna/assistants/_ai21labs.py | 21 +++++++++------------ ragna/assistants/_anthropic.py | 21 +++++++++------------ ragna/assistants/_cohere.py | 8 +++++--- ragna/assistants/_google.py | 15 ++++++++------- ragna/assistants/_openai.py | 21 ++++++--------------- ragna/core/_components.py | 16 +--------------- ragna/core/_rag.py | 20 -------------------- 7 files changed, 38 insertions(+), 84 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 0f343762..a32be7da 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -30,19 +30,16 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: ordered list of dicts with 'text' and 'role' keys """ if isinstance(prompt, str): - return [ - { - "text": prompt, - "role": "user", - } - ] + messages = [Message(content=prompt, role=MessageRole.USER)] else: - messages = [ - {"text": i["content"], "role": i["role"]} - for i in prompt - if i["role"] != "system" - ] - return messages + messages = prompt + + messages = [ + {"text": i["content"], "role": i["role"]} + for i in messages + if i["role"] != "system" + ] + return messages async def generate( self, diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 43e71f72..e8b43dd7 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -44,19 +44,16 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: ordered list of dicts with 'content' and 'role' keys """ if isinstance(prompt, str): - return [ - { - "content": prompt, - "role": "user", - } - ] + messages = [Message(content=prompt, role=MessageRole.USER)] else: - messages = [ - {"content": i["content"], "role": i["role"]} - for i in prompt - if i["role"] != "system" - ] - return messages + messages = prompt + + messages = [ + {"content": i["content"], "role": i["role"]} + for i in messages + if i["role"] != "system" + ] + return messages async def generate( self, diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 83c55500..13842105 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -32,10 +32,12 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: prompt string """ if isinstance(prompt, str): - return prompt + messages = [Message(content=prompt, role=MessageRole.USER)] else: - messages = [i["content"] for i in prompt if i["role"] == "user"][-1] - return messages + messages = prompt + + messages = [i["content"] for i in messages if i["role"] == "user"][-1] + return messages async def generate( self, diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 88893a99..70f851fc 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -29,14 +29,15 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: # need to verify against https://ai.google.dev/api/generate-content#chat_1 role_mapping = {"user": "user", "assistant": "model"} if isinstance(prompt, str): - return [{"role": "user", "parts": [{"text": prompt}]}] + messages = [Message(content=prompt, role=MessageRole.USER)] else: - messages = [ - {"parts": [{"text": i["content"]}], "role": role_mapping[i["role"]]} - for i in prompt - if i["role"] != "system" - ] - return messages + messages = prompt + messages = [ + {"parts": [{"text": i["content"]}], "role": role_mapping[i["role"]]} + for i in messages + if i["role"] != "system" + ] + return messages async def generate( self, prompt: Union[str, list[Message]], *, max_new_tokens: int = 256 diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 3615ff1c..a847f3bb 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -30,22 +30,13 @@ def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) - Returns: ordered list of dicts with 'content' and 'role' keys """ - if isinstance(prompt,str): - messages = [ - { - "role": "system", - "content": system_prompt, - }, - { - "role": "user", - "content": prompt, - }, - ] - return messages + if isinstance(prompt, str): + messages = [Message(content=prompt, role=MessageRole.USER)] else: - system_message = [{"role":"system", "content":system_prompt}] - messages = [{"role":i["role"],"content":i["content"]} for i in prompt if i["role"] != "system"] - return system_message.extend(messages) + messages = prompt + system_message = [{"role":"system", "content":system_prompt}] + messages = [{"role":i["role"],"content":i["content"]} for i in prompt if i["role"] != "system"] + return system_message.extend(messages) async def generate( self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 diff --git a/ragna/core/_components.py b/ragna/core/_components.py index e4382743..62ee6df9 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -235,7 +235,7 @@ def __repr__(self) -> str: class Assistant(Component, abc.ABC): """Abstract base class for assistants used in [ragna.core.Chat][]""" - __ragna_protocol_methods__ = ["answer","generate"] + __ragna_protocol_methods__ = ["answer", "generate"] @abc.abstractmethod def answer(self, messages: list[Message]) -> Iterator[str]: @@ -249,17 +249,3 @@ def answer(self, messages: list[Message]) -> Iterator[str]: Answer. """ ... - - @abc.abstractmethod - def generate(self, prompt: str) -> str: - #TODO - """Answer a prompt given some sources. - - Args: - prompt: Prompt to be answered. - sources: Sources to use when answering answer the prompt. - - Returns: - Answer. - """ - ... diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 7b76486e..15154ea2 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -237,26 +237,6 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: return answer - async def generate(self, *, prompt: str) -> str: - """Run Inference on assistant endpoint - - Returns: - Answer. - - Raises: - ragna.core.RagnaException: If chat is not - [`prepare`][ragna.core.Chat.prepare]d. - """ - if not self._prepared: - raise RagnaException( - "Chat is not prepared", - chat=self, - http_status_code=400, - detail=RagnaException.EVENT, - ) - - return self._run_gen(self.assistant.generate, prompt) - def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: documents_ = [] for document in documents: From 084b28727d7e1eca0a780e247b1d4d00050a94ec Mon Sep 17 00:00:00 2001 From: dillonroach Date: Mon, 26 Aug 2024 11:00:24 -0700 Subject: [PATCH 08/15] revert precommit modify to components --- ragna/core/_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragna/core/_components.py b/ragna/core/_components.py index 62ee6df9..bff49790 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -235,7 +235,7 @@ def __repr__(self) -> str: class Assistant(Component, abc.ABC): """Abstract base class for assistants used in [ragna.core.Chat][]""" - __ragna_protocol_methods__ = ["answer", "generate"] + __ragna_protocol_methods__ = ["answer"] @abc.abstractmethod def answer(self, messages: list[Message]) -> Iterator[str]: From 2e8ae2cf5ad0fb0d653f163d4dfc92e8b97fc28d Mon Sep 17 00:00:00 2001 From: dillonroach Date: Thu, 5 Sep 2024 12:31:23 -0700 Subject: [PATCH 09/15] get default system_prompt strings on each generate call --- ragna/assistants/_ai21labs.py | 2 +- ragna/assistants/_anthropic.py | 2 +- ragna/assistants/_cohere.py | 2 +- ragna/assistants/_openai.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index a32be7da..f448b582 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -44,8 +44,8 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: async def generate( self, prompt: Union[str, list[Message]], - system_prompt: str, *, + system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, ) -> AsyncIterator[str]: """ diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index e8b43dd7..ab54b747 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -58,8 +58,8 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: async def generate( self, prompt: Union[str, list[Message]], - system_prompt: str, *, + system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, ) -> AsyncIterator[str]: """ diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 13842105..73cb7e40 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -42,9 +42,9 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: async def generate( self, prompt: Union[str, list[Message]], - system_prompt: str, source_documents: list[dict[str, str]], *, + system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, ) -> AsyncIterator[str]: """ diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index a847f3bb..a407222f 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -39,7 +39,7 @@ def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) - return system_message.extend(messages) async def generate( - self, prompt: Union[str,list[Message]], system_prompt: str, *, max_new_tokens: int = 256 + self, prompt: Union[str,list[Message]], *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256 ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() From c9e1a7e338740b36a953fb9204db00b222740eee Mon Sep 17 00:00:00 2001 From: dillonroach Date: Thu, 5 Sep 2024 14:57:49 -0700 Subject: [PATCH 10/15] fix pre-commit issues --- ragna/assistants/_ai21labs.py | 8 +++---- ragna/assistants/_anthropic.py | 17 ++++++++++----- ragna/assistants/_cohere.py | 10 ++++----- ragna/assistants/_google.py | 6 +++--- ragna/assistants/_openai.py | 38 ++++++++++++++++++++++++++-------- 5 files changed, 53 insertions(+), 26 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 4017e942..6c686ad9 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, Union, cast -from ragna.core import Message, Source +from ragna.core import Message, MessageRole, Source from ._http_api import HttpApiAssistant @@ -33,7 +33,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - + messages = [ {"text": i["content"], "role": i["role"]} for i in messages @@ -76,7 +76,7 @@ async def generate( "numResults": 1, "temperature": 0.0, "maxTokens": max_new_tokens, - "messages": _render_prompt(prompt), + "messages": self._render_prompt(prompt), "system": system_prompt, }, ) as stream: @@ -88,7 +88,7 @@ async def answer( ) -> AsyncIterator[str]: prompt, sources = (message := messages[-1]).content, message.sources system_prompt = self._make_system_content(sources) - yield generate( + yield self.generate( prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens ) diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 942e9e0b..77376a2e 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,6 +1,13 @@ from typing import AsyncIterator, Union, cast -from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source +from ragna.core import ( + Message, + MessageRole, + PackageRequirement, + RagnaException, + Requirement, + Source, +) from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -44,9 +51,9 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: ordered list of dicts with 'content' and 'role' keys """ if isinstance(prompt, str): - messages = [Message(content=prompt, role=MessageRole.USER)] + messages = [Message(content=prompt, role=MessageRole.USER)] else: - messages = prompt + messages = prompt messages = [ {"content": i["content"], "role": i["role"]} @@ -88,8 +95,8 @@ async def generate( }, json={ "model": self._MODEL, - "system": system, - "messages": _render_prompt(prompt), + "system": system_prompt, + "messages": self._render_prompt(prompt), "max_tokens": max_new_tokens, "temperature": 0.0, "stream": True, diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 69d6fe7a..36925ed9 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, Union, cast -from ragna.core import Message, RagnaException, Source +from ragna.core import Message, MessageRole, RagnaException, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -35,7 +35,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - + messages = [i["content"] for i in messages if i["role"] == "user"][-1] return messages @@ -74,12 +74,12 @@ async def generate( }, json={ "preamble_override": system_prompt, - "message": _render_prompt(prompt), + "message": self._render_prompt(prompt), "model": self._MODEL, "stream": True, "temperature": 0.0, "max_tokens": max_new_tokens, - "documents": self._make_source_documents(sources), + "documents": source_documents, }, ) as stream: async for event in stream: @@ -100,7 +100,7 @@ async def answer( prompt, sources = (message := messages[-1]).content, message.sources system_prompt = self._make_preamble() source_documents = self._make_source_documents(sources) - yield generate( + yield self.generate( prompt=prompt, system_prompt=system_prompt, source_documents=source_documents, diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index cc1bc002..ecc532eb 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, Union -from ragna.core import Message, Source +from ragna.core import Message, MessageRole, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -59,7 +59,7 @@ async def generate( params={"key": self._api_key}, headers={"Content-Type": "application/json"}, json={ - "contents": _render_prompt(prompt), + "contents": self._render_prompt(prompt), # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ { @@ -89,7 +89,7 @@ async def answer( ) -> AsyncIterator[str]: prompt, sources = (message := messages[-1]).content, message.sources expanded_prompt = self._instructize_prompt(prompt, sources) - yield generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens) + yield self.generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens) class GeminiPro(GoogleAssistant): diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index dbb875f5..6bc74c54 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,6 +1,14 @@ import abc from functools import cached_property -from typing import Any, AsyncContextManager, AsyncIterator, Optional, cast, Union +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + MessageRole, + Optional, + Union, + cast, +) from ragna.core import Message, Source @@ -23,7 +31,9 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) -> list[dict]: + def _render_prompt( + self, prompt: Union[str, list[Message]], system_prompt: str + ) -> list[dict]: """ Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. @@ -34,14 +44,22 @@ def _render_prompt(self, prompt: Union[str,list[Message]], system_prompt: str) - messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - system_message = [{"role":"system", "content":system_prompt}] - messages = [{"role":i["role"],"content":i["content"]} for i in prompt if i["role"] != "system"] + system_message = [{"role": "system", "content": system_prompt}] + messages = [ + {"role": i["role"], "content": i["content"]} + for i in prompt + if i["role"] != "system" + ] return system_message.extend(messages) - + async def generate( - self, prompt: Union[str,list[Message]], *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256 + self, + prompt: Union[str, list[Message]], + *, + system_prompt: str = "You are a helpful assistant.", + max_new_tokens: int = 256, ) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]: - """ + """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. @@ -77,8 +95,10 @@ def _call_openai_api( ) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]: system_prompt = self._make_system_content(sources) - yield self.generate(prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens) - + yield self.generate( + prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens + ) + async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: From a1c912f770d831eff9233997cfbce3e72f50022e Mon Sep 17 00:00:00 2001 From: dillonroach Date: Thu, 5 Sep 2024 15:03:34 -0700 Subject: [PATCH 11/15] import in the wrong place --- ragna/assistants/_openai.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 6bc74c54..c4fb2d00 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -4,13 +4,12 @@ Any, AsyncContextManager, AsyncIterator, - MessageRole, Optional, Union, cast, ) -from ragna.core import Message, Source +from ragna.core import Message, MessageRole, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol From 8e120c6cc4db9b1ec911ea7a101519e8aacf1163 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 6 Sep 2024 10:12:23 +0200 Subject: [PATCH 12/15] cleanup --- ragna/assistants/_ai21labs.py | 28 ++++++++--------- ragna/assistants/_anthropic.py | 57 ++++++++++++++-------------------- ragna/assistants/_cohere.py | 45 +++++++++++++-------------- ragna/assistants/_google.py | 30 +++++++++--------- ragna/assistants/_ollama.py | 23 +++++++------- ragna/assistants/_openai.py | 54 ++++++++++++++------------------ tests/assistants/test_api.py | 12 ++++--- 7 files changed, 118 insertions(+), 131 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 6c686ad9..fef9762b 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Union, cast +from typing import Any, AsyncIterator, Union, cast from ragna.core import Message, MessageRole, Source @@ -33,13 +33,11 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - - messages = [ - {"text": i["content"], "role": i["role"]} - for i in messages - if i["role"] != "system" + return [ + {"role": message.role.value, "content": message.content} + for message in messages + if message.role is not MessageRole.SYSTEM ] - return messages async def generate( self, @@ -47,7 +45,7 @@ async def generate( *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. @@ -81,16 +79,18 @@ async def generate( }, ) as stream: async for data in stream: - yield cast(str, data["outputs"][0]["text"]) + yield data async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - system_prompt = self._make_system_content(sources) - yield self.generate( - prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens - ) + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_system_content(message.sources), + max_new_tokens=max_new_tokens, + ): + yield cast(str, data["outputs"][0]["text"]) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 77376a2e..b58cef5c 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,13 +1,6 @@ -from typing import AsyncIterator, Union, cast +from typing import Any, AsyncIterator, Union, cast -from ragna.core import ( - Message, - MessageRole, - PackageRequirement, - RagnaException, - Requirement, - Source, -) +from ragna.core import Message, MessageRole, RagnaException, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -17,10 +10,6 @@ class AnthropicAssistant(HttpApiAssistant): _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL: str - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("httpx_sse")] - @classmethod def display_name(cls) -> str: return f"Anthropic/{cls._MODEL}" @@ -54,13 +43,11 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - - messages = [ - {"content": i["content"], "role": i["role"]} - for i in messages - if i["role"] != "system" + return [ + {"role": message.role.value, "content": message.content} + for message in messages + if message.role is not MessageRole.SYSTEM ] - return messages async def generate( self, @@ -68,7 +55,7 @@ async def generate( *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. @@ -103,24 +90,26 @@ async def generate( }, ) as stream: async for data in stream: - # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response - if "error" in data: - raise RagnaException(data["error"].pop("message"), **data["error"]) - elif data["type"] == "message_stop": - break - elif data["type"] != "content_block_delta": - continue - - yield cast(str, data["delta"].pop("text")) + yield data async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - system_prompt = self._instructize_system_prompt(sources) - yield self.generate( - prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens - ) + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._instructize_system_prompt(message.sources), + max_new_tokens=max_new_tokens, + ): + # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response + if "error" in data: + raise RagnaException(data["error"].pop("message"), **data["error"]) + elif data["type"] == "message_stop": + break + elif data["type"] != "content_block_delta": + continue + + yield cast(str, data["delta"].pop("text")) class ClaudeOpus(AnthropicAssistant): diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 36925ed9..9550dfb6 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Union, cast +from typing import Any, AsyncIterator, Union, cast from ragna.core import Message, MessageRole, RagnaException, Source @@ -36,8 +36,11 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: else: messages = prompt - messages = [i["content"] for i in messages if i["role"] == "user"][-1] - return messages + for message in reversed(messages): + if message.role is MessageRole.USER: + return message.content + else: + raise RagnaException async def generate( self, @@ -46,7 +49,7 @@ async def generate( *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. @@ -82,30 +85,26 @@ async def generate( "documents": source_documents, }, ) as stream: - async for event in stream: - if event["event_type"] == "stream-end": - if event["event_type"] == "COMPLETE": - break - - raise RagnaException(event["error_message"]) - if "text" in event: - yield cast(str, event["text"]) + async for data in stream: + yield data async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - # See https://docs.cohere.com/docs/cochat-beta - # See https://docs.cohere.com/reference/chat - # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - prompt, sources = (message := messages[-1]).content, message.sources - system_prompt = self._make_preamble() - source_documents = self._make_source_documents(sources) - yield self.generate( - prompt=prompt, - system_prompt=system_prompt, - source_documents=source_documents, + message = messages[-1] + async for data in self.generate( + prompt=message.content, + system_prompt=self._make_preamble(), + source_documents=self._make_source_documents(message.sources), max_new_tokens=max_new_tokens, - ) + ): + if data["event_type"] == "stream-end": + if data["event_type"] == "COMPLETE": + break + + raise RagnaException(data["error_message"]) + if "text" in data: + yield cast(str, data["text"]) class Command(CohereAssistant): diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index ecc532eb..ccc2f83b 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Union +from typing import Any, AsyncIterator, Union from ragna.core import Message, MessageRole, Source @@ -26,22 +26,19 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ) def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: - # need to verify against https://ai.google.dev/api/generate-content#chat_1 - role_mapping = {"user": "user", "assistant": "model"} if isinstance(prompt, str): messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - messages = [ - {"parts": [{"text": i["content"]}], "role": role_mapping[i["role"]]} - for i in messages - if i["role"] != "system" + return [ + {"parts": [{"text": message.content}]} + for message in messages + if message.role is not MessageRole.SYSTEM ] - return messages async def generate( self, prompt: Union[str, list[Message]], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: + ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. @@ -79,17 +76,20 @@ async def generate( "maxOutputTokens": max_new_tokens, }, }, - parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"), + parse_kwargs=dict(item="item"), # .candidates.item.content.parts.item.text ) as stream: - async for chunk in stream: - yield chunk + async for data in stream: + yield data async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - expanded_prompt = self._instructize_prompt(prompt, sources) - yield self.generate(prompt=expanded_prompt, max_new_tokens=max_new_tokens) + message = messages[-1] + async for data in self.generate( + self._instructize_prompt(message.content, message.sources), + max_new_tokens=max_new_tokens, + ): + yield data["candidates"][0]["content"]["parts"][0]["text"] class GeminiPro(GoogleAssistant): diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py index 591c7ed1..aaee53af 100644 --- a/ragna/assistants/_ollama.py +++ b/ragna/assistants/_ollama.py @@ -32,17 +32,18 @@ def _url(self) -> str: async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - async with self._call_openai_api( - prompt, sources, max_new_tokens=max_new_tokens - ) as stream: - async for data in stream: - # Modeled after - # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 - if "error" in data: - raise RagnaException(data["error"]) - if not data["done"]: - yield cast(str, data["message"]["content"]) + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_system_content(message.sources), + max_new_tokens=max_new_tokens, + ): + # Modeled after + # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 + if "error" in data: + raise RagnaException(data["error"]) + if not data["done"]: + yield cast(str, data["message"]["content"]) class OllamaGemma2B(OllamaAssistant): diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index c4fb2d00..3831cf57 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -2,7 +2,6 @@ from functools import cached_property from typing import ( Any, - AsyncContextManager, AsyncIterator, Optional, Union, @@ -43,13 +42,14 @@ def _render_prompt( messages = [Message(content=prompt, role=MessageRole.USER)] else: messages = prompt - system_message = [{"role": "system", "content": system_prompt}] - messages = [ - {"role": i["role"], "content": i["content"]} - for i in prompt - if i["role"] != "system" + return [ + {"role": "system", "content": system_prompt}, + *( + {"role": message.role.value, "content": message.content} + for message in messages + if message.role is not MessageRole.SYSTEM + ), ] - return system_message.extend(messages) async def generate( self, @@ -57,7 +57,7 @@ async def generate( *, system_prompt: str = "You are a helpful assistant.", max_new_tokens: int = 256, - ) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]: + ) -> AsyncIterator[dict[str, Any]]: """ Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. @@ -70,8 +70,6 @@ async def generate( Returns: yield call to self._call_api with formatted headers and json """ - messages = self._render_prompt(prompt, system_prompt) - headers = { "Content-Type": "application/json", } @@ -79,7 +77,7 @@ async def generate( headers["Authorization"] = f"Bearer {self._api_key}" json_ = { - "messages": messages, + "messages": self._render_prompt(prompt, system_prompt), "temperature": 0.0, "max_tokens": max_new_tokens, "stream": True, @@ -87,29 +85,25 @@ async def generate( if self._MODEL is not None: json_["model"] = self._MODEL - yield self._call_api("POST", self._url, headers=headers, json=json_) - - def _call_openai_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]: - system_prompt = self._make_system_content(sources) - - yield self.generate( - prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_new_tokens - ) + async with self._call_api( + "POST", self._url, headers=headers, json=json_ + ) as stream: + async for data in stream: + yield data async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - async with self._call_openai_api( - prompt, sources, max_new_tokens=max_new_tokens - ) as stream: - async for data in stream: - choice = data["choices"][0] - if choice["finish_reason"] is not None: - break - yield cast(str, choice["delta"]["content"]) + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_system_content(message.sources), + max_new_tokens=max_new_tokens, + ): + choice = data["choices"][0] + if choice["finish_reason"] is not None: + break + yield cast(str, choice["delta"]["content"]) class OpenaiAssistant(OpenaiLikeHttpApiAssistant): diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index f7c9c594..a9f9d52b 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -82,7 +82,7 @@ def __init__(self, base_url): super().__init__() self._endpoint = f"{base_url}/{self._STREAMING_PROTOCOL.name.lower()}" - async def answer(self, messages): + async def generate(self, messages): if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON: parse_kwargs = dict(item="item") else: @@ -95,11 +95,15 @@ async def answer(self, messages): parse_kwargs=parse_kwargs, ) as stream: async for chunk in stream: - if chunk.get("break"): - break - yield chunk + async def answer(self, messages): + async for chunk in self.generate(messages): + if chunk.get("break"): + break + + yield chunk + @skip_on_windows @pytest.mark.parametrize("streaming_protocol", list(HttpStreamingProtocol)) From 6288bb10630ed2ad7d502569e8f648a1a80d9331 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 6 Sep 2024 10:50:56 +0200 Subject: [PATCH 13/15] more cleanup --- ragna/assistants/_ai21labs.py | 2 +- ragna/assistants/_google.py | 3 ++- ragna/assistants/_openai.py | 8 +------- tests/assistants/test_api.py | 10 +++++----- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index fef9762b..da148cdc 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -34,7 +34,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: else: messages = prompt return [ - {"role": message.role.value, "content": message.content} + {"text": message.content, "role": message.role.value} for message in messages if message.role is not MessageRole.SYSTEM ] diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index ccc2f83b..bbeef04c 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -50,6 +50,7 @@ async def generate( Returns: async streamed inference response string chunks """ + # See https://ai.google.dev/api/generate-content#v1beta.models.streamGenerateContent async with self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", @@ -76,7 +77,7 @@ async def generate( "maxOutputTokens": max_new_tokens, }, }, - parse_kwargs=dict(item="item"), # .candidates.item.content.parts.item.text + parse_kwargs=dict(item="item"), ) as stream: async for data in stream: yield data diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 3831cf57..c75d49b1 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,12 +1,6 @@ import abc from functools import cached_property -from typing import ( - Any, - AsyncIterator, - Optional, - Union, - cast, -) +from typing import Any, AsyncIterator, Optional, Union, cast from ragna.core import Message, MessageRole, Source diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index a9f9d52b..b6dc0b9f 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -94,15 +94,15 @@ async def generate(self, messages): content=messages[-1].content, parse_kwargs=parse_kwargs, ) as stream: - async for chunk in stream: - yield chunk + async for data in stream: + yield data async def answer(self, messages): - async for chunk in self.generate(messages): - if chunk.get("break"): + async for data in self.generate(messages): + if data.get("break"): break - yield chunk + yield data @skip_on_windows From c290d0719aea736fad319ec8f174167d3f83b13e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 6 Sep 2024 12:54:12 +0200 Subject: [PATCH 14/15] debug --- ragna/assistants/_http_api.py | 6 ++++++ tests/assistants/test_api.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py index adc794b8..0969ec35 100644 --- a/ragna/assistants/_http_api.py +++ b/ragna/assistants/_http_api.py @@ -97,6 +97,8 @@ async def stream() -> AsyncIterator[Any]: yield stream() + assert False + @contextlib.asynccontextmanager async def _stream_jsonl( self, @@ -115,6 +117,8 @@ async def stream() -> AsyncIterator[Any]: yield stream() + assert False + # ijson does not support reading from an (async) iterator, but only from file-like # objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. # See https://github.com/ICRAR/ijson/issues/44 for details. @@ -158,6 +162,8 @@ async def stream() -> AsyncIterator[Any]: yield stream() + assert False + async def _assert_api_call_is_success(self, response: httpx.Response) -> None: if response.is_success: return diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index b6dc0b9f..067b7269 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -82,6 +82,27 @@ def __init__(self, base_url): super().__init__() self._endpoint = f"{base_url}/{self._STREAMING_PROTOCOL.name.lower()}" + # def generate(self, messages): + # if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON: + # parse_kwargs = dict(item="item") + # else: + # parse_kwargs = dict() + # + # return self._call_api( + # "POST", + # self._endpoint, + # content=messages[-1].content, + # parse_kwargs=parse_kwargs, + # ) + # + # async def answer(self, messages): + # async with self.generate(messages) as stream: + # async for data in stream: + # if data.get("break"): + # break + # + # yield data + async def generate(self, messages): if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON: parse_kwargs = dict(item="item") From 5b3bfd5c0a458a8274a17c9b18de326e316b6640 Mon Sep 17 00:00:00 2001 From: dillonroach Date: Sun, 15 Sep 2024 13:04:16 -0700 Subject: [PATCH 15/15] rename/reference system prompt functions for clarity, message.role.value cleanup --- ragna/assistants/_ai21labs.py | 4 ++-- ragna/assistants/_anthropic.py | 4 ++-- ragna/assistants/_cohere.py | 8 ++++---- ragna/assistants/_google.py | 2 +- ragna/assistants/_openai.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index da148cdc..d3c01fc3 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -14,7 +14,7 @@ class Ai21LabsAssistant(HttpApiAssistant): def display_name(cls) -> str: return f"AI21Labs/jurassic-2-{cls._MODEL_TYPE}" - def _make_system_content(self, sources: list[Source]) -> str: + def _make_rag_system_content(self, sources: list[Source]) -> str: instruction = ( "You are a helpful assistant that answers user questions given the context below. " "If you don't know the answer, just say so. Don't try to make up an answer. " @@ -87,7 +87,7 @@ async def answer( message = messages[-1] async for data in self.generate( [message], - system_prompt=self._make_system_content(message.sources), + system_prompt=self._make_rag_system_content(message.sources), max_new_tokens=max_new_tokens, ): yield cast(str, data["outputs"][0]["text"]) diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index b58cef5c..06183675 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -14,7 +14,7 @@ class AnthropicAssistant(HttpApiAssistant): def display_name(cls) -> str: return f"Anthropic/{cls._MODEL}" - def _instructize_system_prompt(self, sources: list[Source]) -> str: + def _make_rag_system_prompt(self, sources: list[Source]) -> str: # See https://docs.anthropic.com/claude/docs/system-prompts # See https://docs.anthropic.com/claude/docs/long-context-window-tips#tips-for-document-qa instruction = ( @@ -98,7 +98,7 @@ async def answer( message = messages[-1] async for data in self.generate( [message], - system_prompt=self._instructize_system_prompt(message.sources), + system_prompt=self._make_rag_system_prompt(message.sources), max_new_tokens=max_new_tokens, ): # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 9550dfb6..4da28e69 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -14,14 +14,14 @@ class CohereAssistant(HttpApiAssistant): def display_name(cls) -> str: return f"Cohere/{cls._MODEL}" - def _make_preamble(self) -> str: + def _make_rag_preamble(self) -> str: return ( "You are a helpful assistant that answers user questions given the included context. " "If you don't know the answer, just say so. Don't try to make up an answer. " "Only use the included documents below to generate the answer." ) - def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: + def _make_rag_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: @@ -94,8 +94,8 @@ async def answer( message = messages[-1] async for data in self.generate( prompt=message.content, - system_prompt=self._make_preamble(), - source_documents=self._make_source_documents(message.sources), + system_prompt=self._make_rag_preamble(), + source_documents=self._make_rag_source_documents(message.sources), max_new_tokens=max_new_tokens, ): if data["event_type"] == "stream-end": diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index bbeef04c..95b094c7 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -31,7 +31,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: else: messages = prompt return [ - {"parts": [{"text": message.content}]} + {"parts": [{"text": message.content, "role": message.role.value}]} for message in messages if message.role is not MessageRole.SYSTEM ] diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index c75d49b1..7674393f 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -14,7 +14,7 @@ class OpenaiLikeHttpApiAssistant(HttpApiAssistant): @abc.abstractmethod def _url(self) -> str: ... - def _make_system_content(self, sources: list[Source]) -> str: + def _make_rag_system_content(self, sources: list[Source]) -> str: # See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb instruction = ( "You are an helpful assistants that answers user questions given the context below. " @@ -91,7 +91,7 @@ async def answer( message = messages[-1] async for data in self.generate( [message], - system_prompt=self._make_system_content(message.sources), + system_prompt=self._make_rag_system_content(message.sources), max_new_tokens=max_new_tokens, ): choice = data["choices"][0]