Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to annotate component parameters #218

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion ragna/assistants/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations

import abc
import os
from typing import Annotated

import pydantic

import ragna
from ragna.core import Assistant, EnvVarRequirement, Requirement, Source
Expand All @@ -22,7 +27,22 @@ def __init__(self) -> None:
self._api_key = os.environ[self._API_KEY_ENV_VAR]

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self,
prompt: str,
sources: list[Source],
*,
max_new_tokens: Annotated[
int,
pydantic.Field(
title="Maximum new tokens",
description=(
"Maximum number of new tokens to generate. "
"If you experience truncated answers, increase this value. "
"However, be aware that longer answers also incur a higher cost."
),
gt=0,
),
] = 256,
) -> str:
return await self._call_api(prompt, sources, max_new_tokens=max_new_tokens)

Expand Down
37 changes: 25 additions & 12 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import enum
import functools
import inspect
from typing import Type
from typing import Type, get_type_hints

import pydantic
import pydantic.utils

from ._document import Document
from ._utils import RequirementsMixin, merge_models
from ._utils import RagnaException, RequirementsMixin, merge_models


class Component(RequirementsMixin):
Expand Down Expand Up @@ -56,17 +56,30 @@ def _protocol_models(
).parameters
extra_param_names = concrete_params.keys() - protocol_params.keys()

# We can't rely on inspect.Param.annotation, since that might be str. This
# happens if the annotation was either explicitly provided as str, or there
# is a top-level from __future__ import annotations in the module.
annotations = get_type_hints(method, include_extras=True)

field_definitions = {}
for param_name in extra_param_names:
annotation = annotations.get(param_name)
if annotation is None:
# FIXME: can we live with typing.Any here?
raise RagnaException("Missing annotation")

default = concrete_params[param_name].default
if default is inspect.Parameter.empty:
default = ...

field_info = pydantic.fields.FieldInfo.from_annotated_attribute(
annotation, default
)

field_definitions[param_name] = (annotation, field_info)

models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload]
f"{cls.__name__}.{method_name}",
**{
(param := concrete_params[param_name]).name: (
param.annotation,
param.default
if param.default is not inspect.Parameter.empty
else ...,
)
for param_name in extra_param_names
},
f"{cls.__name__}.{method_name}", **field_definitions
)
return models

Expand Down
42 changes: 25 additions & 17 deletions ragna/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import packaging.requirements
import pydantic
import pydantic_core

from ragna._compat import importlib_metadata_package_distributions

Expand Down Expand Up @@ -140,24 +139,28 @@ def merge_models(
*models: Type[pydantic.BaseModel],
config: Optional[pydantic.ConfigDict] = None,
) -> Type[pydantic.BaseModel]:
raw_field_definitions = defaultdict(list)
field_infoss = defaultdict(list)
for model_cls in models:
for name, field in model_cls.model_fields.items():
type_ = field.annotation

default: Any
if field.is_required():
default = ...
elif field.default is pydantic_core.PydanticUndefined:
default = field.default_factory() # type: ignore[misc]
else:
default = field.default

raw_field_definitions[name].append((type_, default))
for name, field_info in model_cls.model_fields.items():
field_infoss[name].append(field_info)

field_definitions = {}
for name, definitions in raw_field_definitions.items():
types, defaults = zip(*definitions)
for name, field_infos in field_infoss.items():
# data = defaultdict(set)
# for field_info in field_infos:
# data["annotation"].add(field_info.annotation)
# data["default"].add(field_info.get_default(call_default_factory=True))
# data["title"].add()

types, defaults = zip(
*(
(
field_info.annotation,
field_info.get_default(call_default_factory=True),
)
for field_info in field_infos
)
)

types = set(types)
if len(types) > 1:
Expand All @@ -172,7 +175,12 @@ def merge_models(
else:
default = None

field_definitions[name] = (type_, default)
# FIXME: We need a way to also check / merge the additional metadata instead of
# just taking the first at face value.
field_info = field_infos[0]
field_info.default = default

field_definitions[name] = (type_, field_info)

return cast(
Type[pydantic.BaseModel],
Expand Down