Skip to content

LLM function composition #1722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/primary.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ jobs:
include:
- os: ubuntu-latest
target: x86_64-unknown-linux-gnu
- os: macos-latest
- os: macos-latest
target: x86_64-apple-darwin
- os: windows-latest
target: x86_64-pc-windows-msvc
Expand All @@ -170,4 +170,4 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: baml-cli-${{ matrix.target }}
path: engine/target/release/baml-cli*
path: engine/target/release/baml-cli*
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ $RECYCLE.BIN/
**/dist
*.cab
*.env
*.exe
*.icloud
*.lcov
*.lnk
Expand Down Expand Up @@ -165,3 +164,5 @@ yarn-debug.log*
yarn-error.log*
yarn.lock
artifacts
.direnv
typescript/vscode-ext/packages/vscode/server/baml-cli
1 change: 1 addition & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

82 changes: 74 additions & 8 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::collections::HashSet;

use indexmap::IndexMap;
use internal_baml_diagnostics::Span;
use internal_baml_parser_database::walkers::ExprFnWalker;
use internal_baml_schema_ast::ast::{WithIdentifier, WithSpan};
use itertools::Itertools;

Expand All @@ -26,9 +27,10 @@ use baml_types::{
};
pub use to_baml_arg::ArgCoercer;

use super::repr;
use super::{repr, ExprFunctionNode};

pub type FunctionWalker<'a> = Walker<'a, &'a FunctionNode>;
pub type ExprFunctionWalker<'a> = Walker<'a, &'a ExprFunctionNode>;
pub type EnumWalker<'a> = Walker<'a, &'a Enum>;
pub type EnumValueWalker<'a> = Walker<'a, &'a EnumValue>;
pub type ClassWalker<'a> = Walker<'a, &'a Class>;
Expand All @@ -37,15 +39,22 @@ pub type TemplateStringWalker<'a> = Walker<'a, &'a TemplateString>;
pub type ClientWalker<'a> = Walker<'a, &'a Client>;
pub type RetryPolicyWalker<'a> = Walker<'a, &'a RetryPolicy>;
pub type TestCaseWalker<'a> = Walker<'a, (&'a FunctionNode, &'a TestCase)>;
pub type TestCaseExprWalker<'a> = Walker<'a, (&'a ExprFunctionNode, &'a TestCase)>;
pub type ClassFieldWalker<'a> = Walker<'a, &'a Field>;

pub trait IRHelper {
fn find_enum<'a>(&'a self, enum_name: &str) -> Result<EnumWalker<'a>>;
fn find_class<'a>(&'a self, class_name: &str) -> Result<ClassWalker<'a>>;
fn find_type_alias<'a>(&'a self, alias_name: &str) -> Result<TypeAliasWalker<'a>>;
fn find_expr_fn<'a>(&'a self, function_name: &str) -> Result<ExprFunctionWalker<'a>>;
fn find_function<'a>(&'a self, function_name: &str) -> Result<FunctionWalker<'a>>;
fn find_client<'a>(&'a self, client_name: &str) -> Result<ClientWalker<'a>>;
fn find_retry_policy<'a>(&'a self, retry_policy_name: &str) -> Result<RetryPolicyWalker<'a>>;
fn find_expr_fn_test<'a>(
&'a self,
function: &'a ExprFunctionWalker<'a>,
test_name: &str,
) -> Result<TestCaseExprWalker<'a>>;
fn find_template_string<'a>(
&'a self,
template_string_name: &str,
Expand All @@ -62,7 +71,7 @@ pub trait IRHelper {

fn check_function_params<'a>(
&'a self,
function: &'a FunctionWalker<'a>,
function_params: &Vec<(String, FieldType)>,
params: &BamlMap<String, BamlValue>,
coerce_settings: ArgCoercer,
) -> Result<BamlValue>;
Expand Down Expand Up @@ -118,13 +127,17 @@ pub trait IRHelperExtended: IRSemanticStreamingHelper {
.get_all_recursive_aliases(name)
.any(|target| self.is_subtype(base, target)),

(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(base_item), FieldType::Optional(other_item)) => {
self.is_subtype(base_item, other_item)
}
(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(_, FieldType::Optional(t)) => self.is_subtype(base, t),
(FieldType::Optional(_), _) => false,

(FieldType::Primitive(p1), FieldType::Primitive(p2)) => p1 == p2,
(FieldType::Primitive(TypeValue::Null), _) => false,
(FieldType::Primitive(p1), _) => false,

// Handle types that nest other types.
(FieldType::List(base_item), FieldType::List(other_item)) => {
self.is_subtype(&base_item, other_item)
Expand Down Expand Up @@ -180,9 +193,22 @@ pub trait IRHelperExtended: IRSemanticStreamingHelper {
.all(|(base_item, other_item)| self.is_subtype(base_item, other_item))
}
(FieldType::Tuple(_), _) => false,
(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
(FieldType::Class(_), _) => false,

(FieldType::Arrow(arrow1), FieldType::Arrow(arrow2)) => {
let param_lengths_match = arrow1.param_types.len() == arrow2.param_types.len();
// N.B. Functions are covariant in their return type and contravariant in their arguments.
// This is why a and b are swapped in the parameters check, and no in the return type check.
let return_types_match = self.is_subtype(&arrow1.return_type, &arrow2.return_type);
let args_match = arrow1
.param_types
.iter()
.zip(arrow2.param_types.iter())
.all(|(a, b)| self.is_subtype(b, a));
param_lengths_match && return_types_match && args_match
}
(FieldType::Arrow(_), _) => false,
}
}

Expand Down Expand Up @@ -559,6 +585,24 @@ impl IRHelper for IntermediateRepr {
}
}

fn find_expr_fn_test<'a>(
&'a self,
function: &'a ExprFunctionWalker<'a>,
test_name: &str,
) -> Result<TestCaseExprWalker<'a>> {
match function.find_test(test_name) {
Some(t) => Ok(t),
None => {
// Get best match.
let tests = function
.walk_tests()
.map(|t| t.item.1.elem.name.as_str())
.collect::<Vec<_>>();
error_not_found!("test", test_name, &tests)
}
}
}

