Skip to content

Commit

Permalink
add google assistants
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jan 31, 2024
1 parent 1b53e62 commit dad1024
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 36 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ disallow_incomplete_defs = false

[[tool.mypy.overrides]]
module = [
"docx",
"fitz",
"json_stream",
"json_stream.httpx",
"lancedb",
"param",
"pyarrow",
"docx",
"pptx",
"pyarrow",
"sentence_transformers",
]
ignore_missing_imports = true
Expand Down
42 changes: 41 additions & 1 deletion ragna/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,23 @@
import sys
import threading
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
Optional,
TypeVar,
Union,
cast,
)
from urllib.parse import SplitResult, urlsplit, urlunsplit

from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

T = TypeVar("T")

_LOCAL_ROOT = (
Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve()
)
Expand Down Expand Up @@ -125,3 +139,29 @@ def is_debugging() -> bool:
if any(part.startswith(name) for part in parts):
return True
return False


def as_awaitable(
fn: Union[Callable[..., T], Callable[..., Awaitable[T]]],
*args: Any,
**kwargs: Any,
) -> Awaitable[T]:
if inspect.iscoroutinefunction(fn):
fn = cast(Callable[..., Awaitable[T]], fn)
return fn(*args, **kwargs)
else:
fn = cast(Callable[..., T], fn)
return run_in_threadpool(fn, *args, **kwargs)


def as_async_iterator(
fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]],
*args: Any,
**kwargs: Any,
) -> AsyncIterator[T]:
if inspect.isasyncgenfunction(fn):
fn = cast(Callable[..., AsyncIterator[T]], fn)
return fn(*args, **kwargs)
else:
fn = cast(Callable[..., Iterator[T]], fn)
return iterate_in_threadpool(fn(*args, **kwargs))
3 changes: 3 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"Claude",
"ClaudeInstant",
"GeminiPro",
"GeminiUltra",
"Gpt35Turbo16k",
"Gpt4",
"Mpt7bInstruct",
Expand All @@ -10,6 +12,7 @@

from ._anthropic import Claude, ClaudeInstant
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._mosaicml import Mpt7bInstruct, Mpt30bInstruct
from ._openai import Gpt4, Gpt35Turbo16k

Expand Down
2 changes: 1 addition & 1 deletion ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _call_api(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
self._async_client,
"POST",
"https://api.anthropic.com/v1/complete",
headers={
Expand Down
21 changes: 14 additions & 7 deletions ragna/assistants/_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
import os
from typing import AsyncIterator
from typing import Any, AsyncIterator, Iterator

import httpx

import ragna
from ragna._utils import as_async_iterator
from ragna.core import Assistant, EnvVarRequirement, Requirement, Source


Expand All @@ -16,22 +17,28 @@ def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR)]

