Skip to content

Commit

Permalink
Add a meta_optimization example
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackHC committed Mar 4, 2024
1 parent 474b952 commit fb014e5
Show file tree
Hide file tree
Showing 11 changed files with 555 additions and 9 deletions.
File renamed without changes.
File renamed without changes
File renamed without changes.
File renamed without changes
File renamed without changes
431 changes: 431 additions & 0 deletions examples/research/meta_optimization.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions llm_hyperparameters/track_execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from typing import List, Optional
from typing import Any, Dict, List, Optional

from langchain.schema import BaseMessage, ChatMessage, ChatResult, LLMResult
from langchain_core.language_models import BaseChatModel, BaseLanguageModel, BaseLLM
Expand Down Expand Up @@ -149,6 +149,9 @@ class TrackedChatModel(BaseChatModel):
def _llm_type(self) -> str:
return self.chat_model._llm_type

def dict(self, **kwargs: Any) -> Dict:
return self.chat_model.dict(**kwargs)

@trace_calls(name="TrackedChatModel", kind=TraceNodeKind.LLM, capture_args=True, capture_return=True)
def invoke(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs) -> BaseMessage:
response_message = self.chat_model.invoke(messages, stop, **kwargs)
Expand Down Expand Up @@ -193,7 +196,8 @@ def get_tracked_chats(chat_model_or_chat_chain: ChatChain | TrackedChatModel) ->
model = chat_model_or_chat_chain
else:
raise ValueError(f"Unknown language model type {type(chat_model_or_chat_chain)}")
return model.tracked_chats.build_compact_dict()["children"]
assert isinstance(model, TrackedChatModel)
return model.tracked_chats.build_compact_dict()


@trace_object_converter.register_converter()
Expand Down
48 changes: 48 additions & 0 deletions llm_strategy/cached_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Blackboard-PAGI - LLM Proto-AGI using the Blackboard Pattern
# Copyright (c) 2023. Andreas Kirsch
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import List, Optional

import langchain
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
Generation,
)


