diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index 21be7557..207533ec 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -1,5 +1,10 @@ +from __future__ import annotations + import abc import os +from typing import Annotated + +import pydantic import ragna from ragna.core import Assistant, EnvVarRequirement, Requirement, Source @@ -22,7 +27,22 @@ def __init__(self) -> None: self._api_key = os.environ[self._API_KEY_ENV_VAR] async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, + prompt: str, + sources: list[Source], + *, + max_new_tokens: Annotated[ + int, + pydantic.Field( + title="Maximum new tokens", + description=( + "Maximum number of new tokens to generate. " + "If you experience truncated answers, increase this value. " + "However, be aware that longer answers also incur a higher cost." + ), + gt=0, + ), + ] = 256, ) -> str: return await self._call_api(prompt, sources, max_new_tokens=max_new_tokens) diff --git a/ragna/core/_components.py b/ragna/core/_components.py index f2196439..9ebbb8d8 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -4,13 +4,13 @@ import enum import functools import inspect -from typing import Type +from typing import Type, get_type_hints import pydantic import pydantic.utils from ._document import Document -from ._utils import RequirementsMixin, merge_models +from ._utils import RagnaException, RequirementsMixin, merge_models class Component(RequirementsMixin): @@ -56,17 +56,30 @@ def _protocol_models( ).parameters extra_param_names = concrete_params.keys() - protocol_params.keys() + # We can't rely on inspect.Param.annotation, since that might be str. This + # happens if the annotation was either explicitly provided as str, or there + # is a top-level from __future__ import annotations in the module. + annotations = get_type_hints(method, include_extras=True) + + field_definitions = {} + for param_name in extra_param_names: + annotation = annotations.get(param_name) + if annotation is None: + # FIXME: can we live with typing.Any here? + raise RagnaException("Missing annotation") + + default = concrete_params[param_name].default + if default is inspect.Parameter.empty: + default = ... + + field_info = pydantic.fields.FieldInfo.from_annotated_attribute( + annotation, default + ) + + field_definitions[param_name] = (annotation, field_info) + models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload] - f"{cls.__name__}.{method_name}", - **{ - (param := concrete_params[param_name]).name: ( - param.annotation, - param.default - if param.default is not inspect.Parameter.empty - else ..., - ) - for param_name in extra_param_names - }, + f"{cls.__name__}.{method_name}", **field_definitions ) return models diff --git a/ragna/core/_utils.py b/ragna/core/_utils.py index 9c20184e..49414166 100644 --- a/ragna/core/_utils.py +++ b/ragna/core/_utils.py @@ -13,7 +13,6 @@ import packaging.requirements import pydantic -import pydantic_core from ragna._compat import importlib_metadata_package_distributions @@ -140,24 +139,28 @@ def merge_models( *models: Type[pydantic.BaseModel], config: Optional[pydantic.ConfigDict] = None, ) -> Type[pydantic.BaseModel]: - raw_field_definitions = defaultdict(list) + field_infoss = defaultdict(list) for model_cls in models: - for name, field in model_cls.model_fields.items(): - type_ = field.annotation - - default: Any - if field.is_required(): - default = ... - elif field.default is pydantic_core.PydanticUndefined: - default = field.default_factory() # type: ignore[misc] - else: - default = field.default - - raw_field_definitions[name].append((type_, default)) + for name, field_info in model_cls.model_fields.items(): + field_infoss[name].append(field_info) field_definitions = {} - for name, definitions in raw_field_definitions.items(): - types, defaults = zip(*definitions) + for name, field_infos in field_infoss.items(): + # data = defaultdict(set) + # for field_info in field_infos: + # data["annotation"].add(field_info.annotation) + # data["default"].add(field_info.get_default(call_default_factory=True)) + # data["title"].add() + + types, defaults = zip( + *( + ( + field_info.annotation, + field_info.get_default(call_default_factory=True), + ) + for field_info in field_infos + ) + ) types = set(types) if len(types) > 1: @@ -172,7 +175,12 @@ def merge_models( else: default = None - field_definitions[name] = (type_, default) + # FIXME: We need a way to also check / merge the additional metadata instead of + # just taking the first at face value. + field_info = field_infos[0] + field_info.default = default + + field_definitions[name] = (type_, field_info) return cast( Type[pydantic.BaseModel],