def __init__(self) -> None:
self._client = httpx.AsyncClient(
self._api_key = os.environ[self._API_KEY_ENV_VAR]

kwargs: dict[str, Any] = dict(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
)
self._api_key = os.environ[self._API_KEY_ENV_VAR]
self._sync_client = httpx.Client(**kwargs)
self._async_client = httpx.AsyncClient(**kwargs)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
async for chunk in self._call_api( # type: ignore[attr-defined, misc]
prompt, sources, max_new_tokens=max_new_tokens
async for chunk in as_async_iterator(
self._call_api,
prompt,
sources,
max_new_tokens=max_new_tokens,
):
yield chunk

@abc.abstractmethod
async def _call_api(
def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
) -> Iterator[str]:
...
95 changes: 95 additions & 0 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Iterator

from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant


class GoogleApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str
_CONTEXT_SIZE: int

@classmethod
def requirements(cls) -> list[Requirement]:
return [
*super().requirements(),
PackageRequirement("json-stream"),
]

@classmethod
def display_name(cls) -> str:
return f"Google/{cls._MODEL}"

@property
def max_input_size(self) -> int:
return self._CONTEXT_SIZE

def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
# https://ai.google.dev/docs/prompt_best_practices#add-contextual-information
return "\n".join(
[
"Answer the prompt using only the pieces of context below.",
"If you don't know the answer, just say so. Don't try to make up additional context.",
f"Prompt: {prompt}",
*[f"\n{source.content}" for source in sources],
]
)

def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> Iterator[str]:
import json_stream.httpx

with self._sync_client.stream(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
headers={"Content-Type": "application/json"},
json={
"contents": [
{"parts": [{"text": self._instructize_prompt(prompt, sources)}]}
],
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
for category in [
"HARASSMENT",
"HATE_SPEECH",
"SEXUALLY_EXPLICIT",
"DANGEROUS_CONTENT",
]
],
# https://ai.google.dev/tutorials/rest_quickstart#configuration
"generationConfig": {
"temperature": 0.0,
"maxOutputTokens": max_new_tokens,
},
},
) as response:
for chunk in json_stream.httpx.load(response, persistent=True):
yield chunk["candidates"][0]["content"]["parts"][0]["text"]


class GeminiPro(GoogleApiAssistant):
"""[Google Gemini Pro](https://ai.google.dev/models/gemini)
!!! info "Required environment variables"
- `GOOGLE_API_KEY`
"""

_MODEL = "gemini-pro"
_CONTEXT_SIZE = 30_720


class GeminiUltra(GoogleApiAssistant):
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini)
!!! info "Required environment variables"
- `GOOGLE_API_KEY`
"""

_MODEL = "gemini-ultra"
_CONTEXT_SIZE = 30_720
2 changes: 1 addition & 1 deletion ragna/assistants/_mosaicml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def _call_api(
) -> AsyncIterator[str]:
instruction = self._instructize_prompt(prompt, sources)
# https://docs.mosaicml.com/en/latest/inference.html#text-completion-requests
response = await self._client.post(
response = await self._async_client.post(
f"https://models.hosted-on.mosaicml.hosting/{self._MODEL}/v1/predict",
headers={
"Authorization": f"{self._api_key}",
Expand Down
2 changes: 1 addition & 1 deletion ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _call_api(
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/streaming
async with httpx_sse.aconnect_sse(
self._client,
self._async_client,
"POST",
"https://api.openai.com/v1/chat/completions",
headers={
Expand Down
29 changes: 7 additions & 22 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import datetime
import inspect
import uuid
from typing import (
Any,
Expand All @@ -19,7 +18,8 @@
)

import pydantic
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ragna._utils import as_async_iterator, as_awaitable

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
Expand Down Expand Up @@ -256,34 +256,19 @@ def _unpack_chat_params(
for fn, model in component_models.items()
}

async def _run(
def _run(
self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any
) -> T:
) -> Awaitable[T]:
kwargs = self._unpacked_params[fn]
if inspect.iscoroutinefunction(fn):
fn = cast(Callable[..., Awaitable[T]], fn)
coro = fn(*args, **kwargs)
else:
fn = cast(Callable[..., T], fn)
coro = run_in_threadpool(fn, *args, **kwargs)

return await coro
return as_awaitable(fn, *args, **kwargs)

async def _run_gen(
def _run_gen(
self,
fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]],
*args: Any,
) -> AsyncIterator[T]:
kwargs = self._unpacked_params[fn]
if inspect.isasyncgenfunction(fn):
fn = cast(Callable[..., AsyncIterator[T]], fn)
async_gen = fn(*args, **kwargs)
else:
fn = cast(Callable[..., Iterator[T]], fn)
async_gen = iterate_in_threadpool(fn(*args, **kwargs))

async for item in async_gen:
yield item
return as_async_iterator(fn, *args, **kwargs)

async def __aenter__(self) -> Chat:
await self.prepare()
Expand Down
2 changes: 1 addition & 1 deletion ragna/deploy/_ui/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ApiWrapper(param.Parameterized):
auth_token = param.String(default=None)

def __init__(self, api_url, **params):
self.client = httpx.AsyncClient(base_url=api_url)
self.client = httpx.AsyncClient(base_url=api_url, timeout=60)

super().__init__(**params)

Expand Down

0 comments on commit dad1024

Please sign in to comment.