Skip to content

Commit

Permalink
Add _call_api
Browse files Browse the repository at this point in the history
  • Loading branch information
smokestacklightnin committed Feb 1, 2024
1 parent 17264a8 commit 9e6363b
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum, unique
from typing import cast

from ragna.core import Source
from ragna.core import RagnaException, Source

from ._api import ApiAssistant

Expand Down Expand Up @@ -32,3 +33,34 @@ def _make_system_content(self, sources: list[Source]) -> str:
"Only use the sources below to generate the answer."
)
return instruction + "\n\n".join(source.content for source in sources)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
response = await self._client.post(
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self._api_key}",
},
json={
"numResults": 1,
"temperature": 0.0,
"maxTokens": max_new_tokens,
"messages": [
{
"text": prompt,
"role": "user",
}
],
"system": self._make_system_content(sources),
},
)

if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)

return cast(str, response.json()["outputs"][0]["text"])

0 comments on commit 9e6363b

Please sign in to comment.