fn find_enum(&self, enum_name: &str) -> Result<EnumWalker<'_>> {
match self.walk_enums().find(|e| e.name() == enum_name) {
Some(e) => Ok(e),
Expand Down Expand Up @@ -607,6 +651,28 @@ impl IRHelper for IntermediateRepr {
}
}

fn find_expr_fn<'a>(&'a self, function_name: &str) -> Result<ExprFunctionWalker<'a>> {
let expr_fn_names = self
.walk_expr_fns()
.map(|f| f.item.elem.name.clone())
.collect::<Vec<_>>();
match self
.walk_expr_fns()
.find(|f| f.item.elem.name == function_name)
{
Some(f) => Ok(f),

None => {
// Get best match.
let functions = self
.walk_expr_fns()
.map(|f| f.item.elem.name.clone())
.collect::<Vec<_>>();
error_not_found!("function", function_name, &functions)
}
}
}

fn find_client<'a>(&'a self, client_name: &str) -> Result<ClientWalker<'a>> {
match self.walk_clients().find(|c| c.name() == client_name) {
Some(c) => Ok(c),
Expand Down Expand Up @@ -856,12 +922,10 @@ impl IRHelper for IntermediateRepr {

fn check_function_params<'a>(
&'a self,
function: &'a FunctionWalker<'a>,
function_params: &Vec<(String, FieldType)>,
params: &BamlMap<String, BamlValue>,
coerce_settings: ArgCoercer,
) -> Result<BamlValue> {
let function_params = function.inputs();

// Now check that all required parameters are present.
let mut scope = ScopeStack::new();
let mut baml_arg_map = BamlMap::new();
Expand Down Expand Up @@ -1061,6 +1125,7 @@ pub fn item_type<'ir, 'a, T: std::fmt::Debug>(
}
}
FieldType::Tuple(_) => None,
FieldType::Arrow(_) => None,
FieldType::WithMetadata { base, .. } => item_type(ir, base, baml_child_values),
};
res
Expand Down Expand Up @@ -1097,6 +1162,7 @@ where
variant_map_types.next()
}
FieldType::Class(_) => None,
FieldType::Arrow(_) => None,
FieldType::WithMetadata { .. } => {
unreachable!("distribute_metadata never returns this variant")
}
Expand Down Expand Up @@ -1489,7 +1555,7 @@ mod tests {
span_path: None,
allow_implicit_cast_to_string: true,
};
let res = ir.check_function_params(&function, &params, arg_coercer);
let res = ir.check_function_params(&function.inputs(), &params, arg_coercer);
assert!(res.is_err());
}

Expand Down
4 changes: 4 additions & 0 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ impl ArgCoercer {
}
}
}
(FieldType::Arrow(_), _) => {
scope.push_error(format!("A json value may not be coerced into a function type"));
Err(())
}
(FieldType::WithMetadata { .. }, _) => {
unreachable!("The return value of distribute_constraints can never be FieldType::Constrainted");
}
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/src/ir/json_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ impl WithJsonSchema for FieldType {
}
}
FieldType::WithMetadata { base, .. } => base.json_schema(),
FieldType::Arrow(_) => json!({}), // TODO: Make this function partial - it should not return for Arrow.
}
}
}
3 changes: 2 additions & 1 deletion engine/baml-lib/baml-core/src/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod walker;

pub use ir_helpers::{
scope_diagnostics, ArgCoercer, ClassFieldWalker, ClassWalker, ClientWalker, EnumValueWalker,
EnumWalker, FunctionWalker, IRHelper, IRHelperExtended, IRSemanticStreamingHelper,
EnumWalker, ExprFunctionWalker, FunctionWalker, IRHelper, IRHelperExtended, IRSemanticStreamingHelper,
RetryPolicyWalker, TemplateStringWalker, TestCaseWalker, TypeAliasWalker,
};

Expand All @@ -21,6 +21,7 @@ pub type Field = repr::Node<repr::Field>;
pub type FieldType = baml_types::FieldType;
pub type TypeValue = baml_types::TypeValue;
pub type FunctionNode = repr::Node<repr::Function>;
pub type ExprFunctionNode = repr::Node<repr::ExprFunction>;
#[allow(dead_code)]
pub(super) type Function = repr::Function;
pub(super) type FunctionArgs = repr::FunctionArgs;
Expand Down
Loading
Loading