class CachedChatOpenAI(ChatOpenAI):
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult:
messages_prompt = repr(messages)
if langchain.llm_cache:
results = langchain.llm_cache.lookup(messages_prompt, self.model_name)
if results:
assert len(results) == 1
result: Generation = results[0]
chat_result = ChatResult(
generations=[ChatGeneration(message=AIMessage(content=result.text))],
llm_output=result.generation_info,
)
return chat_result
chat_result = super()._generate(messages, stop)
if langchain.llm_cache:
assert len(chat_result.generations) == 1
result = Generation(text=chat_result.generations[0].message.content, generation_info=chat_result.llm_output)
langchain.llm_cache.update(messages_prompt, self.model_name, [result])
return chat_result
2 changes: 1 addition & 1 deletion llm_strategy/chat_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def structured_query(
parser: PydanticOutputParser = PydanticOutputParser(pydantic_object=output_model)
question_and_formatting = question + "\n\n" + parser.get_format_instructions()

reply_content, chain = self.query(question_and_formatting, **self.enforce_json_response(model_args))
reply_content, chain = self.query(question_and_formatting, self.enforce_json_response(model_args))
parsed_reply: B = typing.cast(B, parser.parse(reply_content))

return parsed_reply, chain
Expand Down
47 changes: 41 additions & 6 deletions llm_strategy/llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,17 @@ def resolve_type(source_type: type, generic_type_map: dict[type, type]) -> type:
return source_type

@staticmethod
def resolve_generic_types(model: type[BaseModel], instance: BaseModel):
def resolve_generic_types(model: type[BaseModel], instance: BaseModel) -> dict:
"""
Resolves the generic types of a given model instance and returns a generic type map.
Args:
model (type[BaseModel]): The model type.
instance (BaseModel): The instance of the model.
Returns:
dict: The generic type map.
"""
generic_type_map: dict = {}

for field_name, attr_value in list(instance):
Expand All @@ -281,9 +291,12 @@ def resolve_generic_types(model: type[BaseModel], instance: BaseModel):
LLMStructuredPrompt.add_resolved_type(generic_type_map, annotation, type(attr_value))
# if the annotation is a generic type alias ignore
elif isinstance(annotation, types.GenericAlias):
# The generic type alias is not supported yet
# The problem is that GenericAlias types are elided: e.g. type(list[str](["hello"])) -> list and not list[str].
# But either way, we would need to resolve the types based on the actual elements and their mros.
continue
# if the annotation is a type, check if it is a generic type
elif issubclass(annotation, generics.GenericModel):
elif isinstance(annotation, type) and issubclass(annotation, generics.GenericModel):
# check if the type is in generics._assigned_parameters
generic_definition_type_map = LLMStructuredPrompt.get_generic_type_map(annotation)

Expand All @@ -299,6 +312,9 @@ def resolve_generic_types(model: type[BaseModel], instance: BaseModel):
continue
resolved_type = generic_instance_type_map[generic_parameter]
LLMStructuredPrompt.add_resolved_type(generic_type_map, generic_parameter_target, resolved_type)
else:
# Let Pydantic handle the rest
continue

return generic_type_map

Expand All @@ -319,6 +335,11 @@ def add_resolved_type(generic_type_map, source_type, resolved_type):

@staticmethod
def get_generic_type_map(generic_type, base_generic_type=None):
"""Build a generic type map for a generic type.
It maps the generic type variables to the actual types.
"""

if base_generic_type is None:
base_generic_type = LLMStructuredPrompt.get_base_generic_type(generic_type)

Expand All @@ -341,7 +362,19 @@ def get_generic_type_map(generic_type, base_generic_type=None):
return generic_parameter_type_map

@staticmethod
def get_base_generic_type(field_type) -> type[generics.GenericModel]:
def get_base_generic_type(field_type: type) -> type[generics.GenericModel]:
"""Determine the base generic type of a generic type. E.g. List[str] -> List.
Args:
field_type (type): The generic type.
Raises:
ValueError: If the base generic type cannot be found.
Returns:
type[generics.GenericModel]: The base generic type.
"""

# get the base class name from annotation (which is without [])
base_generic_name = field_type.__name__
if "[" in field_type.__name__:
Expand Down Expand Up @@ -535,12 +568,14 @@ def from_call(f: typing.Callable[P, T], args: P.args, kwargs: P.kwargs) -> "LLMB
"""

# get clean docstring
docstring = inspect.getdoc(f)
docstring = Hyperparameter("docstring") @ inspect.getdoc(f)
if docstring is None:
raise ValueError("The function must have a docstring.")

# get the type of the first argument
signature = inspect.signature(f, eval_str=True)
globals_from_f = f.__globals__ if hasattr(f, "__globals__") else None
locals_from_f = f.__dict__ if hasattr(f, "__dict__") else None
signature = inspect.signature(f, eval_str=True, globals=globals_from_f, locals=locals_from_f)

# get all parameters
parameters_items: list[tuple[str, inspect.Parameter]] = list(signature.parameters.items())
Expand Down Expand Up @@ -725,7 +760,7 @@ def __getattr__(self, item):
def explicit(self, language_model_or_chat_chain: BaseLanguageModel | ChatChain, input_object: BaseModel):
"""Call the function with explicit inputs."""

return self(language_model_or_chat_chain, **dict(input_object))
return track_hyperparameters(self)(language_model_or_chat_chain, **dict(input_object))

@trace_calls(kind=TraceNodeKind.CHAIN, capture_return=slicer[1:], capture_args=True)
def __call__(
Expand Down
28 changes: 28 additions & 0 deletions llm_strategy/tests/test_llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,22 @@ class FInputs(GenericModel, typing.Generic[T]):
assert llm_bound_signature.output_type.schema() == Output[dict[str, str]].schema()


def test_llm_bound_signature_from_call_generic_collection_2() -> None:
# TODO: I do not like the result but I don't want to implement a full generic type resolution system.
T = typing.TypeVar("T")

def f(llm: BaseLLM, a: list[T]) -> dict[T, T]:
"""Test docstring."""
raise NotImplementedError

class FInputs(GenericModel, typing.Generic[T]):
a: list[T]

llm_bound_signature = LLMBoundSignature.from_call(f, (), dict(a=["a"]))
assert llm_bound_signature.input_type.schema() == FInputs.schema()
assert llm_bound_signature.output_type.schema() == Output[dict[T, T]].schema()


def test_llm_bound_signature_from_call_generic_function() -> None:
T = typing.TypeVar("T")
S = typing.TypeVar("S")
Expand All @@ -320,6 +336,18 @@ def f(llm: BaseLLM, a: T, b: S) -> dict[T, S]:
assert llm_bound_signature.output_type.schema() == Output[dict[int, str]].schema()


def test_llm_bound_signature_from_call_generic_function_with_primitive_types() -> None:
T = typing.TypeVar("T")
S = typing.TypeVar("S")

def f(llm: BaseLLM, a: T, b: S, c: int) -> dict[T, S]:
"""Test docstring."""
raise NotImplementedError

llm_bound_signature = LLMBoundSignature.from_call(f, (), dict(a=0, b="", c=10))
assert llm_bound_signature.output_type.schema() == Output[dict[int, str]].schema()


def test_llm_bound_signature_from_call_generic_input_outputs_full_remap() -> None:
T = typing.TypeVar("T")
S = typing.TypeVar("S")
Expand Down

0 comments on commit fb014e5

Please sign in to comment.