From 0b1fa017c1b4f7cf2d51781c55f6db39451706b4 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Wed, 9 Apr 2025 16:51:58 -0700 Subject: [PATCH] Add Expression Language --- .github/workflows/primary.yml | 4 +- .gitignore | 3 +- engine/Cargo.lock | 1 + .../baml-core/src/ir/ir_helpers/mod.rs | 82 ++- .../src/ir/ir_helpers/to_baml_arg.rs | 4 + .../baml-lib/baml-core/src/ir/json_schema.rs | 1 + engine/baml-lib/baml-core/src/ir/mod.rs | 3 +- engine/baml-lib/baml-core/src/ir/repr.rs | 664 +++++++++++++++++- engine/baml-lib/baml-core/src/ir/walker.rs | 70 +- engine/baml-lib/baml-core/src/validate/mod.rs | 2 +- .../src/validate/validation_pipeline.rs | 2 +- .../validation_pipeline/validations.rs | 6 + .../validations/configurations.rs | 22 +- .../validations/expr_fns.rs | 158 +++++ .../validations/expr_typecheck.rs | 500 +++++++++++++ engine/baml-lib/baml-types/Cargo.toml | 1 + engine/baml-lib/baml-types/src/baml_value.rs | 38 + engine/baml-lib/baml-types/src/expr.rs | 286 ++++++++ .../baml-lib/baml-types/src/field_type/mod.rs | 22 +- engine/baml-lib/baml-types/src/lib.rs | 3 +- engine/baml-lib/baml-types/src/value_expr.rs | 32 + .../class/generator_keywords1.baml | 2 - .../validation_files/expr/constructors.baml | 22 + .../tests/validation_files/expr/expr_fn.baml | 11 + .../validation_files/expr/expr_full.baml | 71 ++ .../validation_files/expr/expr_list.baml | 7 + .../validation_files/expr/expr_small.baml | 3 + .../expr/missing_return_value.baml | 12 + .../expr/missing_semicolons.baml | 20 + .../validation_files/expr/mixed_pipeline.baml | 16 + .../expr/top_level_binding.baml | 6 + .../validation_files/expr/unknown_name.baml | 33 + .../strings/unquoted_strings.baml | 7 - engine/baml-lib/diagnostics/src/lib.rs | 2 +- engine/baml-lib/diagnostics/src/span.rs | 31 +- .../jinja-runtime/src/output_format/types.rs | 26 +- .../src/deserializer/coercer/field_type.rs | 2 + .../src/deserializer/semantic_streaming.rs | 1 + .../parser-database/src/walkers/expr_fn.rs | 66 ++ .../parser-database/src/walkers/mod.rs | 28 +- engine/baml-lib/schema-ast/Cargo.toml | 1 - engine/baml-lib/schema-ast/src/ast.rs | 61 +- .../baml-lib/schema-ast/src/ast/argument.rs | 19 +- engine/baml-lib/schema-ast/src/ast/expr.rs | 25 + .../baml-lib/schema-ast/src/ast/expression.rs | 205 +++++- engine/baml-lib/schema-ast/src/ast/top.rs | 26 +- .../schema-ast/src/parser/datamodel.pest | 81 ++- engine/baml-lib/schema-ast/src/parser/mod.rs | 1 + .../schema-ast/src/parser/parse_expr.rs | 197 ++++++ .../schema-ast/src/parser/parse_expression.rs | 81 ++- .../schema-ast/src/parser/parse_field.rs | 9 +- .../src/parser/parse_named_args_list.rs | 11 +- .../schema-ast/src/parser/parse_schema.rs | 65 +- .../src/parser/parse_template_string.rs | 5 +- .../src/parser/parse_type_expression_block.rs | 5 +- .../parser/parse_value_expression_block.rs | 5 +- engine/baml-runtime/src/cli/serve/mod.rs | 37 +- engine/baml-runtime/src/eval_expr.rs | 472 +++++++++++++ .../src/internal/prompt_renderer/mod.rs | 22 +- .../prompt_renderer/render_output_format.rs | 1 + engine/baml-runtime/src/lib.rs | 225 +++++- .../src/runtime/runtime_interface.rs | 186 +++-- engine/baml-runtime/src/runtime_interface.rs | 6 + engine/baml-runtime/src/test_executor/mod.rs | 17 +- engine/baml-schema-wasm/Cargo.toml | 2 +- .../baml-schema-wasm/src/runtime_wasm/mod.rs | 96 +++ engine/language_client_cffi/src/ctypes.rs | 15 +- .../src/go/generate_types.rs | 3 + engine/language_client_codegen/src/go/mod.rs | 2 + engine/language_client_codegen/src/openapi.rs | 1 + .../src/python/generate_types.rs | 3 + .../language_client_codegen/src/python/mod.rs | 3 + .../src/ruby/field_type.rs | 1 + .../src/ruby/generate_types.rs | 1 + .../src/typescript/mod.rs | 2 + flake.lock | 20 +- flake.nix | 124 ++-- integ-tests/typescript/package.json | 2 +- integ-tests/typescript/tests/logger.test.ts | 10 +- typescript/fiddle-frontend/next.config.mjs | 6 + typescript/nextjs-plugin/src/index.ts | 9 + .../src/baml_wasm_web/EventListener.tsx | 16 + .../prompt-preview/test-panel/test-runner.ts | 62 +- .../packages/language-server/src/server.ts | 2 +- .../packages/vscode/server/darwin/baml-cli | 1 + .../packages/vscode/server/linux/baml-cli | 1 + .../packages/vscode/src/extension.ts | 130 ++++ .../vscode/src/panels/WebviewPanelHost.ts | 11 + 88 files changed, 4264 insertions(+), 292 deletions(-) create mode 100644 engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_fns.rs create mode 100644 engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_typecheck.rs create mode 100644 engine/baml-lib/baml-types/src/expr.rs create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/constructors.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/expr_fn.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/expr_full.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/expr_list.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/expr_small.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/missing_return_value.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/missing_semicolons.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/mixed_pipeline.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/top_level_binding.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/expr/unknown_name.baml create mode 100644 engine/baml-lib/parser-database/src/walkers/expr_fn.rs create mode 100644 engine/baml-lib/schema-ast/src/ast/expr.rs create mode 100644 engine/baml-lib/schema-ast/src/parser/parse_expr.rs create mode 100644 engine/baml-runtime/src/eval_expr.rs create mode 120000 typescript/vscode-ext/packages/vscode/server/darwin/baml-cli create mode 120000 typescript/vscode-ext/packages/vscode/server/linux/baml-cli diff --git a/.github/workflows/primary.yml b/.github/workflows/primary.yml index 968a49238..6bd17f3ac 100644 --- a/.github/workflows/primary.yml +++ b/.github/workflows/primary.yml @@ -149,7 +149,7 @@ jobs: include: - os: ubuntu-latest target: x86_64-unknown-linux-gnu - - os: macos-latest + - os: macos-latest target: x86_64-apple-darwin - os: windows-latest target: x86_64-pc-windows-msvc @@ -170,4 +170,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: baml-cli-${{ matrix.target }} - path: engine/target/release/baml-cli* \ No newline at end of file + path: engine/target/release/baml-cli* diff --git a/.gitignore b/.gitignore index 6f19bc98d..2549a0925 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,6 @@ $RECYCLE.BIN/ **/dist *.cab *.env -*.exe *.icloud *.lcov *.lnk @@ -165,3 +164,5 @@ yarn-debug.log* yarn-error.log* yarn.lock artifacts +.direnv +typescript/vscode-ext/packages/vscode/server/baml-cli diff --git a/engine/Cargo.lock b/engine/Cargo.lock index aa1ad0fb8..21d1d9559 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -1221,6 +1221,7 @@ dependencies = [ "clap", "derive_builder", "indexmap 2.8.0", + "internal-baml-diagnostics", "itertools 0.14.0", "log", "minijinja", diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 0c9cd4b18..8ac7b7907 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -6,6 +6,7 @@ use std::collections::HashSet; use indexmap::IndexMap; use internal_baml_diagnostics::Span; +use internal_baml_parser_database::walkers::ExprFnWalker; use internal_baml_schema_ast::ast::{WithIdentifier, WithSpan}; use itertools::Itertools; @@ -26,9 +27,10 @@ use baml_types::{ }; pub use to_baml_arg::ArgCoercer; -use super::repr; +use super::{repr, ExprFunctionNode}; pub type FunctionWalker<'a> = Walker<'a, &'a FunctionNode>; +pub type ExprFunctionWalker<'a> = Walker<'a, &'a ExprFunctionNode>; pub type EnumWalker<'a> = Walker<'a, &'a Enum>; pub type EnumValueWalker<'a> = Walker<'a, &'a EnumValue>; pub type ClassWalker<'a> = Walker<'a, &'a Class>; @@ -37,15 +39,22 @@ pub type TemplateStringWalker<'a> = Walker<'a, &'a TemplateString>; pub type ClientWalker<'a> = Walker<'a, &'a Client>; pub type RetryPolicyWalker<'a> = Walker<'a, &'a RetryPolicy>; pub type TestCaseWalker<'a> = Walker<'a, (&'a FunctionNode, &'a TestCase)>; +pub type TestCaseExprWalker<'a> = Walker<'a, (&'a ExprFunctionNode, &'a TestCase)>; pub type ClassFieldWalker<'a> = Walker<'a, &'a Field>; pub trait IRHelper { fn find_enum<'a>(&'a self, enum_name: &str) -> Result>; fn find_class<'a>(&'a self, class_name: &str) -> Result>; fn find_type_alias<'a>(&'a self, alias_name: &str) -> Result>; + fn find_expr_fn<'a>(&'a self, function_name: &str) -> Result>; fn find_function<'a>(&'a self, function_name: &str) -> Result>; fn find_client<'a>(&'a self, client_name: &str) -> Result>; fn find_retry_policy<'a>(&'a self, retry_policy_name: &str) -> Result>; + fn find_expr_fn_test<'a>( + &'a self, + function: &'a ExprFunctionWalker<'a>, + test_name: &str, + ) -> Result>; fn find_template_string<'a>( &'a self, template_string_name: &str, @@ -62,7 +71,7 @@ pub trait IRHelper { fn check_function_params<'a>( &'a self, - function: &'a FunctionWalker<'a>, + function_params: &Vec<(String, FieldType)>, params: &BamlMap, coerce_settings: ArgCoercer, ) -> Result; @@ -118,13 +127,17 @@ pub trait IRHelperExtended: IRSemanticStreamingHelper { .get_all_recursive_aliases(name) .any(|target| self.is_subtype(base, target)), - (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, (FieldType::Optional(base_item), FieldType::Optional(other_item)) => { self.is_subtype(base_item, other_item) } + (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, (_, FieldType::Optional(t)) => self.is_subtype(base, t), (FieldType::Optional(_), _) => false, + (FieldType::Primitive(p1), FieldType::Primitive(p2)) => p1 == p2, + (FieldType::Primitive(TypeValue::Null), _) => false, + (FieldType::Primitive(p1), _) => false, + // Handle types that nest other types. (FieldType::List(base_item), FieldType::List(other_item)) => { self.is_subtype(&base_item, other_item) @@ -180,9 +193,22 @@ pub trait IRHelperExtended: IRSemanticStreamingHelper { .all(|(base_item, other_item)| self.is_subtype(base_item, other_item)) } (FieldType::Tuple(_), _) => false, - (FieldType::Primitive(_), _) => false, (FieldType::Enum(_), _) => false, (FieldType::Class(_), _) => false, + + (FieldType::Arrow(arrow1), FieldType::Arrow(arrow2)) => { + let param_lengths_match = arrow1.param_types.len() == arrow2.param_types.len(); + // N.B. Functions are covariant in their return type and contravariant in their arguments. + // This is why a and b are swapped in the parameters check, and no in the return type check. + let return_types_match = self.is_subtype(&arrow1.return_type, &arrow2.return_type); + let args_match = arrow1 + .param_types + .iter() + .zip(arrow2.param_types.iter()) + .all(|(a, b)| self.is_subtype(b, a)); + param_lengths_match && return_types_match && args_match + } + (FieldType::Arrow(_), _) => false, } } @@ -559,6 +585,24 @@ impl IRHelper for IntermediateRepr { } } + fn find_expr_fn_test<'a>( + &'a self, + function: &'a ExprFunctionWalker<'a>, + test_name: &str, + ) -> Result> { + match function.find_test(test_name) { + Some(t) => Ok(t), + None => { + // Get best match. + let tests = function + .walk_tests() + .map(|t| t.item.1.elem.name.as_str()) + .collect::>(); + error_not_found!("test", test_name, &tests) + } + } + } + fn find_enum(&self, enum_name: &str) -> Result> { match self.walk_enums().find(|e| e.name() == enum_name) { Some(e) => Ok(e), @@ -607,6 +651,28 @@ impl IRHelper for IntermediateRepr { } } + fn find_expr_fn<'a>(&'a self, function_name: &str) -> Result> { + let expr_fn_names = self + .walk_expr_fns() + .map(|f| f.item.elem.name.clone()) + .collect::>(); + match self + .walk_expr_fns() + .find(|f| f.item.elem.name == function_name) + { + Some(f) => Ok(f), + + None => { + // Get best match. + let functions = self + .walk_expr_fns() + .map(|f| f.item.elem.name.clone()) + .collect::>(); + error_not_found!("function", function_name, &functions) + } + } + } + fn find_client<'a>(&'a self, client_name: &str) -> Result> { match self.walk_clients().find(|c| c.name() == client_name) { Some(c) => Ok(c), @@ -856,12 +922,10 @@ impl IRHelper for IntermediateRepr { fn check_function_params<'a>( &'a self, - function: &'a FunctionWalker<'a>, + function_params: &Vec<(String, FieldType)>, params: &BamlMap, coerce_settings: ArgCoercer, ) -> Result { - let function_params = function.inputs(); - // Now check that all required parameters are present. let mut scope = ScopeStack::new(); let mut baml_arg_map = BamlMap::new(); @@ -1061,6 +1125,7 @@ pub fn item_type<'ir, 'a, T: std::fmt::Debug>( } } FieldType::Tuple(_) => None, + FieldType::Arrow(_) => None, FieldType::WithMetadata { base, .. } => item_type(ir, base, baml_child_values), }; res @@ -1097,6 +1162,7 @@ where variant_map_types.next() } FieldType::Class(_) => None, + FieldType::Arrow(_) => None, FieldType::WithMetadata { .. } => { unreachable!("distribute_metadata never returns this variant") } @@ -1489,7 +1555,7 @@ mod tests { span_path: None, allow_implicit_cast_to_string: true, }; - let res = ir.check_function_params(&function, ¶ms, arg_coercer); + let res = ir.check_function_params(&function.inputs(), ¶ms, arg_coercer); assert!(res.is_err()); } diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index ec4e0ddda..2fcfeb8d8 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -356,6 +356,10 @@ impl ArgCoercer { } } } + (FieldType::Arrow(_), _) => { + scope.push_error(format!("A json value may not be coerced into a function type")); + Err(()) + } (FieldType::WithMetadata { .. }, _) => { unreachable!("The return value of distribute_constraints can never be FieldType::Constrainted"); } diff --git a/engine/baml-lib/baml-core/src/ir/json_schema.rs b/engine/baml-lib/baml-core/src/ir/json_schema.rs index a7eb3a7e5..d5e8ee892 100644 --- a/engine/baml-lib/baml-core/src/ir/json_schema.rs +++ b/engine/baml-lib/baml-core/src/ir/json_schema.rs @@ -263,6 +263,7 @@ impl WithJsonSchema for FieldType { } } FieldType::WithMetadata { base, .. } => base.json_schema(), + FieldType::Arrow(_) => json!({}), // TODO: Make this function partial - it should not return for Arrow. } } } diff --git a/engine/baml-lib/baml-core/src/ir/mod.rs b/engine/baml-lib/baml-core/src/ir/mod.rs index a76ca5679..1e16cf532 100644 --- a/engine/baml-lib/baml-core/src/ir/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/mod.rs @@ -6,7 +6,7 @@ mod walker; pub use ir_helpers::{ scope_diagnostics, ArgCoercer, ClassFieldWalker, ClassWalker, ClientWalker, EnumValueWalker, - EnumWalker, FunctionWalker, IRHelper, IRHelperExtended, IRSemanticStreamingHelper, + EnumWalker, ExprFunctionWalker, FunctionWalker, IRHelper, IRHelperExtended, IRSemanticStreamingHelper, RetryPolicyWalker, TemplateStringWalker, TestCaseWalker, TypeAliasWalker, }; @@ -21,6 +21,7 @@ pub type Field = repr::Node; pub type FieldType = baml_types::FieldType; pub type TypeValue = baml_types::TypeValue; pub type FunctionNode = repr::Node; +pub type ExprFunctionNode = repr::Node; #[allow(dead_code)] pub(super) type Function = repr::Function; pub(super) type FunctionArgs = repr::FunctionArgs; diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index bc7c5ab45..978773bed 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -1,16 +1,21 @@ use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use anyhow::{anyhow, Result}; +use baml_types::BamlMap; use baml_types::{ - Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, - UnresolvedValue, + expr::{self, Expr, ExprMetadata, Name}, + Arrow, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, + StreamingBehavior, StringOr, TypeValue, UnresolvedValue, }; use either::Either; use indexmap::{IndexMap, IndexSet}; +use internal_baml_diagnostics::{Diagnostics, Span}; use internal_baml_parser_database::{ walkers::{ - ClassWalker, ClientWalker, ConfigurationWalker, EnumValueWalker, EnumWalker, FieldWalker, - FunctionWalker, TemplateStringWalker, TypeAliasWalker, Walker as AstWalker, + ClassWalker, ClientWalker, ConfigurationWalker, EnumValueWalker, EnumWalker, ExprFnWalker, + FieldWalker, FunctionWalker, TemplateStringWalker, TopLevelAssignmentWalker, + TypeAliasWalker, Walker as AstWalker, }, Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; @@ -22,6 +27,7 @@ use internal_baml_schema_ast::ast::{ use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty}; use serde::Serialize; +use crate::validate::validation_pipeline::validations::expr_typecheck::infer_types_in_context; use crate::Configuration; /// This class represents the intermediate representation of the BAML AST. @@ -33,7 +39,9 @@ pub struct IntermediateRepr { enums: Vec>, classes: Vec>, type_aliases: Vec>, - functions: Vec>, + pub functions: Vec>, + pub expr_fns: Vec>, + pub toplevel_assignments: Vec>, clients: Vec>, retry_policies: Vec>, template_strings: Vec>, @@ -50,6 +58,314 @@ pub struct IntermediateRepr { configuration: Configuration, } +#[derive(Debug)] +pub struct TopLevelAssignment { + pub name: Node, + pub expr: Node>, +} + +#[derive(Clone, Debug)] +pub struct ClassConstructor { + pub class_name: Node, + pub fields: Vec>, +} + +#[derive(Clone, Debug)] +pub enum ClassConstructorField { + Named(Node, Node>), + Spread(Node>), +} + +impl WithRepr for TopLevelAssignmentWalker<'_> { + fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { + // TODO: Add attributes. + NodeAttributes::default() + } + + fn repr(&self, db: &ParserDatabase) -> Result { + let name = self + .top_level_assignment() + .stmt + .identifier + .name() + .to_string(); + let expr = self.top_level_assignment().stmt.body.repr(db)?; + Ok(TopLevelAssignment { + name: Node { + elem: name, + attributes: NodeAttributes::default(), + }, + expr: Node { + elem: expr, + attributes: NodeAttributes::default(), + }, + }) + } +} + +impl WithRepr for ExprFnWalker<'_> { + fn repr(&self, db: &ParserDatabase) -> Result { + let body = convert_function_body(self.expr_fn().body.to_owned(), db)?; + let args: Vec<(String, FieldType)> = self + .expr_fn() + .args + .args + .iter() + .map(|(arg_name, arg_type)| { + arg_type + .field_type + .repr(db) + .map(|ty| (arg_name.to_string(), ty)) + }) + .collect::>>()?; + let arg_names = self + .expr_fn() + .args + .args + .iter() + .map(|(arg_name, _arg_type)| arg_name.to_string()) + .collect(); + let tests = self + .walk_tests() + .map(|e| e.node(db)) + .collect::>>()?; + let arg_types = args.iter().map(|(_, arg_type)| arg_type.clone()).collect(); + let return_type = self + .expr_fn() + .return_type + .clone() + .map(|ret| ret.repr(db)) + .transpose()? + .ok_or(anyhow::anyhow!( + "Expression functions must have a return type" + ))?; + let lambda_type = FieldType::Arrow(Box::new(Arrow { + param_types: arg_types, + return_type: return_type.clone(), + })); + let expr_fn = ExprFunction { + name: self.expr_fn().name.to_string(), + inputs: args, + output: return_type, + expr: Expr::Lambda( + arg_names, + Arc::new(body), + (self.expr_fn().span.clone(), Some(lambda_type)), + ), + tests, + }; + Ok(expr_fn) + } +} + +fn weird_default() -> FieldType { + FieldType::Primitive(TypeValue::Null) +} + +impl WithRepr for ExprFnWalker<'_> { + fn repr(&self, db: &ParserDatabase) -> Result { + // TODO: Drop weird default (replace by better validation). + let body = convert_function_body(self.expr_fn().body.to_owned(), db)?; + let args = self + .expr_fn() + .args + .args + .iter() + .map(|(arg_name, arg_type)| { + let ty = arg_type.field_type.repr(db)?; + Ok((arg_name.to_string(), ty)) + }) + .collect::>()?; + let return_ty = self + .expr_fn() + .return_type + .as_ref() + .ok_or(anyhow::anyhow!( + "Expression functions must have return type." + ))? + .repr(db)?; + let function = Function { + name: self.expr_fn().name.to_string(), + inputs: args, + output: return_ty, + configs: vec![], + default_config: "".to_string(), + tests: vec![], + }; + Ok(function) + } +} + +/// Convert a function body to an expression. +/// +/// The function body is a list of statements, which are let bindings. +/// We fold the let bindings into a single expression. +/// { +/// let x = 1; +/// let y = x; +/// y +/// } +/// => +/// Let "x" 1 (Let "y" x (y)) +fn convert_function_body( + function_body: ast::ExpressionBlock, + db: &ParserDatabase, +) -> Result> { + function_body.expr.repr(db).map(|fn_body| { + let expr = function_body + .stmts + .iter() + .fold(fn_body, |acc, stmt| match stmt.body.repr(db) { + Ok(stmt_expr) => Expr::Let( + stmt.identifier.name().to_string(), + Arc::new(stmt_expr), + Arc::new(acc), + (stmt.body.span().clone(), None), + ), + Err(e) => acc, + }); + expr + }) +} + +impl WithRepr> for ast::Expression { + fn repr(&self, db: &ParserDatabase) -> Result> { + match self { + ast::Expression::BoolValue(val, span) => Ok(Expr::Atom(BamlValueWithMeta::Bool( + *val, + (span.clone(), Some(FieldType::Primitive(TypeValue::Bool))), + ))), + ast::Expression::NumericValue(val, span) => val + .parse::() + .map(|v| { + Expr::Atom(BamlValueWithMeta::Int( + v, + (span.clone(), Some(FieldType::Primitive(TypeValue::Int))), + )) + }) + .or_else(|_| { + val.parse::() + .map(|v| { + Expr::Atom(BamlValueWithMeta::Float( + v, + (span.clone(), Some(FieldType::Primitive(TypeValue::Float))), + )) + }) + .or_else(|_| Err(anyhow!("Invalid numeric value: {}", val))) + }), + ast::Expression::StringValue(val, span) => Ok(Expr::Atom(BamlValueWithMeta::String( + val.to_string(), + (span.clone(), Some(FieldType::Primitive(TypeValue::String))), + ))), + ast::Expression::RawStringValue(val) => Ok(Expr::Atom(BamlValueWithMeta::String( + val.value().to_string(), + ( + val.span().clone(), + Some(FieldType::Primitive(TypeValue::String)), + ), + ))), + ast::Expression::JinjaExpressionValue(val, span) => { + Ok(Expr::Atom(BamlValueWithMeta::String( + val.to_string(), + (span.clone(), Some(FieldType::Primitive(TypeValue::String))), + ))) + } + ast::Expression::Array(vals, span) => { + let new_items = vals + .iter() + .map(|v| v.repr(db)) + .collect::>>()?; + let mut item_types = new_items + .iter() + .filter_map(|v| v.meta().1.clone()) + .collect::>(); + item_types.dedup(); + let item_type = match item_types.len() { + 0 => None, + 1 => Some(item_types[0].clone()), + _ => Some(FieldType::Union(item_types)), + }; + let list_type = item_type.map(|t| FieldType::List(Box::new(t))); + Ok(Expr::List(new_items, (span.clone(), list_type))) + } + ast::Expression::Map(vals, span) => { + let new_items = vals + .iter() + .map(|(k, v)| v.repr(db).map(|v2| (k.to_string(), v2))) + .collect::>>()?; + let mut item_types = new_items + .iter() + .filter_map(|v| v.1.meta().1.clone()) + .collect::>(); + item_types.dedup(); + let item_type = match item_types.len() { + 0 => None, + 1 => Some(item_types[0].clone()), + _ => Some(FieldType::Union(item_types)), + }; + // TODO: Is this correct? + let key_type = FieldType::Primitive(TypeValue::String); + let map_type = item_type.map(|t| FieldType::Map(Box::new(key_type), Box::new(t))); + Ok(Expr::Map(new_items, (span.clone(), map_type))) + } + ast::Expression::Identifier(id) => { + Ok(Expr::Var(id.name().to_string(), (id.span().clone(), None))) + } + + ast::Expression::Lambda(args, body, span) => { + let args = args + .arguments + .iter() + .filter_map(|arg| arg.value.as_string_value().map(|v| v.0.to_string())) + .collect(); + let body = convert_function_body(*body.to_owned(), db)?; + Ok(Expr::Lambda(args, Arc::new(body), (span.clone(), None))) + } + ast::Expression::FnApp(func, args, span) => { + let func = Expr::Var(func.name().to_string(), (func.span().clone(), None)); + let args = args.iter().map(|arg| arg.repr(db)).collect::>()?; + Ok(Expr::App( + Arc::new(func), + Arc::new(Expr::ArgsTuple(args, (span.clone(), None))), // TODO: We don't really have a span for the ArgsTuple, so we're using the one for the whole FnApp. + (span.clone(), None), + )) + } + ast::Expression::ClassConstructor( + ast::ClassConstructor { class_name, fields }, + span, + ) => { + let mut new_fields = BamlMap::new(); + let mut spread = None; + for f in fields { + match f { + ast::ClassConstructorField::Named(name, expr) => { + new_fields.insert(name.to_string(), expr.repr(db)?); + } + ast::ClassConstructorField::Spread(expr) => { + spread = Some(Box::new(expr.repr(db)?)); + } + } + } + Ok(Expr::ClassConstructor { + name: class_name.name().to_string(), + fields: new_fields, + spread, + meta: ( + span.clone(), + Some(FieldType::Class(class_name.name().to_string())), + ), + }) + } + ast::Expression::ExprBlock(block, span) => { + // We use "function_body" and "expr_block" interchangeably. + // This may need to be revisited? + let body = convert_function_body(block.clone(), db)?; + Ok(body) + } + } + } +} + /// A generic walker. Only walkers instantiated with a concrete ID type (`I`) are useful. #[derive(Clone, Copy)] pub struct Walker<'ir, I> { @@ -68,6 +384,8 @@ impl IntermediateRepr { finite_recursive_cycles: vec![], structural_recursive_alias_cycles: vec![], functions: vec![], + expr_fns: vec![], + toplevel_assignments: vec![], clients: vec![], retry_policies: vec![], template_strings: vec![], @@ -146,6 +464,31 @@ impl IntermediateRepr { self.functions.iter().map(|e| Walker { ir: self, item: e }) } + // TODO: This is a quick workaround in order to make expr_fns compatible + // with LLM functions for the purpose of listing functions and test + // cases in the playground. + pub fn expr_fns_as_functions(&self) -> Vec> { + self.expr_fns + .iter() + .map(|efn| Node { + elem: efn.elem.pretend_to_be_llm_function(), + attributes: efn.attributes.clone(), + }) + .collect::>() + } + + pub fn walk_toplevel_assignments( + &self, + ) -> impl ExactSizeIterator>> { + self.toplevel_assignments + .iter() + .map(|e| Walker { ir: self, item: e }) + } + + pub fn walk_expr_fns(&self) -> impl ExactSizeIterator>> { + self.expr_fns.iter().map(|e| Walker { ir: self, item: e }) + } + pub fn walk_tests( &self, ) -> impl Iterator, &Node)>> { @@ -223,6 +566,14 @@ impl IntermediateRepr { .walk_functions() .map(|e| e.node(db)) .collect::>>()?, + expr_fns: db + .walk_expr_fns() + .map(|e| e.node(db)) + .collect::>>()?, + toplevel_assignments: db + .walk_toplevel_assignments() + .map(|e| e.node(db)) + .collect::>>()?, clients: db .walk_clients() .map(|e| e.node(db)) @@ -345,7 +696,7 @@ impl IntermediateRepr { // [x] rename lockfile/mod.rs to ir/mod.rs // [x] wire Result<> type through, need this to be more sane -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct NodeAttributes { /// Map of attributes on the corresponding IR node. /// @@ -535,7 +886,7 @@ fn to_ir_attributes( } /// Nodes allow attaching metadata to a given IR entity: attributes, source location, etc -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Node { pub attributes: NodeAttributes, pub elem: T, @@ -721,7 +1072,12 @@ impl WithRepr for ast::FieldType { } } - None => return Err(anyhow!("Field type uses unresolvable local identifier")), + None => { + return Err(anyhow!( + "Field type uses unresolvable local identifier {}", + idn + )) + } }, arity, ), @@ -824,10 +1180,10 @@ impl WithRepr for TemplateStringWalker<'_> { } type EnumId = String; -#[derive(serde::Serialize, Debug)] +#[derive(Clone, serde::Serialize, Debug)] pub struct EnumValue(pub String); -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Enum { pub name: EnumId, pub values: Vec<(Node, Option)>, @@ -891,10 +1247,10 @@ impl WithRepr for EnumWalker<'_> { } } -#[derive(serde::Serialize, Debug)] +#[derive(Clone, serde::Serialize, Debug)] pub struct Docstring(pub String); -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Field { pub name: String, pub r#type: Node, @@ -936,7 +1292,7 @@ impl WithRepr for FieldWalker<'_> { type ClassId = String; /// A BAML Class. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Class { /// User defined class name. pub name: ClassId, @@ -1003,7 +1359,7 @@ impl Class { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct TypeAlias { pub name: String, pub r#type: Node, @@ -1121,11 +1477,169 @@ pub struct FunctionConfig { pub client: ClientSpec, } -// impl std::fmt::Display for ClientSpec { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// write!(f, "{}", self.as_str()) -// } -// } +#[derive(Clone, Debug)] +pub struct ExprFunction { + pub name: FunctionId, + pub inputs: Vec<(String, FieldType)>, + pub output: FieldType, + pub expr: Expr, + pub tests: Vec>, +} + +impl ExprFunction { + /// This is a temporary workaround for making expr_fns behave like llm_functions + /// for the purpose of listing functions and tests in the playground. + /// TODO: (Greg) handle different types of functions through separate paths. + pub fn pretend_to_be_llm_function(&self) -> Function { + Function { + name: self.name.clone(), + inputs: self.inputs.clone(), + output: self.output.clone(), + tests: self.tests.clone(), + configs: vec![], + default_config: "default_config".to_string(), + } + } + + pub fn inputs(&self) -> &Vec<(String, FieldType)> { + &self.inputs + } + + pub fn tests(&self) -> &Vec> { + &self.tests + } + + /// Traverse the function body adding type annotations to variables that + /// correspond to function parameters. + /// TODO: Make this capture-avoiding. + pub fn assign_param_types_to_body_variables(self) -> Self { + let new_expr = match &self.expr { + Expr::Lambda(params, body, meta) => { + let body = Arc::unwrap_or_clone(body.clone()); + let new_body = self.inputs.iter().fold(body, |body, (name, r#type)| { + annotate_variable(name, r#type.clone(), body) + }); + Expr::Lambda(params.clone(), Arc::new(new_body), meta.clone()) + } + // TODO: Handle other cases - traverse the tree. + // It seems like only Expr::Lambda is admissable as an ExprBody's expr field? + _ => self.expr, + }; + ExprFunction { + expr: new_expr, + ..self + } + } +} + +/// For all variables under an expression, assign them the given type. +/// TODO: This ignores scope completely. Make it capture-avoiding. +pub fn annotate_variable( + name: &str, + r#type: FieldType, + expr: Expr, +) -> Expr { + match &expr { + Expr::Var(var_name, meta) => { + let mut new_expr = expr.clone(); + if name == var_name { + new_expr.meta_mut().1 = Some(r#type); + } + new_expr + } + Expr::Lambda(params, body, meta) => { + if !params.contains(&name.to_string()) { + let new_body = + annotate_variable(name, r#type.clone(), Arc::unwrap_or_clone(body.clone())); + // new_expr = annotate_variable(name, r#type.clone(), body); + Expr::Lambda(params.clone(), Arc::new(new_body), meta.clone()) + } else { + expr + } + } + Expr::App(f, args, meta) => { + let new_f = annotate_variable(name, r#type.clone(), Arc::unwrap_or_clone(f.clone())); + let new_args = annotate_variable(name, r#type, Arc::unwrap_or_clone(args.clone())); + Expr::App(Arc::new(new_f), Arc::new(new_args), meta.clone()) + } + Expr::Let(var_name, expr, body, meta) => { + let new_binding = + annotate_variable(name, r#type.clone(), Arc::unwrap_or_clone(expr.clone())); + let new_body = if var_name != name { + Arc::new(annotate_variable( + name, + r#type.clone(), + Arc::unwrap_or_clone(body.clone()), + )) + } else { + body.clone() + }; + Expr::Let( + var_name.clone(), + Arc::new(new_binding), + new_body, + meta.clone(), + ) + } + Expr::ArgsTuple(args, meta) => Expr::ArgsTuple( + args.iter() + .map(|arg| annotate_variable(name, r#type.clone(), arg.clone())) + .collect(), + meta.clone(), + ), + Expr::Atom(_) => expr, + Expr::LLMFunction(_, _, _) => expr, + Expr::ClassConstructor { + name, + fields, + spread, + meta, + } => { + let new_fields = fields + .iter() + .map(|(key, value)| { + ( + key.clone(), + annotate_variable(name, r#type.clone(), value.clone()), + ) + }) + .collect(); + let new_spread = match spread { + None => None, + Some(expr) => Some(Box::new(annotate_variable( + name, + r#type.clone(), + expr.as_ref().clone(), + ))), + }; + Expr::ClassConstructor { + name: name.clone(), + fields: new_fields, + spread: new_spread, + meta: meta.clone(), + } + } + Expr::Map(entries, meta) => { + let new_entries = entries + .iter() + .map(|(key, value)| { + ( + key.clone(), + annotate_variable(name, r#type.clone(), value.clone()), + ) + }) + .collect(); + Expr::Map(new_entries, meta.clone()) + } + Expr::List(items, meta) => { + let new_items = items + .iter() + .map(|item| annotate_variable(name, r#type.clone(), item.clone())) + .collect(); + Expr::List(new_items, meta.clone()) + } + } +} fn process_field( overrides: &IndexMap<(String, String), IndexMap>>, // Adjust the type according to your actual field type @@ -1312,7 +1826,7 @@ impl WithRepr for ConfigurationWalker<'_> { } // TODO: #1343 Temporary solution until we implement scoping in the AST. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum TypeBuilderEntry { Enum(Node), Class(Node), @@ -1320,14 +1834,14 @@ pub enum TypeBuilderEntry { } // TODO: #1343 Temporary solution until we implement scoping in the AST. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct TestTypeBuilder { pub entries: Vec, pub recursive_classes: Vec>, pub recursive_aliases: Vec>, } -#[derive(serde::Serialize, Debug)] +#[derive(Clone, serde::Serialize, Debug)] pub struct TestCaseFunction(String); impl TestCaseFunction { @@ -1336,7 +1850,7 @@ impl TestCaseFunction { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct TestCase { pub name: String, pub functions: Vec>, @@ -1478,6 +1992,23 @@ impl WithRepr for PromptAst<'_> { /// Generate an IntermediateRepr from a single block of BAML source code. /// This is useful for generating IR test fixtures. pub fn make_test_ir(source_code: &str) -> anyhow::Result { + let (ir, diagnostics) = make_test_ir_and_diagnostics(source_code)?; + if diagnostics.has_errors() { + return Err(anyhow::anyhow!( + "Source code was invalid: \n{:?}", + diagnostics.errors() + )); + } else { + Ok(ir) + } +} + +/// Generate an IntermediateRepr from a single block of BAML source code. +/// This is useful for generating IR test fixtures. Also return the +/// `Diagnostics`. +pub fn make_test_ir_and_diagnostics( + source_code: &str, +) -> anyhow::Result<(IntermediateRepr, Diagnostics)> { use crate::validate; use crate::ValidatedSchema; use internal_baml_diagnostics::SourceFile; @@ -1486,18 +2017,12 @@ pub fn make_test_ir(source_code: &str) -> anyhow::Result { let path: PathBuf = "fake_file.baml".into(); let source_file: SourceFile = (path.clone(), source_code).into(); let validated_schema: ValidatedSchema = validate(&path, vec![source_file]); - let diagnostics = &validated_schema.diagnostics; - if diagnostics.has_errors() { - return Err(anyhow::anyhow!( - "Source code was invalid: \n{:?}", - diagnostics.errors() - )); - } + let diagnostics = validated_schema.diagnostics; let ir = IntermediateRepr::from_parser_database( &validated_schema.db, validated_schema.configuration, )?; - Ok(ir) + Ok((ir, diagnostics)) } /// Pull out `StreamingBehavior` from `NodeAttributes`. @@ -1514,6 +2039,59 @@ fn streaming_behavior_from_attributes(attributes: &NodeAttributes) -> StreamingB } } +/// Create a context from the expr_functions, top_level_assignments, and +/// functions in the IR. +/// This context is used in evaluating expressions. +pub fn initial_context(ir: &IntermediateRepr) -> HashMap> { + let mut ctx = HashMap::new(); + + for expr_fn in ir.expr_fns.iter() { + ctx.insert(expr_fn.elem.name.clone(), expr_fn.elem.expr.clone()); + } + for top_level_assignment in ir.toplevel_assignments.iter() { + ctx.insert( + top_level_assignment.elem.name.elem.clone(), + top_level_assignment.elem.expr.elem.clone(), + ); + } + for llm_function in ir.functions.iter() { + let params = llm_function + .elem + .inputs + .iter() + .map(|arg| arg.0.clone()) + .collect::>(); + let params_type: Vec = llm_function + .elem + .inputs + .iter() + .map(|arg| arg.1.clone()) + .collect::>(); + let body_type = llm_function.elem.output.clone(); + let lambda_type = FieldType::Arrow(Box::new(Arrow { + param_types: params_type, + return_type: body_type, + })); + ctx.insert( + llm_function.elem.name.clone(), + Expr::LLMFunction( + llm_function.elem.name.clone(), + params, + ( + llm_function + .attributes + .span + .as_ref() + .expect("LLM Functions have spans until we use dynamic types") + .clone(), + Some(lambda_type), + ), + ), + ); + } + ctx +} + #[cfg(test)] mod tests { use super::*; @@ -1721,4 +2299,28 @@ mod tests { assert_eq!(constraints[2].level, ConstraintLevel::Check); assert_eq!(constraints[2].label, Some("gt_ten".to_string())); } + + #[test] + fn test_expr_fn_tests() { + let ir = make_test_ir( + r##" + fn Foo(x: int) -> int { + x + } + + test FooTest { + functions [Foo] + args { + x 1 + } + } + "##, + ) + .unwrap(); + + let function = ir.find_expr_fn("Foo").unwrap(); + let test = ir.find_expr_fn_test(&function, "FooTest").unwrap(); + assert_eq!(test.item.1.elem.functions.len(), 1); + assert_eq!(test.item.1.elem.functions[0].elem.name(), "Foo"); + } } diff --git a/engine/baml-lib/baml-core/src/ir/walker.rs b/engine/baml-lib/baml-core/src/ir/walker.rs index 83c7b0eb6..f6c4fb9bf 100644 --- a/engine/baml-lib/baml-core/src/ir/walker.rs +++ b/engine/baml-lib/baml-core/src/ir/walker.rs @@ -10,12 +10,40 @@ use internal_llm_client::ClientSpec; use std::collections::HashSet; use super::{ - repr::{self, FunctionConfig, TypeBuilderEntry, WithRepr}, - Class, Client, Enum, EnumValue, Field, FieldType, FunctionNode, IRHelper, Impl, RetryPolicy, - TemplateString, TestCase, TypeAlias, Walker, + repr::{self, ExprFunction, FunctionConfig, TypeBuilderEntry, WithRepr}, Class, Client, Enum, EnumValue, ExprFunctionNode, Field, FieldType, FunctionNode, IRHelper, Impl, RetryPolicy, TemplateString, TestCase, TypeAlias, Walker }; use crate::ir::jinja_helpers::render_expression; +impl<'a> Walker<'a, &'a ExprFunctionNode> { + pub fn name(&self) -> &'a str { + self.elem().name.as_str() + } + + pub fn inputs(&self) -> &'a Vec<(String, baml_types::FieldType)> { + self.elem().inputs() + } + + pub fn walk_tests( + &'a self, + ) -> impl Iterator> { + self.elem().tests().iter().map(|i| Walker { + ir: self.ir, + item: (self.item, i), + }) + } + + pub fn elem(&self) -> &'a repr::ExprFunction { + &self.item.elem + } + + pub fn find_test( + &'a self, + test_name: &str, + ) -> Option>{ + self.walk_tests().find(|t| t.item.1.elem.name == test_name) + } +} + impl<'a> Walker<'a, &'a FunctionNode> { pub fn name(&self) -> &'a str { self.elem().name() @@ -192,6 +220,42 @@ impl<'a> Walker<'a, (&'a FunctionNode, &'a Impl)> { pub fn elem(&self) -> &'a repr::Implementation { &self.item.1.elem } + +} + +impl<'a> Walker<'a, (&'a ExprFunctionNode, &'a TestCase )> { + pub fn matches(&self, function_name: &str, test_name: &str) -> bool { + self.item.0.elem.name == function_name && self.item.1.elem.name == test_name + } + + pub fn name(&self) -> String { + format!("{}::{}", self.item.0.elem.name, self.item.1.elem.name) + } + + pub fn args(&self) -> &IndexMap> { + &self.item.1.elem.args + } + + pub fn test_case(&self) -> &repr::TestCase { + &self.item.1.elem + } + + pub fn span(&self) -> Option<&crate::Span> { + self.item.1.attributes.span.as_ref() + } + + pub fn test_case_params( + &self, + ctx: &EvaluationContext<'_>, + ) -> Result>> { + self.args() + .iter() + .map(|(k, v)| Ok((k.clone(), v.resolve_serde::(ctx)))) + .collect() + } + + + } impl<'a> Walker<'a, (&'a FunctionNode, &'a TestCase)> { diff --git a/engine/baml-lib/baml-core/src/validate/mod.rs b/engine/baml-lib/baml-core/src/validate/mod.rs index a77c317a3..4e910c2c2 100644 --- a/engine/baml-lib/baml-core/src/validate/mod.rs +++ b/engine/baml-lib/baml-core/src/validate/mod.rs @@ -1,4 +1,4 @@ pub(crate) mod generator_loader; -mod validation_pipeline; +pub(crate) mod validation_pipeline; pub(crate) use validation_pipeline::validate; diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline.rs index 6483e989d..888355009 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline.rs @@ -1,5 +1,5 @@ mod context; -mod validations; +pub mod validations; use crate::{internal_baml_diagnostics::Diagnostics, PreviewFeature}; use enumflags2::BitFlags; diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs index f504d5184..de4232ffe 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs @@ -3,6 +3,8 @@ mod clients; mod configurations; mod cycle; mod enums; +mod expr_fns; +pub mod expr_typecheck; mod functions; mod template_strings; mod tests; @@ -35,6 +37,10 @@ pub(super) fn validate(ctx: &mut Context<'_>) { .collect::>(); classes::assert_no_field_name_collisions(ctx, &codegen_targets); + expr_fns::validate_expr_fns(ctx); + + let _ = expr_typecheck::typecheck_exprs(ctx); + if !ctx.diagnostics.has_errors() { cycle::validate(ctx); } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/configurations.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/configurations.rs index 9876e138f..73f62d05d 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/configurations.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/configurations.rs @@ -29,11 +29,23 @@ pub(super) fn validate(ctx: &mut Context<'_>) { // TODO: Check args. } None => { - ctx.push_warning(DatamodelWarning::new_type_not_found_error( - name, - ctx.db.valid_function_names(), - s.clone(), - )); + let expr_fns = ctx + .db + .walk_expr_fns() + .filter(|f| f.name() == name) + .collect::>(); + match ctx.db.find_expr_fn_by_name(name) { + Some(_f) => { + // TODO: Check args. + } + None => { + ctx.push_warning(DatamodelWarning::new_type_not_found_error( + name, + ctx.db.valid_function_names(), + s.clone(), + )); + } + } } }); } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_fns.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_fns.rs new file mode 100644 index 000000000..64ea391b6 --- /dev/null +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_fns.rs @@ -0,0 +1,158 @@ +use itertools::Itertools; +use std::collections::HashSet; + +use internal_baml_diagnostics::DatamodelError; +// use internal_baml_schema_ast::ast::expr; +use internal_baml_schema_ast::ast::{ClassConstructor, ClassConstructorField, Expression, Stmt}; +use internal_baml_schema_ast::ast::{WithName, WithSpan}; + +use crate::validate::validation_pipeline::context::Context; + +// An expr_fn is valid if: +// - Its arguments have valid types. +// - Its return type is valid. +// - Its body is a valid function body (series of statements ending in an +// expression). Bodies are valid if they refer only to variables defined +// in the argument list and in the current scope. +// - It does not share a name with any other expr_fn or LLM function. +pub(super) fn validate_expr_fns(ctx: &mut Context<'_>) { + let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default( + internal_baml_jinja_types::JinjaContext::Prompt, + ); + + let mut taken_names = std::collections::HashSet::new(); + ctx.db.walk_classes().for_each(|class| { + class.add_to_types(&mut defined_types); + taken_names.insert(class.name().to_owned()); + }); + ctx.db.walk_toplevel_assignments().for_each(|assignment| { + taken_names.insert(assignment.name().to_owned()); + }); + ctx.db.walk_functions().for_each(|function| { + taken_names.insert(function.name().to_owned()); + }); + + for expr_fn in ctx.db.walk_expr_fns() { + if taken_names.contains(expr_fn.name()) { + ctx.push_error(DatamodelError::new_validation_error( + "Expr function name must be unique", + expr_fn.name_span().clone(), + )); + } + taken_names.insert(expr_fn.name().to_owned()); + } + + for expr_fn in ctx.db.walk_expr_fns() { + let mut scope: HashSet = expr_fn + .expr_fn() + .args + .args + .iter() + .map(|(arg_name, _arg)| arg_name.to_string()) + .collect(); + + scope.extend(taken_names.iter().cloned()); + expr_fn.expr_fn().body.stmts.iter().for_each(|s| { + validate_stmt(ctx, s, &scope); + scope.insert(s.identifier.name().to_string()); + }); + validate_expression(ctx, &expr_fn.expr_fn().body.expr, &scope); + } + + for toplevel_assignment in ctx.db.walk_toplevel_assignments() { + let scope: HashSet = taken_names.clone(); + validate_stmt( + ctx, + &toplevel_assignment.top_level_assignment().stmt, + &scope, + ); + } +} + +fn validate_stmt(ctx: &mut Context<'_>, stmt: &Stmt, scope: &HashSet) { + validate_expression(ctx, &stmt.body, scope); +} + +fn validate_expression(ctx: &mut Context<'_>, expr: &Expression, scope: &HashSet) { + match &expr { + Expression::Identifier(identifier) => { + if !scope.contains(&identifier.to_string()) { + ctx.push_error(DatamodelError::new_anyhow_error( + anyhow::anyhow!("Unknown variable {}", &identifier.to_string()), + identifier.span().clone(), + )); + } + } + Expression::Lambda(_args, _body, _span) => {} + Expression::FnApp(fn_name, args, span) => { + // Validate the function name. + if !scope.contains(&fn_name.to_string()) { + ctx.push_error(DatamodelError::new_anyhow_error( + anyhow::anyhow!("Unknown function {}", &fn_name.to_string()), + span.clone(), + )); + } + for arg in args { + validate_expression(ctx, arg, scope); + } + } + Expression::Array(items, span) => { + for item in items { + validate_expression(ctx, item, scope); + } + } + Expression::Map(fields, span) => { + for (_key, value) in fields { + validate_expression(ctx, value, scope); + } + } + Expression::BoolValue(_, span) => {} + Expression::StringValue(_, _) => {} + Expression::NumericValue(_, _) => {} + Expression::RawStringValue(_) => {} + Expression::JinjaExpressionValue(_, _) => {} + Expression::ClassConstructor(cc, span) => { + let fields = cc.fields.clone(); + + if fields.iter().len() + != fields + .iter() + .map(|f| format!("{:?}", f)) + .dedup() + .collect::>() + .len() + { + ctx.push_error(DatamodelError::new_validation_error( + "Class constructor fields must be unique", + span.clone(), + )); + } + + let field_names = cc + .fields + .iter() + .filter_map(|field| match field { + ClassConstructorField::Named(name, _) => Some(name.to_string()), + ClassConstructorField::Spread(_) => None, + }) + .collect::>(); + + for field in cc.fields.iter() { + match field { + ClassConstructorField::Named(field_name, value) => {} + ClassConstructorField::Spread(expr) => { + validate_expression(ctx, expr, scope); + } + } + } + } + Expression::ExprBlock(block, span) => { + let mut scope = scope.clone(); + for stmt in block.stmts.iter() { + validate_stmt(ctx, stmt, &mut scope); + scope.insert(stmt.identifier.name().to_string()); + } + validate_expression(ctx, &block.expr, &scope); + } + } +} diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_typecheck.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_typecheck.rs new file mode 100644 index 000000000..558591891 --- /dev/null +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/expr_typecheck.rs @@ -0,0 +1,500 @@ +use anyhow::Result; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::ir::IntermediateRepr; +use crate::ir::{repr::initial_context, IRHelper}; +use crate::validate::validation_pipeline::context::Context; +use crate::Configuration; +use baml_types::{ + expr::{Expr, ExprMetadata}, + Arrow, BamlValueWithMeta, FieldType, +}; +use internal_baml_diagnostics::{DatamodelError, Diagnostics, Span}; + +use crate::ir::IRHelperExtended; + +pub fn typecheck_exprs(ctx: &mut Context<'_>) -> Result<()> { + let null_configuration = Configuration::new(); + if let Ok(ir) = IntermediateRepr::from_parser_database(ctx.db, null_configuration) { + let mut typing_context: HashMap = ir + .expr_fns + .iter() + .map(|expr_fn| { + ( + expr_fn.elem.name.clone(), + FieldType::Arrow(Box::new(Arrow { + param_types: expr_fn.elem.inputs.iter().map(|(_, t)| t.clone()).collect(), + return_type: expr_fn.elem.output.clone(), + })), + ) + }) + .chain(ir.functions.iter().map(|llm_function| { + ( + llm_function.elem.name.clone(), + FieldType::Arrow(Box::new(Arrow { + param_types: llm_function + .elem + .inputs + .iter() + .map(|(_, t)| t.clone()) + .collect(), + return_type: llm_function.elem.output.clone(), + })), + ) + })) + .collect(); + + for expr_fn in ir.expr_fns.iter() { + let expr_fn_with_types = infer_types_in_context( + &mut typing_context, + Arc::new( + expr_fn + .elem + .clone() + .assign_param_types_to_body_variables() + .expr + .clone(), + ), + ); + typecheck_in_context( + &ir, + &mut ctx.diagnostics, + &typing_context, + &expr_fn_with_types, + )?; + } + } + Ok(()) +} + +pub fn typecheck_in_context( + ir: &IntermediateRepr, + diagnostics: &mut Diagnostics, + typing_context: &HashMap, + expr: &Expr, +) -> Result<()> { + match expr { + Expr::Atom(atom) => { + // Atoms always typecheck. + Ok(()) + } + Expr::LLMFunction(llm_function, args, _) => { + // Bare functions always typecheck. + Ok(()) + } + Expr::Var(var, (var_span, maybe_type)) => { + if let Some(var_type) = maybe_type { + if let Some(ctx_type) = typing_context.get(var) { + if ir.is_subtype(&ctx_type, var_type) { + Ok(()) + } else { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch", + var_span.clone(), + )); + Ok(()) + } + } else { + Ok(()) + } + } else { + Ok(()) + } + } + Expr::Lambda(param_names, body, (span, maybe_type)) => { + // (\(x,y) -> x + y) : (Int,Int) -> Int + if let Some(FieldType::Arrow(arrow)) = maybe_type { + let mut inner_context = typing_context.clone(); + for (param_type, param_name) in arrow.param_types.iter().zip(param_names.iter()) { + inner_context.insert(param_name.to_string(), param_type.clone()); + } + if !compatible_as_subtype(ir, &body.meta().1, &Some(arrow.return_type.clone())) { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!( + "Type mismatch in lambda: {} vs {}", + body.meta() + .1 + .as_ref() + .map_or("?".to_string(), |t| t.to_string()), + arrow.return_type.to_string() + ), + body.meta().0.clone(), + )); + } + typecheck_in_context(ir, diagnostics, &inner_context, body)?; + } + Ok(()) + } + // (\[x,y] -> x + y) (1,2) + // ([Int,Int] -> Int) ([Int,Int] + Expr::App(f, xs, (span, maybe_app_type)) => { + // First check that the param types are compatible with the arguments. + match (&f.meta().1, xs.as_ref()) { + (Some(FieldType::Arrow(arrow)), Expr::ArgsTuple(args, _)) => { + for (param_type, arg) in arrow.param_types.iter().zip(args.iter()) { + if !compatible_as_subtype(ir, &arg.meta().1, &Some(param_type.clone())) { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch in app", + span.clone(), + )); + } + } + } + x => { + eprintln!("TYPECHECKING APP: UNEXPECTED ARGS: {:?}", x); + } + } + Ok(()) + + // TODO: What was this? Bring it back? + // match (f.as_ref(), xs.as_ref(), maybe_app_type) { + // ( + // _, // Expr::Lambda(params, body, (lambda_span, _)), + // Expr::ArgsTuple(args, (args_span, args_type)), + // _, + // ) => { + // // First, check that the arguments are the right type + // // for the lambda. + // let maybe_lambda_type= &f.meta().1; + // eprintln!("LAMBDA_TYPE: {:?}", maybe_lambda_type); + // if let Some(lambda_type) = maybe_lambda_type { + // eprintln!("checking lambda_type: {:?}", lambda_type); + // match lambda_type { + // ExprType::Arrow(arrow) => { + // if let Some(app_type) = maybe_app_type { + + // if !compatible_as_subtype( + // ir, + // &maybe_app_type, + // &Some(arrow.body_type.clone()), + // ) { + // eprintln!( + // "C Type mismatch in app: {} vs {}", + // app_type.dump_str(), + // arrow.body_type.dump_str() + // ); + // diagnostics.push_error(DatamodelError::new_validation_error( + // &format!( + // "D Type mismatch in app: {} vs {}", + // app_type.dump_str(), + // arrow.body_type.dump_str() + // ), + // span.clone(), + // )); + // } + // } + // for (param_type, arg) in arrow.param_types.iter().zip(args.iter()) { + // eprintln!("TYPECHECKING APP COMPARING PARAMTYPE: {:?} vs ARG: {:?}", param_type, arg); + // if !compatible_as_subtype( + // ir, + // &arg.meta().1, + // &Some(param_type.clone()), + // ) { + // eprintln!( + // "E Type mismatch in app: {} vs {}", + // arg.meta() + // .1 + // .as_ref() + // .map_or("?".to_string(), |t| t.dump_str()), + // param_type.dump_str() + // ); + // diagnostics.push_error( + // DatamodelError::new_validation_error( + // &format!( + // "F Type mismatch in app: {} vs {}", + // arg.meta() + // .1 + // .as_ref() + // .map_or("?".to_string(), |t| t.dump_str()), + // param_type.dump_str() + // ), + // span.clone(), + // ), + // ); + // } + // } + // } + // ExprType::Atom(_) => { + // diagnostics.push_error(DatamodelError::new_validation_error( + // "Expected a function type", + // span.clone(), + // )); + // } + // } + // } + + // typecheck_in_context(ir, diagnostics, &inner_context, body)?; + + // Ok(()) + // } + // _ => Ok(()), + // } + // Applications typecheck if the function arguments + } + Expr::Let(let_expr, _, _, _) => Ok(()), + Expr::ArgsTuple(args, _) => Ok(()), + Expr::List(items, meta) => { + for item in items.iter() { + if let Some(item_type) = item.meta().1.as_ref() { + let item_list_type = FieldType::List(Box::new(item_type.clone())); + if !compatible_as_subtype(ir, &Some(item_list_type), &meta.1.clone()) { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch in list", + meta.0.clone(), + )); + } + } + typecheck_in_context(ir, diagnostics, typing_context, item)?; + } + Ok(()) + } + Expr::Map(items, meta) => { + if let Some(map_type) = meta.1.as_ref() { + if let Some((key_type, item_type)) = match map_type { + FieldType::Map(key_type, item_type) => Some((key_type, item_type)), + _ => None, + } { + for (_key, item) in items.iter() { + if let Some(item_type) = item.meta().1.as_ref() { + let item_map_type = + FieldType::Map(key_type.clone(), Box::new(item_type.clone())); + if !compatible_as_subtype(ir, &Some(item_map_type), &meta.1.clone()) { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch in map", + meta.0.clone(), + )); + } + } + typecheck_in_context(ir, diagnostics, typing_context, item)?; + } + } else { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch in map", + meta.0.clone(), + )); + } + } + Ok(()) + } + Expr::ClassConstructor { + name, + fields, + spread, + meta, + } => { + if let Ok(class_walker) = ir.find_class(name) { + for (field_name, field_value) in fields.iter() { + let maybe_field_type = field_value.meta().1.clone(); + if let Some(field_type) = maybe_field_type { + if let Some(field_walker) = class_walker.find_field(field_name) { + if !compatible_as_subtype( + ir, + &Some(field_walker.r#type().clone()), + &Some(field_type), + ) { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch in class constructor", + meta.0.clone(), + )); + } + } + } + } + } + let spread_type = spread.as_ref().and_then(|s| s.meta().1.clone()); + if !compatible_as_subtype(ir, &meta.1, &spread_type) { + diagnostics.push_error(DatamodelError::new_validation_error( + "Type mismatch in class constructor", + meta.0.clone(), + )); + } + Ok(()) + } + } +} + +// fn is_subtype(ir: &IntermediateRepr, a: &ExprType, b: &ExprType) -> bool { +// match (a, b) { +// (ExprType::Atom(a), ExprType::Atom(b)) => ir.is_subtype(a, b), +// (ExprType::Arrow(a), ExprType::Arrow(b)) => { +// let a_arrow = a.as_ref(); +// let b_arrow = b.as_ref(); +// let return_type_ok = is_subtype(ir, &a_arrow.body_type, &b_arrow.body_type); +// let arg_types_ok = a_arrow +// .param_types +// .iter() +// .zip(b_arrow.param_types.iter()) +// .all(|(a, b)| is_subtype(ir, b, a)); +// return_type_ok && arg_types_ok +// } +// _ => false, +// } +// } + +fn compatible_as_subtype( + ir: &IntermediateRepr, + a: &Option, + b: &Option, +) -> bool { + match (a, b) { + (Some(a), Some(b)) => ir.is_subtype(a, b), + _ => true, + } +} + +pub fn infer_types_in_context( + typing_context: &mut HashMap, + expr: Arc>, +) -> Arc> { + match expr.as_ref() { + Expr::Var(ref var_name, (span, maybe_type)) => { + // Assign variables from the context. + if let Some(ctx_ty) = typing_context.get(var_name) { + Arc::new(Expr::Var( + var_name.clone(), + (span.clone(), Some(ctx_ty.clone())), + )) + } else { + // Otherwise, and if we know the type, add it to the context. + if let Some(var_ty) = &expr.meta().1 { + typing_context.insert(var_name.to_string(), var_ty.clone()); + } + expr.clone() + } + } + Expr::Atom(_) => { + // All atoms are typed during parsing, so we ignore them. + expr.clone() + } + Expr::Let(ref var_name, expr, ref body, _) => { + let new_expr = infer_types_in_context(typing_context, expr.clone()); + let new_body = infer_types_in_context(typing_context, body.clone()); + if let Some(ref expr_ty) = new_expr.meta().1 { + typing_context.insert(var_name.to_string(), expr_ty.clone()); + } + let new_meta = (expr.meta().0.clone(), new_body.meta().1.clone()); + Arc::new(Expr::Let(var_name.clone(), new_expr, new_body, new_meta)) + } + Expr::App(f, args, (span, maybe_app_type)) => { + // Infer the type of an App from the return type of the function, if + // it is a function with a known return type. + let new_f = infer_types_in_context(typing_context, f.clone()); + let new_args = infer_types_in_context(typing_context, args.clone()); + let new_app_type = match &new_f.meta().1 { + Some(FieldType::Arrow(arrow)) => Some(arrow.return_type.clone()), + ty => None, + } + .or(maybe_app_type.clone()); + let new_meta = (span.clone(), new_app_type); + Arc::new(Expr::App(new_f, new_args, new_meta)) + } + Expr::ArgsTuple(ref args, _) => { + let new_args = args + .iter() + .map(|arg| { + Arc::unwrap_or_clone(infer_types_in_context( + typing_context, + Arc::new(arg.clone()), + )) + }) + .collect(); + Arc::new(Expr::ArgsTuple( + new_args, + (expr.meta().0.clone(), expr.meta().1.clone()), + )) + } + Expr::Lambda(param_names, body, (span, maybe_type)) => { + let mut local_typing_context = typing_context.clone(); + if let Some(FieldType::Arrow(arrow)) = maybe_type { + for (param_type, param_name) in arrow.param_types.iter().zip(param_names.iter()) { + local_typing_context.insert(param_name.to_string(), param_type.clone()); + } + } + let new_body = infer_types_in_context(&mut local_typing_context, body.clone()); + Arc::new(Expr::Lambda( + param_names.clone(), + new_body, + (span.clone(), maybe_type.clone()), + )) + } + _ => expr.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::repr::make_test_ir_and_diagnostics; + + #[test] + fn null_case() { + let (ir, diagnostics) = make_test_ir_and_diagnostics( + r##" + fn First(x: int, y: int) -> int { + x + } + "##, + ) + .expect("Valid source"); + assert!(!diagnostics.has_errors()); + } + + #[test] + fn param_body_mismatch() { + let (ir, diagnostics) = make_test_ir_and_diagnostics( + r##" + fn First(x: int, y: int) -> string { + x + } + "##, + ) + .expect("Valid source"); + assert!(diagnostics.has_errors()); + } + + #[test] + fn application_mismatch() { + let (ir, diagnostics) = make_test_ir_and_diagnostics( + r##" + fn First(x: int, y: int) -> int { + Inner(x) + } + + fn Inner(x: string) -> int { + 5 + } + "##, + ) + .expect("Valid source"); + assert!(diagnostics.has_errors()); + } + + #[test] + fn multiple_calls() { + let (ir, diagnostics) = make_test_ir_and_diagnostics( + r##" + fn Compare(x: string, y: string) -> int { + 1 + } + + fn MkPoem1(x: string) -> string { + "Pretty" + } + + fn MkPoem2(x: string) -> string { + "Poem" + } + + fn Go(x: string) -> int { + let poem1 = MkPoem1(x); + let poem2 = MkPoem2(x); + Compare(poem1, poem2) + } + "##, + ) + .expect("Valid source"); + dbg!(&diagnostics); + assert!(!diagnostics.has_errors()); + } +} diff --git a/engine/baml-lib/baml-types/Cargo.toml b/engine/baml-lib/baml-types/Cargo.toml index 068a521e5..da1494561 100644 --- a/engine/baml-lib/baml-types/Cargo.toml +++ b/engine/baml-lib/baml-types/Cargo.toml @@ -19,6 +19,7 @@ anyhow.workspace = true clap.workspace = true derive_builder.workspace = true itertools = "0.14.0" +internal-baml-diagnostics = { path = "../diagnostics" } log.workspace = true minijinja.workspace = true once_cell.workspace = true diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index 5549bd28d..c97a5783c 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -549,6 +549,44 @@ impl BamlValueWithMeta { } } + /// Apply the same meta value to every node throughout a BamlValue. + pub fn with_const_meta(value: &BamlValue, meta: T) -> BamlValueWithMeta + where + T: Clone, + { + use BamlValueWithMeta::*; + match value { + BamlValue::String(s) => String(s.clone(), meta), + BamlValue::Int(i) => Int(*i, meta), + BamlValue::Float(f) => Float(*f, meta), + BamlValue::Bool(b) => Bool(*b, meta), + BamlValue::Map(entries) => BamlValueWithMeta::Map( + entries + .iter() + .map(|(k, v)| (k.clone(), Self::with_const_meta(v, meta.clone()))) + .collect(), + meta, + ), + BamlValue::List(items) => List( + items + .iter() + .map(|i| Self::with_const_meta(i, meta.clone())) + .collect(), + meta, + ), + BamlValue::Media(m) => Media(m.clone(), meta), + BamlValue::Enum(n, v) => Enum(n.clone(), v.clone(), meta), + BamlValue::Class(_, items) => Map( + items + .iter() + .map(|(k, v)| (k.clone(), Self::with_const_meta(v, meta.clone()))) + .collect(), + meta, + ), + BamlValue::Null => Null(meta), + } + } + pub fn map_meta<'a, F, U>(&'a self, f: F) -> BamlValueWithMeta where F: Fn(&'a T) -> U + Copy, diff --git a/engine/baml-lib/baml-types/src/expr.rs b/engine/baml-lib/baml-types/src/expr.rs new file mode 100644 index 000000000..afd35fbd8 --- /dev/null +++ b/engine/baml-lib/baml-types/src/expr.rs @@ -0,0 +1,286 @@ +// use moniker::{Binder, BoundTerm, Scope, Var}; +use std::sync::Arc; + +use crate::{field_type::FieldType, BamlMap, BamlValueWithMeta}; +use internal_baml_diagnostics::Span; +use itertools::join; + +pub type Name = String; + +/// A BAML expression term. +/// T is the type of the metadata. +#[derive(Debug, Clone)] +pub enum Expr { + Atom(BamlValueWithMeta), + List(Vec>, T), + Map(BamlMap>, T), + ClassConstructor { + name: String, + fields: BamlMap>, + spread: Option>>, + meta: T, + }, + + LLMFunction(Name, Vec, T), + Var(Name, T), + Lambda(Vec, Arc>, T), + App(Arc>, Arc>, T), + Let(Name, Arc>, Arc>, T), // let name = expr in body + ArgsTuple(Vec>, T), +} + +/// The metadata used during parsing, typechecking and evaluation of BAML expressions. +pub type ExprMetadata = (Span, Option); + +impl Expr { + pub fn meta(&self) -> &T { + match self { + Expr::Atom(baml_value) => baml_value.meta(), + Expr::List(_, meta) => meta, + Expr::Map(_, meta) => meta, + Expr::ClassConstructor { meta, .. } => meta, + Expr::LLMFunction(_, _, meta) => meta, + Expr::Var(_, meta) => meta, + Expr::Lambda(_, _, meta) => meta, + Expr::App(_, _, meta) => meta, + Expr::ArgsTuple(_, meta) => meta, + Expr::Let(_, _, _, meta) => meta, + } + } + + pub fn meta_mut(&mut self) -> &mut T { + match self { + Expr::Atom(baml_value) => baml_value.meta_mut(), + Expr::List(_, meta) => meta, + Expr::Map(_, meta) => meta, + Expr::ClassConstructor { meta, .. } => meta, + Expr::LLMFunction(_, _, meta) => meta, + Expr::Var(_, meta) => meta, + Expr::Lambda(_, _, meta) => meta, + Expr::App(_, _, meta) => meta, + Expr::Let(_, _, _, meta) => meta, + Expr::ArgsTuple(_, meta) => meta, + } + } + + pub fn into_meta(self) -> T { + match self { + Expr::Atom(baml_value) => baml_value.meta().clone(), + Expr::List(_, meta) => meta, + Expr::Map(_, meta) => meta, + Expr::ClassConstructor { meta, .. } => meta, + Expr::LLMFunction(_, _, meta) => meta, + Expr::Var(_, meta) => meta, + Expr::Lambda(_, _, meta) => meta, + Expr::App(_, _, meta) => meta, + Expr::ArgsTuple(_, meta) => meta, + Expr::Let(_, _, _, meta) => meta, + } + } +} + +impl Expr { + /// A very rough pretty-printer for debugging expressions. + pub fn dump_str(&self) -> String { + match self { + Expr::Atom(atom) => atom.clone().value().to_string(), + Expr::LLMFunction(name, _, _) => name.clone(), + Expr::Var(name, _) => name.clone(), + Expr::Lambda(args, body, _) => format!("\\{:?} -> {}", args, body.dump_str()), + Expr::App(func, args, _) => { + let args_str = match args.as_ref() { + Expr::ArgsTuple(args, _) => args + .iter() + .map(|arg| arg.dump_str()) + .collect::>() + .join(", "), + _ => format!("(NON_ARGS_TUPLE {})", args.dump_str()), + }; + let func_str = match func.as_ref() { + Expr::LLMFunction(name, _, _) => name.clone(), + Expr::Var(name, _) => name.clone(), + _ => format!("({})", func.dump_str()), + }; + format!("{}({})", func_str, args_str) + } + Expr::Let(name, expr, body, _) => { + format!("Let {} = {} in {}", name, expr.dump_str(), body.dump_str()) + } + Expr::ArgsTuple(args, _) => format!( + "ArgsTuple({:?})", + args.iter().map(|arg| arg.dump_str()).collect::>() + ), + Expr::List(items, _) => { + let items = join( + items.iter().map(|item| item.dump_str()).collect::>(), + ", ", + ); + format!("[{}]", items) + } + Expr::Map(entries, _) => { + let entries = entries + .iter() + .map(|(key, value)| format!("{}: {}", key, value.dump_str())) + .collect::>() + .join(", "); + format!("{{{}}}", entries) + } + Expr::ClassConstructor { + name, + fields, + spread, + .. + } => { + let fields = fields + .iter() + .map(|(key, value)| format!("{}: {}", key, value.dump_str())) + .collect::>() + .join(", "); + let spread = match spread { + Some(expr) => format!("..{}", expr.dump_str()), + None => String::new(), + }; + format!("Class({} {{ {}{} }}", name, fields, spread) + } + } + } + + /// This quick hack of a function checks whether two expressions are + /// equal in terms of reduction state. This test is used to detect + /// if the evaluation stepper is stuck. + pub fn temporary_same_state(&self, other: &Expr) -> bool { + match (self, other) { + (Expr::Atom(a1), Expr::Atom(a2)) => a1.clone().value() == a2.clone().value(), + (Expr::Atom(_), _) => false, + + (Expr::LLMFunction(n1, _, _), Expr::LLMFunction(n2, _, _)) => n1 == n2, + (Expr::LLMFunction(_, _, _), _) => false, + + (Expr::Var(n1, _), Expr::Var(n2, _)) => n1 == n2, + (Expr::Var(_, _), _) => false, + + (Expr::Lambda(args1, body1, _), Expr::Lambda(args2, body2, _)) => { + args1 == args2 && body1.temporary_same_state(body2) + } + (Expr::Lambda(_, _, _), _) => false, + + (Expr::App(f1, x1, _), Expr::App(f2, x2, _)) => { + f1.temporary_same_state(f2) && x1.temporary_same_state(x2) + } + (Expr::App(_, _, _), _) => false, + + (Expr::Let(n1, e1, b1, _), Expr::Let(n2, e2, b2, _)) => { + n1 == n2 && e1.temporary_same_state(e2) && b1.temporary_same_state(b2) + } + (Expr::Let(_, _, _, _), _) => false, + + (Expr::ArgsTuple(args1, _), Expr::ArgsTuple(args2, _)) => { + args1.len() == args2.len() + && args1 + .iter() + .zip(args2.iter()) + .all(|(a1, a2)| a1.temporary_same_state(a2)) + } + (Expr::ArgsTuple(_, _), _) => false, + + ( + Expr::ClassConstructor { + name: n1, + fields: e1, + spread: s1, + .. + }, + Expr::ClassConstructor { + name: n2, + fields: e2, + spread: s2, + .. + }, + ) => { + n1 == n2 + && e1.len() == e2.len() + && e1 + .iter() + .zip(e2.iter()) + .all(|((_k1, v1), (_k2, v2))| v1.temporary_same_state(v2)) + && (match (s1, s2) { + (Some(s1), Some(s2)) => s1.temporary_same_state(s2), + (None, None) => true, + _ => false, + }) + } + (Expr::ClassConstructor { .. }, _) => false, + + (Expr::Map(e1, _), Expr::Map(e2, _)) => { + e1.len() == e2.len() + && e1 + .iter() + .zip(e2.iter()) + .all(|((_k1, v1), (_k2, v2))| v1.temporary_same_state(v2)) + } + (Expr::Map(_, _), _) => false, + + (Expr::List(e1, _), Expr::List(e2, _)) => { + e1.len() == e2.len() + && e1 + .iter() + .zip(e2.iter()) + .all(|(a1, a2)| a1.temporary_same_state(a2)) + } + (Expr::List(_, _), _) => false, + } + } +} + +/// Special methods for Exprs parameterized by the ExprMetadata type. +impl Expr { + /// Attempt to smoosh an expression that has been deeply evaluated into a BamlValue. + /// If it encounters any non-evaluated sub-expressions, it returns None. + pub fn as_atom(&self) -> Option> { + match self { + Expr::Atom(atom) => Some(atom.clone()), + Expr::List(items, meta) => { + let atom_items = items + .iter() + .map(|item| item.as_atom()) + .collect::>>()?; + Some(BamlValueWithMeta::List(atom_items, meta.clone())) + } + Expr::Map(entries, meta) => { + let atom_entries = entries + .iter() + .map(|(key, value)| { + let atom = value.as_atom()?; + Some((key.clone(), atom)) + }) + .collect::>>>()?; + Some(BamlValueWithMeta::Map(atom_entries, meta.clone())) + } + // A class constructor may not be evaluated into an atom if it still contains a spread. + Expr::ClassConstructor { + name, + fields, + spread, + meta, + } => { + if spread.is_some() { + None + } else { + let atom_entries = fields + .iter() + .map(|(key, value)| { + let atom = value.as_atom()?; + Some((key.clone(), atom)) + }) + .collect::>>>()?; + Some(BamlValueWithMeta::Class( + name.clone(), + atom_entries, + meta.clone(), + )) + } + } + _ => None, + } + } +} diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index bf9f9b124..004dd3206 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -88,6 +88,7 @@ pub enum FieldType { Tuple(Vec), Optional(Box), RecursiveTypeAlias(String), + Arrow(Box), WithMetadata { base: Box, constraints: Vec, @@ -99,6 +100,12 @@ pub trait HasFieldType { fn field_type<'a>(&'a self) -> &'a FieldType; } +#[derive(serde::Serialize, Debug, Clone, PartialEq, Eq, Hash)] +pub struct Arrow { + pub param_types: Vec, + pub return_type: FieldType, +} + // Impl display for FieldType impl std::fmt::Display for FieldType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -133,6 +140,17 @@ impl std::fmt::Display for FieldType { FieldType::Map(k, v) => write!(f, "map<{k}, {v}>"), FieldType::List(t) => write!(f, "{t}[]"), FieldType::Optional(t) => write!(f, "{t}?"), + FieldType::Arrow(arrow) => write!( + f, + "({}) -> {}", + arrow + .param_types + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", "), + arrow.return_type.to_string() + ), FieldType::WithMetadata { base, .. } => base.fmt(f), } } @@ -239,7 +257,8 @@ impl ToUnionName for FieldType { | FieldType::Enum(_) | FieldType::Literal(_) | FieldType::Class(_) - | FieldType::RecursiveTypeAlias(_) => IndexSet::new(), + | FieldType::RecursiveTypeAlias(_) + | FieldType::Arrow(_) => IndexSet::new(), FieldType::Tuple(inner) => inner.iter().flat_map(|t| t.find_union_types()).collect(), FieldType::Optional(inner) => inner.find_union_types(), FieldType::WithMetadata { base, .. } => base.find_union_types(), @@ -295,6 +314,7 @@ impl ToUnionName for FieldType { } FieldType::RecursiveTypeAlias(name) => name.to_string(), FieldType::WithMetadata { base, .. } => base.to_union_name(), + FieldType::Arrow(_) => "function".to_string(), } } } diff --git a/engine/baml-lib/baml-types/src/lib.rs b/engine/baml-lib/baml-types/src/lib.rs index 36e248ffb..b47565842 100644 --- a/engine/baml-lib/baml-types/src/lib.rs +++ b/engine/baml-lib/baml-types/src/lib.rs @@ -1,4 +1,5 @@ mod constraint; +pub mod expr; mod map; mod media; mod minijinja; @@ -13,7 +14,7 @@ mod value_expr; pub use baml_value::{BamlValue, BamlValueWithMeta, Completion, CompletionState}; pub use constraint::*; pub use field_type::{ - FieldType, HasFieldType, LiteralValue, StreamingBehavior, ToUnionName, TypeValue, + Arrow, FieldType, HasFieldType, LiteralValue, StreamingBehavior, ToUnionName, TypeValue, }; pub use generator::{GeneratorDefaultClientMode, GeneratorOutputType}; pub use map::Map as BamlMap; diff --git a/engine/baml-lib/baml-types/src/value_expr.rs b/engine/baml-lib/baml-types/src/value_expr.rs index 49bec20e9..5575bdfa8 100644 --- a/engine/baml-lib/baml-types/src/value_expr.rs +++ b/engine/baml-lib/baml-types/src/value_expr.rs @@ -18,6 +18,8 @@ pub enum Resolvable { Array(Vec>, Meta), // This includes key-value pairs for classes Map(IndexMap)>, Meta), + // The class name and list of fields as resolvable values. + ClassConstructor(String, Vec<(String, Resolvable)>, Meta), Null(Meta), } @@ -108,6 +110,7 @@ impl Resolvable { Resolvable::Bool(_, meta) => meta, Resolvable::Array(_, meta) => meta, Resolvable::Map(_, meta) => meta, + Resolvable::ClassConstructor(_, _, meta) => meta, Resolvable::Null(meta) => meta, } } @@ -138,6 +141,7 @@ impl Resolvable { .join(",\n"); format!("{{\n{content}\n}}") } + Resolvable::ClassConstructor(class_name, _, _) => class_name.to_string(), Resolvable::Null(..) => String::from("null"), } } @@ -220,6 +224,14 @@ impl UnresolvedValue { .collect(), (), ), + Self::ClassConstructor(class_name, fields, ..) => Resolvable::ClassConstructor( + class_name.clone(), + fields + .iter() + .map(|(k, v)| (k.clone(), v.without_meta())) + .collect(), + (), + ), Self::Null(..) => Resolvable::Null(()), } } @@ -294,6 +306,9 @@ impl UnresolvedValue { Self::Bool(..) => anyhow::bail!("Expected a string, not a bool"), Self::Map(..) => anyhow::bail!("Expected a string, not a map"), Self::Null(..) => anyhow::bail!("Expected a string, not null"), + Self::ClassConstructor(..) => { + anyhow::bail!("Expected a string, not a class constructor") + } } } @@ -370,6 +385,17 @@ impl UnresolvedValue { .collect::>()?; Ok(ResolvedValue::Map(values, ())) } + Self::ClassConstructor(class_name, fields, _meta) => { + let new_fields = fields + .iter() + .map(|(k, v)| v.resolve(ctx).map(|res| (k.clone(), res))) + .collect::>>()?; + Ok(ResolvedValue::ClassConstructor( + class_name.clone(), + new_fields, + (), + )) + } Self::Null(..) => Ok(ResolvedValue::Null(())), } } @@ -418,6 +444,12 @@ impl TryFrom for serde_json::Value { .map(|(k, (_, v))| Ok((k.clone(), serde_json::Value::try_from(v)?))) .collect::>()?, ), + ResolvedValue::ClassConstructor(_class_name, fields, ..) => serde_json::Value::Object( + fields + .into_iter() + .map(|(k, v)| Ok((k, serde_json::Value::try_from(v)?))) + .collect::>()?, + ), ResolvedValue::Null(..) => serde_json::Value::Null, }) } diff --git a/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml b/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml index 4547adcd9..9f8eb080c 100644 --- a/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml +++ b/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml @@ -18,12 +18,10 @@ class Foo { // | // 11 | class Foo { // 12 | if string -// 13 | ETA ETA? // | // error: Error validating field `ETA` in class `ETA`: When using the python/pydantic generator, a field name must not be exactly equal to the type name. Consider changing the field name and using an alias. // --> class/generator_keywords1.baml:13 // | // 12 | if string // 13 | ETA ETA? -// 14 | } // | diff --git a/engine/baml-lib/baml/tests/validation_files/expr/constructors.baml b/engine/baml-lib/baml/tests/validation_files/expr/constructors.baml new file mode 100644 index 000000000..d2e17519d --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/constructors.baml @@ -0,0 +1,22 @@ +class MyClass { + a int + b string +} + +let x = MyClass { a: 1, b: 2 }; + +let y = MyClass { a: 1, ..x }; + + +let default_person = Person { + name: "John Doe", + age: 20, + poem: "Never was there a man more plain." +}; + + +class Person { + name string + age int + poem string +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/expr_fn.baml b/engine/baml-lib/baml/tests/validation_files/expr/expr_fn.baml new file mode 100644 index 000000000..3aa27c741 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/expr_fn.baml @@ -0,0 +1,11 @@ +fn Foo(x: int) -> int { + let y = x; + x +} + +test TestBar { + functions [Foo] + args { + x 1 + } +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/expr_full.baml b/engine/baml-lib/baml/tests/validation_files/expr/expr_full.baml new file mode 100644 index 000000000..189df3b48 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/expr_full.baml @@ -0,0 +1,71 @@ + +function MakePoem(length: int) -> string { + client GPT4o + prompt #"Write a poem {{ length }} lines long."# +} + +function CombinePoems(poem1: string, poem2: string) -> string { + client GPT4o + prompt #"Combine the following two poems into one poem. + + Poem 1: + {{ poem1 }} + + Poem 2: + {{ poem2 }} + "# +} + +let poem = MakePoem(10); + +let another = { + let x = MakePoem(10); + let y = MakePoem(5); + CombinePoems(x,y) +}; + +fn Pipeline() -> string { + let x = MakePoem(6); + let y = MakePoem(6); + let a = MakePoem(6); + let b = MakePoem(6); + let xy = CombinePoems(x,y); + let ab = CombinePoems(a,b); + CombinePoems(xy, ab) +} + +fn Pyramid() -> string { + CombinePoems( CombinePoems( MakePoem(10), MakePoem(10)), MakePoem(10)) +} + +fn OuterPyramid() -> string { + CombinePoems(poem, another) +} + +test TestPipeline() { + functions [Pipeline] + args { } +} + +test TestPyramid() { + functions [Pyramid] + args { } +} + +test OuterPyramid() { + functions [OuterPyramid] + args { } +} + +client GPT4o { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } +} + +test TestMakePoem() { + functions [MakePoem] + args { length 4 } +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/expr_list.baml b/engine/baml-lib/baml/tests/validation_files/expr/expr_list.baml new file mode 100644 index 000000000..983fe658f --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/expr_list.baml @@ -0,0 +1,7 @@ +fn Id(x: int) -> int { + x +} + +fn Go(x: int) -> int[] { + [x, Id(x), 3] +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/expr_small.baml b/engine/baml-lib/baml/tests/validation_files/expr/expr_small.baml new file mode 100644 index 000000000..84b1fe49d --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/expr_small.baml @@ -0,0 +1,3 @@ +fn A(x: int) -> int { + x +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/missing_return_value.baml b/engine/baml-lib/baml/tests/validation_files/expr/missing_return_value.baml new file mode 100644 index 000000000..b5163f707 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/missing_return_value.baml @@ -0,0 +1,12 @@ +fn NoRet(x: int) { + 1 +} + +// error: fn must have a return type: e.g. fn Foo() -> int +// --> expr/missing_return_value.baml:1 +// | +// | +// 1 | fn NoRet(x: int) { +// 2 | 1 +// 3 | } +// | diff --git a/engine/baml-lib/baml/tests/validation_files/expr/missing_semicolons.baml b/engine/baml-lib/baml/tests/validation_files/expr/missing_semicolons.baml new file mode 100644 index 000000000..a2a7b8773 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/missing_semicolons.baml @@ -0,0 +1,20 @@ +// Missing semicolons. +fn Foo(x: int) -> int { + let y = "hello" + x +} + +let x = 1 + +// error: Statement must end with a semicolon. +// --> expr/missing_semicolons.baml:3 +// | +// 2 | fn Foo(x: int) -> int { +// 3 | let y = "hello" +// | +// error: Statement must end with a semicolon. +// --> expr/missing_semicolons.baml:7 +// | +// 6 | +// 7 | let x = 1 +// | diff --git a/engine/baml-lib/baml/tests/validation_files/expr/mixed_pipeline.baml b/engine/baml-lib/baml/tests/validation_files/expr/mixed_pipeline.baml new file mode 100644 index 000000000..6076cca39 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/mixed_pipeline.baml @@ -0,0 +1,16 @@ +function LLM(x: int) -> int { + client GPT4o + prompt #"Return {{ x }} {{ ctx.output_format}}"# +} + +fn UseLLM(x: int) -> int { + LLM(x) +} + +client GPT4o { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/top_level_binding.baml b/engine/baml-lib/baml/tests/validation_files/expr/top_level_binding.baml new file mode 100644 index 000000000..f99466d21 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/top_level_binding.baml @@ -0,0 +1,6 @@ +let x = 1; + +let y = { + let b = 2; + [1,2,3] +}; \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/expr/unknown_name.baml b/engine/baml-lib/baml/tests/validation_files/expr/unknown_name.baml new file mode 100644 index 000000000..2bfca8a74 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/expr/unknown_name.baml @@ -0,0 +1,33 @@ +fn Go(x: int) -> int { + let y = a; + y +} + +fn Go2(x: int) -> int { + let z = x; + let a = z; + Go(z,a,b) +} + +fn Go3(x:int) -> int { + Unknown(x) +} + +// error: Unknown variable a +// --> expr/unknown_name.baml:2 +// | +// 1 | fn Go(x: int) -> int { +// 2 | let y = a; +// | +// error: Unknown variable b +// --> expr/unknown_name.baml:9 +// | +// 8 | let a = z; +// 9 | Go(z,a,b) +// | +// error: Unknown function Unknown +// --> expr/unknown_name.baml:13 +// | +// 12 | fn Go3(x:int) -> int { +// 13 | Unknown(x) +// | diff --git a/engine/baml-lib/baml/tests/validation_files/strings/unquoted_strings.baml b/engine/baml-lib/baml/tests/validation_files/strings/unquoted_strings.baml index 0f682d3d8..d5906e45d 100644 --- a/engine/baml-lib/baml/tests/validation_files/strings/unquoted_strings.baml +++ b/engine/baml-lib/baml/tests/validation_files/strings/unquoted_strings.baml @@ -47,13 +47,6 @@ client Hello { // 6 | banned2 #helloworld // 7 | banned3 hello(world) // | -// error: Error validating: This line is not a valid field or attribute definition. A valid property may look like: 'myProperty "some value"' for example, with no colons. -// --> strings/unquoted_strings.baml:7 -// | -// 6 | banned2 #helloworld -// 7 | banned3 hello(world) -// 8 | } -// | // error: Error validating: This line is invalid. It does not start with any known Baml schema keyword. // --> strings/unquoted_strings.baml:9 // | diff --git a/engine/baml-lib/diagnostics/src/lib.rs b/engine/baml-lib/diagnostics/src/lib.rs index 526170656..227bd2566 100644 --- a/engine/baml-lib/diagnostics/src/lib.rs +++ b/engine/baml-lib/diagnostics/src/lib.rs @@ -8,5 +8,5 @@ mod warning; pub use collection::Diagnostics; pub use error::DatamodelError; pub use source_file::SourceFile; -pub use span::Span; +pub use span::{Span, SerializedSpan}; pub use warning::DatamodelWarning; diff --git a/engine/baml-lib/diagnostics/src/span.rs b/engine/baml-lib/diagnostics/src/span.rs index ed0b40568..aa01a7cec 100644 --- a/engine/baml-lib/diagnostics/src/span.rs +++ b/engine/baml-lib/diagnostics/src/span.rs @@ -58,11 +58,13 @@ impl Span { } } - match (start, end) { + let res = match (start, end) { (Some(start), Some(end)) => (start, end), (Some(start), None) => (start, (line, column)), _ => ((0, 0), (0, 0)), - } + }; + log::info!("Span line and column: {:?} => {:?}", self, res); + res } /// Create a fake span. Useful when generating test data that requires @@ -82,3 +84,28 @@ impl From<(SourceFile, pest::Span<'_>)> for Span { } } } + +/// A special-purpose span used for communicating with the JS playground. +/// Currently its only job is indicating the span of a currently-active +/// LLM Function. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)] +pub struct SerializedSpan { + pub file_path: String, + pub start_line: usize, + pub start: usize, + pub end_line: usize, + pub end: usize, +} + +impl SerializedSpan { + pub fn serialize(span: &Span) -> Self { + let (start, end) = span.line_and_column(); + SerializedSpan { + file_path: span.file.path().to_string(), + start_line: start.0, + start: start.1, + end_line: end.0, + end: end.1, + } + } +} diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index aadcb9aa0..5077055dd 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -58,8 +58,8 @@ pub struct Class { pub struct OutputFormatContent { pub enums: Arc>, pub classes: Arc>, - recursive_classes: Arc>, - structural_recursive_aliases: Arc>, + pub recursive_classes: Arc>, + pub structural_recursive_aliases: Arc>, pub target: FieldType, } @@ -320,6 +320,21 @@ impl OutputFormatContent { Builder::new(target) } + /// A temporary OutputFormatContent constructor used by Expression functions. + /// Expression Functions have no prompt and no client, so OutputFormatContent + /// is not applicable to them. However one is needed for generating a + /// PromptRenderer, which is technically needed in order to call the + /// function-calling methods of BamlRuntime. + pub fn mk_fake() -> OutputFormatContent { + OutputFormatContent { + enums: Arc::new(IndexMap::new()), + classes: Arc::new(IndexMap::new()), + recursive_classes: Arc::new(IndexSet::new()), + structural_recursive_aliases: Arc::new(IndexMap::new()), + target: FieldType::Primitive(TypeValue::String), + } + } + fn prefix(&self, options: &RenderOptions) -> Option { fn auto_prefix( ft: &FieldType, @@ -369,6 +384,7 @@ impl OutputFormatContent { FieldType::WithMetadata { base, .. } => { auto_prefix(base, options, output_format_content) } + FieldType::Arrow(_) => None, // TODO: Error? Arrow shouldn't appear here. } } @@ -574,6 +590,12 @@ impl OutputFormatContent { )?, } .to_string(), + FieldType::Arrow(_) => { + return Err(minijinja::Error::new( + minijinja::ErrorKind::BadSerialization, + "Arrow type is not supported in LLM function outputs", + )) + } }) } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index f6f846f47..a05fac571 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -98,6 +98,7 @@ impl TypeCoercer for FieldType { } FieldType::Map(_, _) => coerce_map(ctx, self, value).map(|v| v.with_target(target)), FieldType::Tuple(_) => Err(ctx.error_internal("Tuple not supported")), + FieldType::Arrow(_) => Err(ctx.error_internal("Arrow type not supported")), FieldType::WithMetadata { base, .. } => { let mut coerced_value = base.coerce(ctx, target, value)?; let constraint_results = run_user_checks(&coerced_value.clone().into(), self) @@ -210,6 +211,7 @@ impl DefaultValue for FieldType { } } FieldType::Primitive(_) => None, + FieldType::Arrow(_) => None, // If it has constraints, we can't assume our defaults meet them. FieldType::WithMetadata { .. } => None, } diff --git a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs index d991256b6..e47dfd290 100644 --- a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs +++ b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs @@ -298,6 +298,7 @@ fn required_done(ir: &impl IRHelperExtended, field_type: &FieldType) -> bool { // TODO: This rule is pretty aggressive. For example in the case of // Class | Enum it would not allow classes to be streamed. FieldType::Union(options) => options.iter().any(|option| required_done(ir, option)), + FieldType::Arrow(_) => false, // TODO: Error? Arrow shouldn't appear here. FieldType::WithMetadata { .. } => { unreachable!("distribute_metadata always consumes `WithMetadata`.") } diff --git a/engine/baml-lib/parser-database/src/walkers/expr_fn.rs b/engine/baml-lib/parser-database/src/walkers/expr_fn.rs new file mode 100644 index 000000000..12eb9fafe --- /dev/null +++ b/engine/baml-lib/parser-database/src/walkers/expr_fn.rs @@ -0,0 +1,66 @@ +use baml_types::expr::Expr; +use internal_baml_diagnostics::Span; +use internal_baml_schema_ast::ast::{self, WithSpan}; +use internal_baml_schema_ast::ast::{ExprFn, TopLevelAssignment, WithName}; + +use super::{ConfigurationWalker, Walker}; + +/// Walker for top level assignments. +pub type TopLevelAssignmentWalker<'db> = Walker<'db, ast::TopLevelAssignmentId>; + +impl<'db> TopLevelAssignmentWalker<'db> { + /// Returns the name of the top level assignment. + pub fn name(&self) -> &str { + self.db.ast[self.id].stmt.identifier.name() + } + + /// Return the AST node for the top level assignment. + pub fn top_level_assignment(&self) -> &ast::TopLevelAssignment { + &self.db.ast[self.id] + } + + /// Returns the expression of the top level assignment. + pub fn expr(&self) -> &ast::Expression { + &self.db.ast[self.id].stmt.body + } +} + +/// Walker for expression functions. +pub type ExprFnWalker<'db> = Walker<'db, ast::ExprFnId>; + +impl<'db> ExprFnWalker<'db> { + /// Return the name of the function. + pub fn name(&self) -> &str { + self.db.ast[self.id].name.name() + } + + /// Return the span of the name of the function. + pub fn name_span(&self) -> &Span { + self.db.ast[self.id].name.span() + } + + /// Return the AST node for the function. + pub fn expr_fn(&self) -> &ast::ExprFn { + &self.db.ast[self.id] + } + + /// Return the arguments of the function. + pub fn args(&self) -> &ast::BlockArgs { + &self.db.ast[self.id].args + } + + /// All the test cases for this function. + pub fn walk_tests(self) -> impl ExactSizeIterator> { + let mut tests = self + .db + .walk_test_cases() + .filter(|w| w.test_case().functions.iter().any(|f| f.0 == self.name())) + .collect::>(); + + // log::debug!("Found {} tests for function {}", tests.len(), self.name()); + + tests.sort_by(|a, b| a.name().cmp(b.name())); + + tests.into_iter() + } +} diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index dd166949b..a77118225 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -11,6 +11,7 @@ mod r#class; mod client; mod configuration; mod r#enum; +mod expr_fn; mod field; mod function; mod template_string; @@ -23,10 +24,11 @@ use either::Either; pub use field::*; pub use function::FunctionWalker; use internal_baml_schema_ast::ast::{ - FieldType, Identifier, TopId, TypeAliasId, TypeExpId, WithName, + FieldType, Identifier, SchemaAst, TopId, TypeAliasId, TypeExpId, WithName }; pub use r#class::*; pub use r#enum::*; +pub use expr_fn::{ExprFnWalker, TopLevelAssignmentWalker}; pub use template_string::TemplateStringWalker; /// A generic walker. Only walkers intantiated with a concrete ID type (`I`) are useful. @@ -38,7 +40,8 @@ pub struct Walker<'db, I> { pub id: I, } -impl<'db, I> Walker<'db, I> { +impl<'db, I> Walker<'db, I> +{ /// Traverse something else in the same schema. pub fn walk(self, other: J) -> Walker<'db, J> { self.db.walk(other) @@ -131,6 +134,11 @@ impl<'db> crate::ParserDatabase { .map(|function_id| self.walk(function_id)) } + /// Find a function by name. + pub fn find_expr_fn_by_name(&'db self, name: &str) -> Option> { + self.walk_expr_fns().find(|expr_fn| expr_fn.name() == name) + } + /// Find a function by name. pub fn find_retry_policy(&'db self, name: &str) -> Option> { self.interner @@ -244,6 +252,22 @@ impl<'db> crate::ParserDatabase { }) } + /// Walk all toplevel assignments in the schema. + pub fn walk_toplevel_assignments(&self) -> impl Iterator> { + self.ast() + .iter_tops() + .filter_map(|(top_id, _)| top_id.as_toplevel_assignment_id()) + .map(move |top_id| Walker { db: self, id: top_id }) + } + + /// Walk all expr functions in the schema. + pub fn walk_expr_fns(&self) -> impl Iterator> { + self.ast() + .iter_tops() + .filter_map(|(top_id, _)| top_id.as_expr_fn_id()) + .map(move |top_id| Walker { db: self, id: top_id }) + } + /// Walk all functions in the schema. pub fn walk_functions(&self) -> impl Iterator> { self.ast() diff --git a/engine/baml-lib/schema-ast/Cargo.toml b/engine/baml-lib/schema-ast/Cargo.toml index 430abed44..1a01c55f1 100644 --- a/engine/baml-lib/schema-ast/Cargo.toml +++ b/engine/baml-lib/schema-ast/Cargo.toml @@ -10,7 +10,6 @@ license-file.workspace = true dead_code = "deny" elided_named_lifetimes = "deny" unused_imports = "allow" -unused_variables = "deny" [dependencies] internal-baml-diagnostics = { path = "../diagnostics" } diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index a350937ae..291eba578 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -5,6 +5,7 @@ mod attribute; mod comment; mod config; +pub mod expr; mod expression; mod field; @@ -24,7 +25,10 @@ pub use argument::{Argument, ArgumentId, ArgumentsList}; pub use assignment::Assignment; pub use attribute::{Attribute, AttributeContainer, AttributeId}; pub use config::ConfigBlockProperty; -pub use expression::{Expression, RawString}; +pub use expr::{ExprFn, TopLevelAssignment}; +pub use expression::{ + ClassConstructor, ClassConstructorField, Expression, ExpressionBlock, RawString, Stmt, +}; pub use field::{Field, FieldArity, FieldType}; pub use identifier::{Identifier, RefIdentifier}; pub use indentation_type::IndentationType; @@ -118,6 +122,34 @@ impl std::ops::Index for SchemaAst { } } +/// An opaque identifier for a top-level assignment in a schema AST. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TopLevelAssignmentId(u32); + +impl std::ops::Index for SchemaAst { + type Output = TopLevelAssignment; + + fn index(&self, index: TopLevelAssignmentId) -> &Self::Output { + self.tops[index.0 as usize] + .as_top_level_assignment() + .expect("expected top level assignment") + } +} + +/// An opaque identifier for an expression function in a schema AST. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ExprFnId(u32); + +impl std::ops::Index for SchemaAst { + type Output = ExprFn; + + fn index(&self, index: ExprFnId) -> &Self::Output { + self.tops[index.0 as usize] + .as_expr_fn() + .expect("expected expression function") + } +} + /// An opaque identifier for a model in a schema AST. Use the /// `schema[model_id]` syntax to resolve the id to an `ast::Model`. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -174,6 +206,12 @@ pub enum TopId { TestCase(ValExpId), RetryPolicy(ValExpId), + + /// A top-level assignment. + TopLevelAssignment(TopLevelAssignmentId), + + /// A function declaration. + ExprFn(ExprFnId), } impl TopId { @@ -223,6 +261,20 @@ impl TopId { } } + pub fn as_toplevel_assignment_id(self) -> Option { + match self { + TopId::TopLevelAssignment(id) => Some(id), + _ => None, + } + } + + pub fn as_expr_fn_id(self) -> Option { + match self { + TopId::ExprFn(id) => Some(id), + _ => None, + } + } + pub fn as_retry_policy_id(self) -> Option { match self { TopId::RetryPolicy(id) => Some(id), @@ -237,7 +289,6 @@ impl TopId { } } } - impl std::ops::Index for SchemaAst { type Output = Top; @@ -252,6 +303,8 @@ impl std::ops::Index for SchemaAst { TopId::Generator(ValExpId(idx)) => idx, TopId::TestCase(ValExpId(idx)) => idx, TopId::RetryPolicy(ValExpId(idx)) => idx, + TopId::TopLevelAssignment(TopLevelAssignmentId(idx)) => idx, + TopId::ExprFn(ExprFnId(idx)) => idx, }; &self.tops[idx as usize] @@ -269,5 +322,9 @@ fn top_idx_to_top_id(top_idx: usize, top: &Top) -> TopId { Top::Generator(_) => TopId::Generator(ValExpId(top_idx as u32)), Top::TestCase(_) => TopId::TestCase(ValExpId(top_idx as u32)), Top::RetryPolicy(_) => TopId::RetryPolicy(ValExpId(top_idx as u32)), + Top::TopLevelAssignment(_) => { + TopId::TopLevelAssignment(TopLevelAssignmentId(top_idx as u32)) + } + Top::ExprFn(_) => TopId::ExprFn(ExprFnId(top_idx as u32)), } } diff --git a/engine/baml-lib/schema-ast/src/ast/argument.rs b/engine/baml-lib/schema-ast/src/ast/argument.rs index 5fc04d019..0675cbd98 100644 --- a/engine/baml-lib/schema-ast/src/ast/argument.rs +++ b/engine/baml-lib/schema-ast/src/ast/argument.rs @@ -1,5 +1,5 @@ use super::{Expression, Span, WithSpan}; -use std::fmt::{Display, Formatter}; +use std::fmt::{self, Display, Formatter}; /// An opaque identifier for a value in an AST enum. Use the /// `r#enum[enum_value_id]` syntax to resolve the id to an `ast::EnumValue`. @@ -33,6 +33,23 @@ pub struct ArgumentsList { pub arguments: Vec, } +impl Display for ArgumentsList { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(")?; + write!( + f, + "{}", + self.arguments + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", ") + )?; + write!(f, ")")?; + Ok(()) + } +} + impl ArgumentsList { pub fn iter(&self) -> impl ExactSizeIterator { self.arguments diff --git a/engine/baml-lib/schema-ast/src/ast/expr.rs b/engine/baml-lib/schema-ast/src/ast/expr.rs new file mode 100644 index 000000000..f17ddaf69 --- /dev/null +++ b/engine/baml-lib/schema-ast/src/ast/expr.rs @@ -0,0 +1,25 @@ +/// Types for the concrete syntax of compound expressions, +/// top-level assignments, and non-llm functions. +use baml_types::{TypeValue, UnresolvedValue}; +use internal_baml_diagnostics::Diagnostics; + +use crate::ast::{ + ArgumentsList, BlockArgs, Expression, ExpressionBlock, FieldType, Identifier, Span, Stmt, +}; + +/// A function definition. +#[derive(Debug, Clone)] +pub struct ExprFn { + pub name: Identifier, + pub args: BlockArgs, + pub return_type: Option, + pub body: ExpressionBlock, + pub span: Span, +} + +/// A top-level binding. +/// E.g. (at top-level in source file) `let x = 1;` +#[derive(Debug, Clone)] +pub struct TopLevelAssignment { + pub stmt: Stmt, +} diff --git a/engine/baml-lib/schema-ast/src/ast/expression.rs b/engine/baml-lib/schema-ast/src/ast/expression.rs index e4ae57614..e36805068 100644 --- a/engine/baml-lib/schema-ast/src/ast/expression.rs +++ b/engine/baml-lib/schema-ast/src/ast/expression.rs @@ -7,7 +7,7 @@ use crate::ast::Span; use bstd::dedent; use std::fmt; -use super::{Identifier, WithName, WithSpan}; +use super::{ArgumentsList, Identifier, WithName, WithSpan}; use baml_types::JinjaExpression; #[derive(Debug, Clone)] @@ -109,6 +109,15 @@ pub enum Expression { Map(Vec<(Expression, Expression)>, Span), /// A JinjaExpression. e.g. "this|length > 5". JinjaExpressionValue(JinjaExpression, Span), + /// Function abstraction. + Lambda(ArgumentsList, Box, Span), + /// Function Application + /// TODO: Function should be an Expression, not an Identifier. + FnApp(Identifier, Vec, Span), + /// A class constructor, e.g. `MyClass { x = 1, y = 2 }`. + ClassConstructor(ClassConstructor, Span), + /// An expression block, e.g. `{ let x = 1; x + 2 }`. + ExprBlock(ExpressionBlock, Span), } impl fmt::Display for Expression { @@ -138,6 +147,39 @@ impl fmt::Display for Expression { .join(","); write!(f, "{{{vals}}}") } + Expression::ClassConstructor(cc, ..) => { + write!(f, "{} {{", cc.class_name)?; + for field in &cc.fields { + match field { + ClassConstructorField::Named(name, expr) => { + write!(f, " {name}: {expr};")?; + } + ClassConstructorField::Spread(expr) => { + write!(f, " ..{expr};")?; + } + } + } + write!(f, "}}") + } + Expression::Lambda(args, body, _span) => { + write!(f, "{} => {}", args, body) + } + Expression::FnApp(name, args, _span) => { + write!(f, "{name}(")?; + for arg in args { + write!(f, "{arg},")?; // TODO: Drop the comma for the last argument. + } + write!(f, ")")?; + Ok(()) + } + Expression::ExprBlock(block, _span) => { + write!(f, "{{")?; + for stmt in &block.stmts { + write!(f, "{stmt};")?; + } + write!(f, "{}", block.expr)?; + write!(f, "}}") + } } } } @@ -247,6 +289,10 @@ impl Expression { Self::Identifier(id) => id.span(), Self::Map(_, span) => span, Self::Array(_, span) => span, + Self::ClassConstructor(_, span) => span, + Self::Lambda(_, _, span) => span, + Self::FnApp(_, _, span) => span, + Self::ExprBlock(_, span) => span, } } @@ -255,7 +301,7 @@ impl Expression { } /// Creates a friendly readable representation for a value's type. - pub fn describe_value_type(&self) -> &'static str { + pub fn describe_value_type(&self) -> &str { match self { Expression::BoolValue(_, _) => "boolean", Expression::NumericValue(_, _) => "numeric", @@ -271,6 +317,10 @@ impl Expression { }, Expression::Map(_, _) => "map", Expression::Array(_, _) => "array", + Expression::ClassConstructor(cc, _) => cc.class_name.name(), + Expression::Lambda(_, _, _) => "function", + Expression::FnApp(_, _, _) => "function_application", + Expression::ExprBlock(_, _) => "expression_block", } } @@ -325,6 +375,33 @@ impl Expression { }); } (Map(_, _), _) => panic!("Types do not match: {self:?} and {other:?}"), + (ClassConstructor(cc1, _), ClassConstructor(cc2, _)) => { + cc1.assert_eq_up_to_span(cc2); + } + (ClassConstructor(_, _), _) => panic!("Types do not match: {self:?} and {other:?}"), + (Lambda(args1, body1, _), Lambda(args2, body2, _)) => { + assert_eq!(args1.arguments.len(), args2.arguments.len()); + for (arg1, arg2) in args1.arguments.iter().zip(args2.arguments.iter()) { + arg1.assert_eq_up_to_span(arg2); + } + body1.assert_eq_up_to_span(body2); + } + (Lambda(_, _, _), _) => panic!("Types do not match: {self:?} and {other:?}"), + (FnApp(name1, args1, _), FnApp(name2, args2, _)) => { + name1.assert_eq_up_to_span(name2); + assert_eq!(args1.len(), args2.len()); + for (arg1, arg2) in args1.iter().zip(args2.iter()) { + arg1.assert_eq_up_to_span(arg2); + } + } + (FnApp(_, _, _), _) => panic!("Types do not match: {self:?} and {other:?}"), + (ExprBlock(block1, _), ExprBlock(block2, _)) => { + for (stmt1, stmt2) in block1.stmts.iter().zip(block2.stmts.iter()) { + stmt1.assert_eq_up_to_span(stmt2); + } + block1.expr.assert_eq_up_to_span(&block2.expr); + } + (ExprBlock(_, _), _) => panic!("Types do not match: {self:?} and {other:?}"), } } @@ -400,6 +477,130 @@ impl Expression { span.clone(), )) } + Expression::ClassConstructor(cc, span) => { + let fields = cc + .fields + .iter() + .filter_map(|f| f.to_unresolved_value(_diagnostics)) + .collect::>(); + Some(UnresolvedValue::ClassConstructor( + cc.class_name.name().to_string(), + fields, + span.clone(), + )) + } + Expression::Lambda(_arg_names, _body, _span) => todo!(), + Expression::FnApp(_, _, _) => None, // Is this right? + Expression::ExprBlock(_, _) => None, // Is this right? + } + } +} + +#[derive(Debug, Clone)] +pub struct ClassConstructor { + pub class_name: Identifier, + pub fields: Vec, +} + +#[derive(Debug, Clone)] +pub enum ClassConstructorField { + Named(Identifier, Expression), + Spread(Expression), +} + +impl ClassConstructor { + pub fn assert_eq_up_to_span(&self, other: &ClassConstructor) { + assert_eq!(self.class_name, other.class_name); + assert_eq!(self.fields.len(), other.fields.len()); + self.fields + .iter() + .zip(other.fields.iter()) + .for_each(|(a, b)| { + a.assert_eq_up_to_span(b); + }); + } +} + +impl ClassConstructorField { + pub fn assert_eq_up_to_span(&self, other: &ClassConstructorField) { + use ClassConstructorField::*; + match (self, other) { + (Named(name1, expr1), Named(name2, expr2)) => { + name1.assert_eq_up_to_span(name2); + expr1.assert_eq_up_to_span(expr2); + } + (Spread(expr1), Spread(expr2)) => { + expr1.assert_eq_up_to_span(expr2); + } + (Named(_, _), _) => panic!("Types do not match: {self:?} and {other:?}"), + (Spread(_expr), _) => panic!("Types do not match: {self:?} and {other:?}"), + } + } + + // TODO: This is weird. Figure out what should happen with UnresolvedValue on spreads. + pub fn to_unresolved_value( + &self, + _diagnostics: &mut internal_baml_diagnostics::Diagnostics, + ) -> Option<(String, UnresolvedValue)> { + match self { + ClassConstructorField::Named(name, expr) => Some(( + name.name().to_string(), + expr.to_unresolved_value(_diagnostics)?, + )), + ClassConstructorField::Spread(_expr) => None, + } + } +} + +#[derive(Debug, Clone)] +pub struct ExpressionBlock { + pub stmts: Vec, + pub expr: Box, +} + +// TODO: How do we indent the inner statements? +impl fmt::Display for ExpressionBlock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{{")?; + for stmt in &self.stmts { + write!(f, "{stmt}")?; } + write!(f, "{}", self.expr)?; + write!(f, "}}") + } +} + +impl ExpressionBlock { + pub fn assert_eq_up_to_span(&self, other: &ExpressionBlock) { + self.stmts + .iter() + .zip(other.stmts.iter()) + .for_each(|(a, b)| { + a.assert_eq_up_to_span(b); + }); + self.expr.assert_eq_up_to_span(&other.expr); + } +} + +// TODO: Stmt statements have the form` `let x = some_expr`. +// When we add more statements, `Stmt` will become an enum. +#[derive(Debug, Clone)] +pub struct Stmt { + pub identifier: Identifier, + pub body: Expression, + pub span: Span, +} + +impl fmt::Display for Stmt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "let {} = {}", self.identifier, self.body)?; + Ok(()) + } +} + +impl Stmt { + pub fn assert_eq_up_to_span(&self, other: &Stmt) { + self.identifier.assert_eq_up_to_span(&other.identifier); + self.body.assert_eq_up_to_span(&other.body); } } diff --git a/engine/baml-lib/schema-ast/src/ast/top.rs b/engine/baml-lib/schema-ast/src/ast/top.rs index 53fb603cf..dd0b19742 100644 --- a/engine/baml-lib/schema-ast/src/ast/top.rs +++ b/engine/baml-lib/schema-ast/src/ast/top.rs @@ -1,6 +1,7 @@ use super::{ assignment::Assignment, traits::WithSpan, Identifier, Span, TemplateString, TypeExpressionBlock, ValueExprBlock, WithIdentifier, + expr::{TopLevelAssignment, ExprFn} }; /// Enum for distinguishing between top-level entries @@ -26,6 +27,10 @@ pub enum Top { TestCase(ValueExprBlock), RetryPolicy(ValueExprBlock), + + TopLevelAssignment(TopLevelAssignment), + + ExprFn(ExprFn), } impl Top { @@ -41,6 +46,8 @@ impl Top { Top::Generator(_) => "generator", Top::TestCase(_) => "test_case", Top::RetryPolicy(_) => "retry_policy", + Top::TopLevelAssignment(_) => "assignment", + Top::ExprFn(_) => "function", } } @@ -77,8 +84,21 @@ impl Top { _ => None, } } -} + pub fn as_top_level_assignment(&self) -> Option<&TopLevelAssignment> { + match self { + Top::TopLevelAssignment(assignment) => Some(assignment), + _ => None, + } + } + + pub fn as_expr_fn(&self) -> Option<&ExprFn> { + match self { + Top::ExprFn(expr_fn) => Some(expr_fn), + _ => None, + } + } +} impl WithIdentifier for Top { /// The name of the item. fn identifier(&self) -> &Identifier { @@ -93,6 +113,8 @@ impl WithIdentifier for Top { Top::Generator(x) => x.identifier(), Top::TestCase(x) => x.identifier(), Top::RetryPolicy(x) => x.identifier(), + Top::TopLevelAssignment(x) => &x.stmt.identifier, + Top::ExprFn(x) => &x.name, } } } @@ -109,6 +131,8 @@ impl WithSpan for Top { Top::Generator(gen) => gen.span(), Top::TestCase(test) => test.span(), Top::RetryPolicy(retry) => retry.span(), + Top::TopLevelAssignment(asmnt) => &asmnt.stmt.span, + Top::ExprFn(function) => &function.span, } } } diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index 2845d5dcf..9f8028a8d 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -1,11 +1,11 @@ schema = { - SOI ~ (value_expression_block | type_expression_block | template_declaration | type_alias | comment_block | raw_string_literal | empty_lines | CATCH_ALL)* ~ EOI + SOI ~ (expr_fn | top_level_assignment | value_expression_block | type_expression_block | template_declaration | type_alias | comment_block | raw_string_literal | empty_lines | CATCH_ALL)* ~ EOI } // ###################################### // Unified Block for Class and Enum // ###################################### -type_expression_block = { identifier ~ identifier ~ named_argument_list? ~ BLOCK_OPEN ~ type_expression_contents ~ BLOCK_CLOSE } +type_expression_block = { identifier ~ identifier ~ named_argument_list? ~ BLOCK_OPEN ~ type_expression_contents ~ BLOCK_CLOSE } // Dynamic declarations start with the dynamic keyword followed by a normal type expression. dynamic_type_expression_block = { identifier ~ type_expression_block } @@ -35,7 +35,7 @@ value_expression = { identifier ~ expression? ~ (NEWLINE? ~ field_attri // Type builder // ###################################### -type_builder_block = { +type_builder_block = { TYPE_BUILDER_KEYWORD ~ BLOCK_OPEN ~ type_builder_contents ~ BLOCK_CLOSE } type_builder_contents = { (dynamic_type_expression_block | type_expression_block | type_alias | comment_block | empty_lines | BLOCK_LEVEL_CATCH_ALL)* } @@ -84,10 +84,10 @@ non_union = { array_notation | map | identifier | group | tuple | literal_type } parenthesized_type = { openParan ~ field_type_with_attr ~ closeParan } -path_identifier = { single_word ~ ("." ~ single_word)+ } -identifier = { path_identifier | namespaced_identifier | single_word } -single_word = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" | "-")* } -namespaced_identifier = { single_word ~ ("::" ~ single_word)+ } +path_identifier = { single_word ~ ("." ~ single_word)+ } +identifier = { path_identifier | namespaced_identifier | single_word } +single_word = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" | "-")* } +namespaced_identifier = { single_word ~ ("::" ~ single_word)+ } // ###################################### // Type Alias @@ -112,7 +112,18 @@ jinja_block_open = _{ "{{" } jinja_block_close = _{ "}}" } jinja_body = { (!(jinja_block_open | jinja_block_close) ~ ANY)* } jinja_expression = { jinja_block_open ~ jinja_body ~ jinja_block_close } -expression = { jinja_expression | map_expression | array_expression | numeric_literal | string_literal | identifier } +expression = { + fn_app | + lambda | + jinja_expression | + map_expression | + expr_block | + array_expression | + numeric_literal | + class_constructor | + string_literal | + identifier +} ARRAY_CATCH_ALL = { !"]" ~ CATCH_ALL } ENTRY_CATCH_ALL = { field_attribute | BLOCK_LEVEL_CATCH_ALL } // ###################################### @@ -123,7 +134,7 @@ numeric_literal = @{ ("-")? ~ ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT+)? } // ###################################### // String literals. These behave specially in BAML. // ###################################### -banned_chars = @{ "#" | "@" | "{" | "(" | "[" | "<" | "}" | ")" | "]" | ">" | "," | "'" | " //" | NEWLINE } +banned_chars = @{ "#" | "@" | "{" | "(" | "[" | "<" | "}" | ")" | "]" | ">" | "," | "'" | ";" | " //" | NEWLINE } banned_start_chars = { WHITESPACE | banned_chars } banned_end_chars = { WHITESPACE | banned_chars } unquoted_string_literal = @{ (!banned_start_chars ~ ANY) ~ (!banned_chars ~ !"\"" ~ ANY)* ~ (!banned_end_chars ~ ANY)* } @@ -176,7 +187,7 @@ doc_content = @{ (!NEWLINE ~ ANY)* } // Shared Building Blocks and Comments // ###################################### WHITESPACE = _{ " " | "\t" } -NEWLINE = { "\n" | "\r\n" | "\r" } +NEWLINE = { "\n" | "\r\n" | "\r" } empty_lines = @{ (WHITESPACE* ~ NEWLINE)+ } // ###################################### @@ -198,3 +209,53 @@ TYPE_BUILDER_KEYWORD = { "type_builder" } CLIENT_KEYWORD = { "client" | "client" } GENERATOR_KEYWORD = { "generator" } RETRY_POLICY_KEYWORD = { "retry_policy" } +SEMICOLON = { ";" } +COLON = { ":" } +COMMA = { "," } + +// ################################################# +// BAML Expressions +// ################################################# + +// Constant applicative form. +// e.g. top-level +// - `let x = go(1,2)` +// - `let z = { let b = 2; [b, 3] } +top_level_assignment = { stmt } + +// Regular function. +// e.g.: +// fn foo(x:int, y: bool?) -> string { +// go(x,y) +// } +expr_fn = { "fn" ~ identifier ~ named_argument_list ~ ARROW? ~ field_type_chain? ~ expr_block } + +// Body of a function (including curly brackets). +expr_block = { BLOCK_OPEN ~ NEWLINE? ~ (stmt ~ NEWLINE)* ~ expression ~ NEWLINE? ~ BLOCK_CLOSE } + +// Statement. +// Currently the only statement is a let-binding. +stmt = { let_expr ~ SEMICOLON? } + +// Let-binding statement. +let_expr = { "let" ~ identifier ~ "=" ~ expression } + +// Function application. +fn_app = { identifier ~ "(" ~ expression? ~ ("," ~ expression)* ~ ")" } + +// Anonymous function. +lambda = { named_argument_list ~ "=>" ~ expression } + +// Class constructors. +// e.g. `new MyClass { x = 1, y = 2 }`. +// +class_constructor = { identifier ~ "{" ~ NEWLINE? ~ (class_field_value_pair ~ COMMA? ~ NEWLINE?)* ~ NEWLINE? ~ "}" } + +// A single field in a class constructor. +class_field_value_pair = { (identifier ~ COLON ~ expression) | struct_spread } + +// struct spread a.k.a struct update syntax. +// e.g. `..other_struct` +// Used in constructors to initialize fields of a new struct from some other struct. +// e.g. `new MyClass { a: 1, b: 2, ..other_struct }` +struct_spread = { ".." ~ expression } \ No newline at end of file diff --git a/engine/baml-lib/schema-ast/src/parser/mod.rs b/engine/baml-lib/schema-ast/src/parser/mod.rs index 5f018bb6e..d1c12ed09 100644 --- a/engine/baml-lib/schema-ast/src/parser/mod.rs +++ b/engine/baml-lib/schema-ast/src/parser/mod.rs @@ -3,6 +3,7 @@ mod parse_arguments; mod parse_assignment; mod parse_attribute; mod parse_comments; +pub mod parse_expr; mod parse_expression; mod parse_field; mod parse_identifier; diff --git a/engine/baml-lib/schema-ast/src/parser/parse_expr.rs b/engine/baml-lib/schema-ast/src/parser/parse_expr.rs new file mode 100644 index 000000000..cdc880f6b --- /dev/null +++ b/engine/baml-lib/schema-ast/src/parser/parse_expr.rs @@ -0,0 +1,197 @@ +use super::{ + helpers::{parsing_catch_all, Pair}, + parse_identifier::parse_identifier, + Rule, +}; +use crate::ast::ArgumentsList; +use crate::parser::{ + parse_expression::parse_expression, parse_identifier, + parse_named_args_list::parse_named_argument_list, +}; +use crate::{ + assert_correct_parser, + ast::{expr::ExprFn, ExpressionBlock, *}, + parser::parse_arguments::parse_arguments_list, + unreachable_rule, +}; +use crate::{ + ast::{self, Expression, Stmt, TopLevelAssignment}, + parser::{parse_field::parse_field_type_chain, parse_types::parse_field_type}, +}; +use internal_baml_diagnostics::{DatamodelError, Diagnostics}; + +pub fn parse_expr_fn(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Option { + assert_correct_parser!(token, Rule::expr_fn); + let span = diagnostics.span(token.as_span()); + let mut tokens = token.into_inner(); + let name = parse_identifier(tokens.next()?, diagnostics); + let args = parse_named_argument_list(tokens.next()?, diagnostics); + let arrow_or_body = tokens.next()?; + + // We may or may not have an arrow and a return type. + // If the args list is immediately followed by an arrow, we have an arrow and a return type. + // Otherwise, we have just a body. + let (maybe_return_type, maybe_body) = if matches!(arrow_or_body.as_rule(), Rule::ARROW) { + let return_type = parse_field_type_chain(tokens.next()?, diagnostics); + let function_body = parse_function_body(tokens.next()?, diagnostics); + (Some(return_type), function_body) + } else { + diagnostics.push_error(DatamodelError::new_static( + "fn must have a return type: e.g. fn Foo() -> int", + span.clone(), + )); + let function_body = parse_function_body(arrow_or_body, diagnostics); + (None, function_body) + }; + match (maybe_return_type, maybe_body) { + (Some(return_type), Some(body)) => Some(ExprFn { + name, + args, + return_type, + body, + span, + }), + _ => None, + } +} + +pub fn parse_top_level_assignment( + token: Pair<'_>, + diagnostics: &mut Diagnostics, +) -> Option { + assert_correct_parser!(token, Rule::top_level_assignment); + let mut tokens = token.into_inner(); + let stmt = parse_statement(tokens.next()?, diagnostics)?; + Some(TopLevelAssignment { stmt }) +} + +pub fn parse_statement(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Option { + assert_correct_parser!(token, Rule::stmt); + let span = diagnostics.span(token.as_span()); + let mut tokens = token.into_inner(); + // Our only statements are let bindings, so: + let let_binding_token = tokens.next()?; + assert_correct_parser!(let_binding_token, Rule::let_expr); + let mut let_binding_tokens = let_binding_token.into_inner(); + let identifier = parse_identifier(let_binding_tokens.next()?, diagnostics); + + let rhs = let_binding_tokens.next()?; + let rhs_span = diagnostics.span(rhs.as_span()); + let maybe_body = match rhs.as_rule() { + Rule::expr_block => { + let block_span = diagnostics.span(rhs.as_span()); + eprintln!("parsing expr_block"); + let maybe_expr_block = parse_expr_block(rhs, diagnostics); + maybe_expr_block.map(|expr_block| Expression::ExprBlock(expr_block, block_span)) + } + Rule::expression => { + eprintln!("parsing expr"); + let maybe_expr = parse_expression(rhs, diagnostics); + maybe_expr + } + _ => { + diagnostics.push_error(DatamodelError::new_static( + "Parser only allows expr_block and expr here", + rhs_span, + )); + None + } + }; + let maybe_semicolon = tokens.next(); + match maybe_semicolon { + Some(p) if p.as_str() == ";" => {} + _ => { + diagnostics.push_error(DatamodelError::new_static( + "Statement must end with a semicolon.", + span.clone(), + )); + } + } + maybe_body.map(|body| Stmt { + identifier, + body, + span, + }) +} + +pub fn parse_expr_block(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Option { + assert_correct_parser!(token, Rule::expr_block); + let span = diagnostics.span(token.as_span()); + let mut tokens = token.into_inner(); + let mut stmts = Vec::new(); + let mut expr = None; + let _open_bracket = tokens.next()?; + for item in tokens { + match item.as_rule() { + Rule::stmt => { + let maybe_stmt = parse_statement(item, diagnostics); + if let Some(stmt) = maybe_stmt { + stmts.push(stmt); + } + } + Rule::expression => { + let maybe_expr = parse_expression(item, diagnostics); + if let Some(parsed_expr) = maybe_expr { + expr = Some(parsed_expr); + break; + } + } + Rule::BLOCK_CLOSE => { + if expr.is_none() { + diagnostics.push_error(DatamodelError::new_static( + "Function must end in an expression.", + span.clone(), + )); + } + break; + } + Rule::NEWLINE => { + continue; + } + _ => { + diagnostics.push_error(DatamodelError::new_static( + "Internal Error: Parser only allows statements and expressions in function body.", + span.clone() + )); + } + } + } + expr.map(|e| ExpressionBlock { + stmts, + expr: Box::new(e), + }) +} + +pub fn parse_fn_app(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Option { + assert_correct_parser!(token, Rule::fn_app); + let span = diagnostics.span(token.as_span()); + let mut tokens = token.into_inner(); + let fn_name = parse_identifier(tokens.next()?, diagnostics); + let mut args = Vec::new(); + for item in tokens { + let maybe_arg = parse_expression(item, diagnostics); + if let Some(arg) = maybe_arg { + args.push(arg); + } + } + Some(Expression::FnApp(fn_name, args, span)) +} + +pub fn parse_lambda(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Option { + assert_correct_parser!(token, Rule::lambda); + let span = diagnostics.span(token.as_span()); + let mut tokens = token.into_inner(); + let mut args = ArgumentsList { + arguments: Vec::new(), + }; + parse_arguments_list(tokens.next()?, &mut args, &None, diagnostics); + let maybe_body = parse_function_body(tokens.next()?, diagnostics); + maybe_body.map(|body| Expression::Lambda(args, Box::new(body), span)) +} + +pub fn parse_function_body( + token: Pair<'_>, + diagnostics: &mut Diagnostics, +) -> Option { + parse_expr_block(token, diagnostics) +} diff --git a/engine/baml-lib/schema-ast/src/parser/parse_expression.rs b/engine/baml-lib/schema-ast/src/parser/parse_expression.rs index 190ba66fa..bf43e7369 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_expression.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_expression.rs @@ -1,21 +1,33 @@ use super::{ helpers::{parsing_catch_all, Pair}, + parse_expr::{parse_expr_block, parse_fn_app, parse_lambda}, parse_identifier::parse_identifier, Rule, }; use crate::{assert_correct_parser, ast::*, unreachable_rule}; use baml_types::JinjaExpression; -use internal_baml_diagnostics::Diagnostics; +use internal_baml_diagnostics::{DatamodelError, Diagnostics}; pub(crate) fn parse_expression( token: Pair<'_>, diagnostics: &mut internal_baml_diagnostics::Diagnostics, ) -> Option { - let first_child = token.into_inner().next().unwrap(); + let first_child = token.into_inner().next()?; let span = diagnostics.span(first_child.as_span()); match first_child.as_rule() { Rule::numeric_literal => Some(Expression::NumericValue(first_child.as_str().into(), span)), Rule::string_literal => Some(parse_string_literal(first_child, diagnostics)), + Rule::raw_string_literal => Some(Expression::RawStringValue(parse_raw_string( + first_child, + diagnostics, + ))), + Rule::quoted_string_literal => { + let contents = first_child.into_inner().next().unwrap(); + Some(Expression::StringValue( + unescape_string(contents.as_str()), + span, + )) + } Rule::map_expression => Some(parse_map(first_child, diagnostics)), Rule::array_expression => Some(parse_array(first_child, diagnostics)), Rule::jinja_expression => Some(parse_jinja_expression(first_child, diagnostics)), @@ -24,6 +36,16 @@ pub(crate) fn parse_expression( first_child, diagnostics, ))), + Rule::class_constructor => Some(parse_class_constructor(first_child, diagnostics)), + Rule::fn_app => parse_fn_app(first_child, diagnostics), + Rule::lambda => parse_lambda(first_child, diagnostics), + Rule::expr_block => { + eprintln!("About to parse_expr_block on {first_child:?}"); + let res = parse_expr_block(first_child, diagnostics); + eprintln!("parse_expr_block result: {res:?}"); + res + } + .map(|block| Expression::ExprBlock(block, span)), Rule::BLOCK_LEVEL_CATCH_ALL => { diagnostics.push_error( @@ -286,6 +308,61 @@ pub fn parse_jinja_expression(token: Pair<'_>, diagnostics: &mut Diagnostics) -> } } +pub fn parse_class_constructor(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Expression { + assert_correct_parser!(token, Rule::class_constructor); + let span = diagnostics.span(token.as_span()); + let mut tokens = token.into_inner(); + let class_name = parse_identifier( + tokens.next().expect("Guaranteed by the grammar"), + diagnostics, + ); + let mut fields = Vec::new(); + while let Some(field_or_close_bracket) = tokens.next() { + if field_or_close_bracket.as_str() == "}" { + break; + } else if field_or_close_bracket.as_str() == "," { + continue; + } else if field_or_close_bracket.as_rule() == Rule::NEWLINE { + continue; + } else { + assert_correct_parser!(field_or_close_bracket, Rule::class_field_value_pair); + let mut field_tokens = field_or_close_bracket.into_inner(); + let identifier_or_spread = field_tokens.next().expect("Guaranteed by the grammar"); + match identifier_or_spread.as_rule() { + Rule::struct_spread => { + let mut struct_spread_tokens = identifier_or_spread.into_inner(); + let maybe_expr = parse_expression( + struct_spread_tokens + .next() + .expect("Guaranteed by the grammar"), + diagnostics, + ); + if let Some(expr) = maybe_expr { + fields.push(ClassConstructorField::Spread(expr)); + } + } + Rule::identifier => { + let field_name = parse_identifier(identifier_or_spread, diagnostics); + + let _colon = field_tokens.next(); + let maybe_expr = parse_expression( + field_tokens.next().expect("Guaranteed by the grammar"), + diagnostics, + ); + if let Some(expr) = maybe_expr { + fields.push(ClassConstructorField::Named(field_name, expr)); + } + } + _ => unreachable_rule!(identifier_or_spread, Rule::class_field_value_pair), + } + let _maybe_comma = tokens.next(); + } + } + let class_constructor = ClassConstructor { class_name, fields }; + + Expression::ClassConstructor(class_constructor, span) +} + #[cfg(test)] mod tests { use super::super::{BAMLParser, Rule}; diff --git a/engine/baml-lib/schema-ast/src/parser/parse_field.rs b/engine/baml-lib/schema-ast/src/parser/parse_field.rs index b1eb5cbce..f7f5569d2 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_field.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_field.rs @@ -160,7 +160,14 @@ pub fn parse_field_type_chain(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> } } Rule::field_operator => operators.push(current.as_str().to_string()), - _ => parsing_catch_all(current, "field_type_chain"), + _ => { + diagnostics.push_error(DatamodelError::new_model_validation_error( + "Unexpected token in field type chain", + "field_type_chain", + "", + diagnostics.span(current.as_span()), + )); + } } } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_named_args_list.rs b/engine/baml-lib/schema-ast/src/parser/parse_named_args_list.rs index 66e58a15b..e58f00289 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_named_args_list.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_named_args_list.rs @@ -16,7 +16,7 @@ use super::helpers::Pair; pub(crate) fn parse_named_argument_list( pair: Pair<'_>, diagnostics: &mut Diagnostics, -) -> Result { +) -> BlockArgs { assert!( pair.as_rule() == Rule::named_argument_list, "parse_named_argument_list called on the wrong rule: {:?}", @@ -48,7 +48,10 @@ pub(crate) fn parse_named_argument_list( } Rule::colon => {} Rule::field_type | Rule::field_type_chain => { - r#type = Some(parse_function_arg(arg, diagnostics)?); + match parse_function_arg(arg, diagnostics) { + Ok(t) => r#type = Some(t), + Err(e) => diagnostics.push_error(e), + } } _ => parsing_catch_all(arg, "named_argument_list"), } @@ -69,11 +72,11 @@ pub(crate) fn parse_named_argument_list( } } - Ok(BlockArgs { + BlockArgs { documentation: None, args, span, - }) + } } pub fn parse_function_arg( diff --git a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs index d359b82d7..c77493666 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs @@ -1,9 +1,12 @@ use std::path::{Path, PathBuf}; use super::{ - parse_assignment::parse_assignment, parse_template_string::parse_template_string, + parse_assignment::parse_assignment, + parse_expr::{parse_expr_fn, parse_top_level_assignment}, + parse_template_string::parse_template_string, parse_type_expression_block::parse_type_expression_block, - parse_value_expression_block::parse_value_expression_block, BAMLParser, Rule, + parse_value_expression_block::parse_value_expression_block, + BAMLParser, Rule, }; use crate::ast::*; use internal_baml_diagnostics::{DatamodelError, Diagnostics, SourceFile}; @@ -60,6 +63,19 @@ pub fn parse_schema( while let Some(current) = pairs.next() { match current.as_rule() { + Rule::top_level_assignment => { + parse_top_level_assignment(current, &mut diagnostics).map( + |top_level_assignment| { + top_level_definitions + .push(Top::TopLevelAssignment(top_level_assignment)); + }, + ); + } + Rule::expr_fn => { + parse_expr_fn(current, &mut diagnostics).map(|expr_fn| { + top_level_definitions.push(Top::ExprFn(expr_fn)); + }); + } Rule::type_expression_block => { let type_expr = parse_type_expression_block( current, @@ -188,7 +204,7 @@ mod tests { use super::parse_schema; use crate::ast::*; - use baml_types::TypeValue; + use baml_types::{expr::Expr, TypeValue}; // Add this line to import the ast module use internal_baml_diagnostics::SourceFile; @@ -397,4 +413,47 @@ mod tests { assert_eq!(alias.to_string(), "One"); } + + #[test] + fn test_top_level_assignment() { + let input = "let x = 1;"; + let path = "example_file.baml"; + let source = SourceFile::new_static(path.into(), input); + let (ast, _) = parse_schema(&Path::new(path), &source).unwrap(); + match ast.tops.as_slice() { + [Top::TopLevelAssignment(x)] => { + assert_eq!(x.stmt.identifier.name(), "x"); + } + _ => panic!("Expected a single top level assignment."), + } + } + + #[test] + fn test_top_level_block_assignment() { + let input = r#" + let x = { + let y = 10; + go(y, 20) + }; + "#; + let path = "example_file.baml"; + let source = SourceFile::new_static(path.into(), input); + let (ast, _) = parse_schema(&Path::new(path), &source).unwrap(); + match ast.tops.as_slice() { + [Top::TopLevelAssignment(x)] => { + dbg!(&x); + dbg!(&x.stmt); + assert_eq!(x.stmt.identifier.name(), "x"); + match &x.stmt.body { + Expression::ExprBlock(ExpressionBlock { stmts, expr }, _) => { + assert_eq!(stmts.len(), 1); + assert_eq!(stmts[0].identifier.name(), "y"); + assert!(matches!(expr.as_ref(), Expression::FnApp(_, _, _))); + } + _ => panic!("Expected ExpressionBlock"), + } + } + _ => panic!("Expected a single top level assignment."), + } + } } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_template_string.rs b/engine/baml-lib/schema-ast/src/parser/parse_template_string.rs index f60009ea6..4218c681d 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_template_string.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_template_string.rs @@ -25,10 +25,7 @@ pub(crate) fn parse_template_string( Rule::TEMPLATE_KEYWORD => {} Rule::identifier => name = Some(parse_identifier(current, diagnostics)), Rule::assignment => {} - Rule::named_argument_list => match parse_named_argument_list(current, diagnostics) { - Ok(arg) => input = Some(arg), - Err(err) => diagnostics.push_error(err), - }, + Rule::named_argument_list => { input = Some(parse_named_argument_list(current, diagnostics))}, Rule::raw_string_literal => { value = Some(Expression::RawStringValue(parse_raw_string( current, diff --git a/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs b/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs index 7d830993c..dcbf1dbb4 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs @@ -75,10 +75,7 @@ pub(crate) fn parse_type_expression_block( } Rule::BLOCK_OPEN | Rule::BLOCK_CLOSE => {} - Rule::named_argument_list => match parse_named_argument_list(current, diagnostics) { - Ok(arg) => input = Some(arg), - Err(err) => diagnostics.push_error(err), - }, + Rule::named_argument_list => { input = Some(parse_named_argument_list(current, diagnostics))}, Rule::type_expression_contents => { let mut pending_field_comment: Option> = None; diff --git a/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs b/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs index d8e3f1761..113075e16 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs @@ -39,10 +39,7 @@ pub(crate) fn parse_value_expression_block( }, Rule::ARROW => has_arrow = true, Rule::identifier => name = Some(parse_identifier(current, diagnostics)), - Rule::named_argument_list => match parse_named_argument_list(current, diagnostics) { - Ok(arg) => input = Some(arg), - Err(err) => diagnostics.push_error(err), - }, + Rule::named_argument_list => { input = Some(parse_named_argument_list(current, diagnostics))}, Rule::field_type | Rule::field_type_chain => { match parse_function_arg(current, diagnostics) { Ok(arg) => output = Some(arg), diff --git a/engine/baml-runtime/src/cli/serve/mod.rs b/engine/baml-runtime/src/cli/serve/mod.rs index 6ced94968..ef180103f 100644 --- a/engine/baml-runtime/src/cli/serve/mod.rs +++ b/engine/baml-runtime/src/cli/serve/mod.rs @@ -23,7 +23,9 @@ use axum_extra::{ headers::{self, authorization::Basic, Authorization, Header}, TypedHeader, }; -use baml_types::{BamlValue, GeneratorDefaultClientMode, GeneratorOutputType}; +use baml_types::{ + expr::Expr, expr::ExprMetadata, BamlValue, GeneratorDefaultClientMode, GeneratorOutputType, +}; use core::pin::Pin; use futures::Stream; use jsonish::ResponseBamlValue; @@ -206,15 +208,12 @@ impl Server { "Failed to bind to port {}; try using --port PORT to specify a different port.", port ))?; - + let baml_runtime = BamlRuntime::from_directory(&src_dir, std::env::vars().collect())?; Ok(( Arc::new(Self { src_dir: src_dir.clone(), port, - b: Arc::new(RwLock::new(BamlRuntime::from_directory( - &src_dir, - std::env::vars().collect(), - )?)), + b: Arc::new(RwLock::new(baml_runtime)), }), tcp_listener, )) @@ -285,13 +284,23 @@ impl Server { let s = self.clone(); let app = app.route( "/call/:msg", - post(move |b_fn, b_args| s.clone().baml_call_axum(b_fn, b_args)), + post( + move |extract::Path(b_fn): extract::Path, + extract::Json(b_args): extract::Json| async move { + s.clone().baml_call_axum(b_fn, b_args).await + }, + ), ); let s = self.clone(); let app = app.route( "/stream/:msg", - post(move |b_fn, b_args| s.clone().baml_stream_axum2(b_fn, b_args)), + post( + move |extract::Path(b_fn): extract::Path, + extract::Json(b_args): extract::Json| async move { + s.clone().baml_stream_axum2(b_fn, b_args).await + }, + ), ); let s = self.clone(); let app = app.route("/docs", get(move || s.clone().docs_handler())); @@ -405,11 +414,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` } } - async fn baml_call_axum( - self: Arc, - extract::Path(b_fn): extract::Path, - extract::Json(b_args): extract::Json, - ) -> Response { + async fn baml_call_axum(self: Arc, b_fn: String, b_args: serde_json::Value) -> Response { let mut b_options = None; if let Some(options_value) = b_args.get("__baml_options__") { match BamlOptions::deserialize(options_value) { @@ -541,11 +546,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` } // newline-delimited can be implemented using axum_streams::StreamBodyAs::json_nl(self.baml_stream(path, body)) - async fn baml_stream_axum2( - self: Arc, - extract::Path(path): extract::Path, - extract::Json(body): extract::Json, - ) -> Response { + async fn baml_stream_axum2(self: Arc, path: String, body: serde_json::Value) -> Response { let mut b_options = None; if let Some(options_value) = body.get("__baml_options__") { match BamlOptions::deserialize(options_value) { diff --git a/engine/baml-runtime/src/eval_expr.rs b/engine/baml-runtime/src/eval_expr.rs new file mode 100644 index 000000000..14800da47 --- /dev/null +++ b/engine/baml-runtime/src/eval_expr.rs @@ -0,0 +1,472 @@ +use anyhow::Context; +use futures::channel::mpsc; +use futures::stream::{self as stream, StreamExt}; +use internal_baml_core::internal_baml_diagnostics::SerializedSpan; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::{BamlRuntime, FunctionResult}; +use baml_types::expr::{Expr, ExprMetadata, Name}; +use baml_types::Arrow; +use baml_types::{BamlMap, BamlValue, BamlValueWithMeta}; +use internal_baml_core::ir::repr::IntermediateRepr; + +const MAX_STEPS: usize = 1000; + +pub struct EvalEnv<'a> { + pub context: HashMap>, + pub runtime: &'a BamlRuntime, + pub expr_tx: Option>>, +} + +impl<'a> EvalEnv<'a> { + pub fn dump_ctx(&self) -> String { + self.context + .iter() + .map(|(k, v)| format!("{}: {}", k, v.dump_str())) + .collect::>() + .join("\n") + } +} + +fn subst2<'a>( + expr: &Expr, + var_name: &Name, + val: &Expr, + env: &EvalEnv<'a>, +) -> anyhow::Result> { + let res: anyhow::Result> = match expr { + Expr::Var(expr_var_name, _) => { + if expr_var_name == var_name { + Ok(val.clone()) + } else { + if let Some(expr_fn) = env.context.get(expr_var_name) { + Ok(expr_fn.clone()) + } else { + Ok(expr.clone()) + } + } + } + Expr::Atom(_) => Ok(expr.clone()), + Expr::App(f, x, meta) => { + let f2 = subst2(f, var_name, val, env)?; + let x2 = subst2(x, var_name, val, env)?; + Ok(Expr::App(Arc::new(f2), Arc::new(x2), meta.clone())) + } + Expr::Lambda(params, body, meta) => Ok(Expr::Lambda( + params.clone(), + Arc::new(subst2(body, var_name, val, env)?), + meta.clone(), + )), + Expr::ArgsTuple(args, meta) => { + let mut new_args = Vec::new(); + for arg in args { + new_args.push(subst2(arg, var_name, val, env)?); + } + Ok(Expr::ArgsTuple(new_args, meta.clone())) + } + Expr::LLMFunction(_, _, _) => Ok(expr.clone()), + Expr::Let(name, value, body, meta) => { + if name == var_name { + // Skip substitution if the let binding shadows the variable. + Ok(expr.clone()) + } else { + let new_value = subst2(value, var_name, val, env)?; + let new_body = subst2(body, var_name, val, env)?; + Ok(Expr::Let( + name.clone(), + Arc::new(new_value), + Arc::new(new_body), + meta.clone(), + )) + } + } + Expr::List(items, meta) => { + let new_items = items + .iter() + .map(|item| subst2(item, var_name, val, env)) + .collect::>>()?; + Ok(Expr::List(new_items, meta.clone())) + } + Expr::Map(items, meta) => { + let new_items = items + .iter() + .map(|(key, value)| { + let new_value = subst2(value, var_name, val, env)?; + Ok((key.clone(), new_value)) + }) + .collect::>>()?; + Ok(Expr::Map(new_items, meta.clone())) + } + Expr::ClassConstructor { + name, + fields, + spread, + meta, + } => { + let new_fields = fields + .iter() + .map(|(key, value)| { + let new_value = subst2(value, var_name, val, env)?; + Ok((key.clone(), new_value)) + }) + .collect::>>()?; + let new_spread = spread + .as_ref() + .map(|spread| { + subst2(spread, var_name, val, env).map(|spread| Box::new(spread.clone())) + }) + .transpose()?; + Ok(Expr::ClassConstructor { + name: name.clone(), + fields: new_fields, + spread: new_spread, + meta: meta.clone(), + }) + } + }; + let res = res?; + Ok(res) +} + +/// Perform a single beta reduction. Note that we ignore env.context +/// here. Only use env for the runtime. +async fn beta_reduce<'a>( + env: &EvalEnv<'a>, + expr: &Expr, +) -> anyhow::Result> { + match expr { + Expr::Atom(_) => Ok(expr.clone()), + Expr::Let(name, value, body, meta) => { + // Rewrite the let binding as an application. + // e.g. (let x = y in f) => (\x y => f) + let lambda = Expr::Lambda(vec![name.clone()], body.clone(), meta.clone()); + let app = Expr::App(Arc::new(lambda), value.clone(), meta.clone()); + Box::pin(beta_reduce(env, &app)).await + } + Expr::App(f, x, meta) => { + match (f.as_ref(), x.as_ref()) { + (Expr::Lambda(params, body, _), Expr::ArgsTuple(args, _)) => { + let pairs = params + .iter() + .cloned() + .zip(args.iter().cloned()) + .collect::>(); + let new_body = pairs + .iter() + .fold(body.as_ref().clone(), |acc, (param, arg)| { + subst2(&acc, ¶m, &arg, env).as_ref().unwrap().clone() + }); + Box::pin(beta_reduce(env, &new_body)).await + } + (Expr::Lambda(params, body, _), arg) => { + if params.len() != 1 { + return Err(anyhow::anyhow!( + "Lambda takes exactly one argument: {:?}", + expr + )); + } + let new_body = subst2(body, ¶ms[0], &arg, env) + .as_ref() + .unwrap() + .clone(); + Box::pin(beta_reduce(env, &new_body)).await + } + (Expr::LLMFunction(name, arg_names, _), Expr::ArgsTuple(args, _)) => { + let mut evaluated_args: Vec = Vec::new(); + for arg in args { + let val = eval_to_value(env, arg).await; + evaluated_args.push(val.unwrap().unwrap().clone().value()); + } + + let params = evaluated_args + .into_iter() + .zip(arg_names.iter()) + .map(|(arg, name)| (name.clone(), arg)) + .collect::>(); + let args_map = BamlMap::from_iter(params.into_iter()); + let ctx = env + .runtime + .create_ctx_manager(BamlValue::String("none".to_string()), None); + + let app_span = SerializedSpan::serialize(&expr.meta().0); + if let Some(tx) = &env.expr_tx { + tx.unbounded_send(vec![app_span]).unwrap(); + } + // if let Some(tx) = &env.expr_tx { + // tx.unbounded_send(vec![]).unwrap(); + // } + let res: anyhow::Result = env + .runtime + .call_function(name.clone(), &args_map, &ctx, None, None, None) + .await + .0; + + if let Some(tx) = &env.expr_tx { + tx.unbounded_send(vec![]).unwrap(); + } + let val = res? + .parsed() + .as_ref() + .ok_or(anyhow::anyhow!( + "Impossible case - empty value in parsed result." + ))? + .as_ref() + .map_err(|e| anyhow::anyhow!("{e}"))? + .clone() + .0 + .map_meta(|_| ()); + Ok(Expr::Atom(val.map_meta(|_| meta.clone()))) + } + (Expr::Var(name, _), _) => { + let var_lookup = env + .context + .get(name) + .context(format!("Variable not found: {:?}", name))?; + let new_app = Expr::App(Arc::new(var_lookup.clone()), x.clone(), meta.clone()); + let res = Box::pin(beta_reduce(env, &new_app)).await?; + Ok(res) + } + _ => Err(anyhow::anyhow!("Not a function: {:?}", f)), + } + } + Expr::Var(name, _) => { + let var_lookup = env + .context + .get(name) + .context(format!("Variable not found: {:?}", name))?; + Ok(var_lookup.clone()) + } + Expr::List(_, _) => Ok(expr.clone()), + Expr::Map(_, _) => Ok(expr.clone()), + Expr::ClassConstructor { .. } => Ok(expr.clone()), + _ => Err(anyhow::anyhow!("Not an application: {:?}", expr)), + } +} + +/// Fully evaluate an expression to a value. +pub async fn eval_to_value<'a>( + env: &EvalEnv<'a>, + expr: &Expr, +) -> anyhow::Result>> { + let mut current_expr = expr.clone(); + + for steps in 0..MAX_STEPS { + match current_expr { + Expr::Atom(value) => return Ok(Some(value.clone().map_meta(|_| ()))), + Expr::List(items, meta) => { + let mut new_items = Vec::new(); + for item in items { + let val = Box::pin(eval_to_value(env, &item)) + .await? + .context("Evaluated value to None")?; + new_items.push(val); + } + let val = BamlValueWithMeta::List(new_items, ()); + return Ok(Some(val)); + } + Expr::Map(items, meta) => { + let mut new_items = BamlMap::new(); + for (key, value) in items { + let val = Box::pin(eval_to_value(env, &value)) + .await? + .context("Evaluated value to None")?; + new_items.insert(key.clone(), val); + } + return Ok(Some(BamlValueWithMeta::Map(new_items, ()))); + } + Expr::ClassConstructor { + name, + fields, + spread, + meta, + } => { + let mut new_fields = BamlMap::new(); + for (key, value) in fields { + let val = Box::pin(eval_to_value(env, &value)) + .await? + .context("Evaluated value to None")?; + new_fields.insert(key.clone(), val); + } + let mut spread_fields = match spread { + Some(spread) => { + let res = Box::pin(eval_to_value(env, spread.as_ref())).await?; + match res { + Some(BamlValueWithMeta::Class(spread_class_name, spread_fields, _)) => { + if name != spread_class_name { + return Err(anyhow::anyhow!("Class constructor name mismatch")); + } + spread_fields.clone() + } + _ => { + return Err(anyhow::anyhow!("Spread is not a class")); + } + } + } + None => BamlMap::new(), + }; + + spread_fields.extend(new_fields); + let val = BamlValueWithMeta::Class(name.clone(), spread_fields, ()); + return Ok(Some(val)); + } + other => { + // let new_expr = step(env, &other).await?; + let new_expr = Box::pin(beta_reduce(env, &other)).await?; + + if new_expr.temporary_same_state(expr) { + return Err(anyhow::anyhow!("Failed to make progress.")); + } + current_expr = new_expr; + } + } + } + Err(anyhow::anyhow!("Max steps reached.")) +} + +#[cfg(test)] +mod tests { + use crate::internal_baml_diagnostics::Span; + use baml_types::{BamlMap, BamlValue}; + use futures::channel::mpsc; + use internal_baml_core::ir::repr::make_test_ir; + use internal_baml_core::ir::IRHelper; + + use super::*; + use crate::BamlRuntime; + + // Make a testing runtime. It assumes the presence of + // OPENAI_API_KEY environment variable. + fn runtime(content: &str) -> BamlRuntime { + let openai_api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set."); + BamlRuntime::from_file_content( + ".", + &HashMap::from([("main.baml", content)]), + HashMap::from([("OPENAI_API_KEY", openai_api_key.as_str())]), + ) + .unwrap() + } + + // #[tokio::test] // Uncomment to run. + async fn test_eval_expr() { + let rt = runtime( + r##" +function MakePoem(length: int) -> string { + client GPT4o + prompt #"Write a poem {{ length }} lines long."# +} + +function CombinePoems(poem1: string, poem2: string) -> string { + client GPT4o + prompt #"Combine the following two poems into one poem. + + Poem 1: + {{ poem1 }} + + Poem 2: + {{ poem2 }} + "# +} + +let poem = MakePoem(10); + +let another = { + let x = MakePoem(10); + let y = MakePoem(5); + CombinePoems(x,y) +}; + +fn Pipeline() -> string { + let x = MakePoem(6); + let y = MakePoem(6); + let a = MakePoem(6); + let b = MakePoem(6); + let xy = CombinePoems(x,y); + let ab = CombinePoems(a,b); + CombinePoems(xy, ab) +} + +fn Pyramid() -> string { + CombinePoems( CombinePoems( MakePoem(10), MakePoem(10)), MakePoem(10)) +} + +let default_person = Person { + name: "John Doe", + age: 20, + poem: "Never was there a man more plain." +}; + +class Person { + name string + age int + poem string +} + +fn MakePerson() -> Person { + Person { name: "Greg", poem: "Hello, world!", ..default_person } +} + +fn OuterPyramid() -> string { + CombinePoems(poem, another) +} + +fn ExprList() -> string[] { + [ MakePoem(10), MakePoem(2) ] +} + +test TestPipeline() { + functions [Pipeline] + args { } +} + +test TestPyramid() { + functions [Pyramid] + args { } +} + +test OuterPyramid() { + functions [OuterPyramid] + args { } +} + +client GPT4o { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } +} + +test TestMakePoem() { + functions [MakePoem] + args { length 4 } +} + +test TestExprList() { + functions [ExprList] + args { } +} + +test TestMakePerson() { + functions [MakePerson] + args { } +} + "##, + ); + // dbg!(&rt.inner.ir.find_function("OuterPyramid").unwrap().item); + let ctx = rt.create_ctx_manager(BamlValue::String("test".to_string()), None); + + let on_event = |res: FunctionResult| { + eprintln!("on_event: {:?}", res); + }; + let (res, _) = rt + // .run_test("Second", "TestSecond", &ctx, Some(on_event)) + .run_test("OuterPyramid", "OuterPyramid", &ctx, Some(on_event), None) + // .run_test("MakePerson", "TestMakePerson", &ctx, Some(on_event), None) + // .run_test("CompareHaikus", "Test", &ctx, Some(on_event)) + // .run_test("LlmParseInt", "TestParse", &ctx, Some(on_event)) + .await; + dbg!(res); + assert!(false); + } +} diff --git a/engine/baml-runtime/src/internal/prompt_renderer/mod.rs b/engine/baml-runtime/src/internal/prompt_renderer/mod.rs index 49c5d2743..63423e7a7 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/mod.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/mod.rs @@ -5,7 +5,7 @@ use jsonish::{BamlValueWithFlags, ResponseBamlValue}; use render_output_format::render_output_format; use anyhow::Result; -use baml_types::{BamlValue, FieldType, StreamingBehavior}; +use baml_types::{BamlValue, FieldType, StreamingBehavior, TypeValue}; use internal_baml_core::{ error_unsupported, ir::{ @@ -25,10 +25,10 @@ use super::llm_client::parsed_value_to_response; #[derive(Debug)] pub struct PromptRenderer { - function_name: String, - client_spec: ClientSpec, - output_defs: OutputFormatContent, - output_type: FieldType, + pub function_name: String, + pub client_spec: ClientSpec, + pub output_defs: OutputFormatContent, + pub output_type: FieldType, } impl PromptRenderer { @@ -53,6 +53,18 @@ impl PromptRenderer { }) } + /// A temporary function used to generate a fake prompt renderer, for cases + /// when we call BamlRuntime's `call` API with Expression fns, which + /// don't have a prompt. + pub fn mk_fake() -> PromptRenderer { + PromptRenderer { + function_name: "fake".into(), + client_spec: ClientSpec::Named("fake".into()), + output_defs: OutputFormatContent::mk_fake(), + output_type: FieldType::Primitive(TypeValue::String), + } + } + pub fn client_spec(&self) -> &ClientSpec { &self.client_spec } diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 7e2d19341..eb1ce8eda 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -428,6 +428,7 @@ fn relevant_data_models<'a>( } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _) => {} + (FieldType::Arrow(_), _) => {} (FieldType::WithMetadata { .. }, _) => { unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") } diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index 410e2f6f3..3fd0ba8f0 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -9,6 +9,7 @@ pub(crate) mod internal; pub mod cli; pub mod client_registry; pub mod errors; +pub mod eval_expr; pub mod request; mod runtime; pub mod runtime_interface; @@ -27,16 +28,26 @@ use std::sync::Arc; use anyhow::Context; use anyhow::Result; - +use futures::channel::mpsc; +use internal_baml_core::ast::Span; +use internal_baml_core::ir::repr::initial_context; +use jsonish::ResponseValueMeta; +use tokio::sync::Mutex; + +use crate::internal::llm_client::LLMCompleteResponse; +use baml_types::expr::{Expr, ExprMetadata}; use baml_types::tracing::events::FunctionId; use baml_types::tracing::events::HTTPBody; use baml_types::tracing::events::HTTPRequest; use baml_types::tracing::events::HttpRequestId; use baml_types::BamlMap; use baml_types::BamlValue; +use baml_types::BamlValueWithMeta; +use baml_types::Completion; use baml_types::Constraint; use cfg_if::cfg_if; use client_registry::ClientRegistry; +use eval_expr::EvalEnv; use futures::future::join; use futures::future::join_all; use indexmap::IndexMap; @@ -52,6 +63,7 @@ use internal_baml_core::configuration::CodegenGenerator; use internal_baml_core::configuration::Generator; use internal_baml_core::configuration::GeneratorOutputType; use internal_baml_core::ir::FunctionWalker; +use internal_baml_core::ir::IRHelperExtended; use internal_llm_client::AllowedRoleMetadata; use internal_llm_client::ClientSpec; use jsonish::ResponseBamlValue; @@ -62,7 +74,9 @@ use serde_json::json; use std::sync::OnceLock; use tracingv2::storage::storage::Collector; use tracingv2::storage::storage::BAML_TRACER; +use web_time::SystemTime; +use crate::internal::llm_client::LLMCompleteResponseMetadata; #[cfg(not(target_arch = "wasm32"))] pub use cli::RuntimeCliDefaults; pub use runtime_context::{ @@ -74,6 +88,7 @@ use runtime_interface::RuntimeInterface; use tracing::{BamlTracer, TracingSpan}; use type_builder::TypeBuilder; pub use types::*; +use web_time::Duration; #[cfg(feature = "internal")] pub use internal_baml_jinja::{ChatMessagePart, RenderedPrompt}; @@ -87,7 +102,10 @@ pub(crate) use runtime_interface::InternalRuntimeInterface; pub use internal_baml_core::internal_baml_diagnostics; pub use internal_baml_core::internal_baml_diagnostics::Diagnostics as DiagnosticsError; -pub use internal_baml_core::ir::{scope_diagnostics, FieldType, IRHelper, TypeValue}; +pub use internal_baml_core::internal_baml_diagnostics::SerializedSpan; +pub use internal_baml_core::ir::{ + ir_helpers::infer_type, scope_diagnostics, FieldType, IRHelper, TypeValue, +}; use crate::internal::llm_client::LLMResponse; use crate::test_constraints::{evaluate_test_constraints, TestConstraintsResult}; @@ -268,7 +286,9 @@ impl BamlRuntime { .get_test_params(function_name, test_name, ctx, strict)?; let constraints = self .inner - .get_test_constraints(function_name, test_name, ctx)?; + .get_test_constraints(function_name, test_name, ctx) + .unwrap_or(vec![]); // TODO: Fix this. + // .get_test_constraints(function_name, test_name, ctx)?; Ok((params, constraints)) } @@ -283,12 +303,13 @@ impl BamlRuntime { .get_test_params(function_name, test_name, ctx, strict) } - pub async fn run_test( + pub async fn run_test_with_expr_events( &self, function_name: &str, test_name: &str, ctx: &RuntimeContextManager, on_event: Option, + expr_tx: Option>>, collector: Option>, ) -> (Result, Option) where @@ -296,10 +317,42 @@ impl BamlRuntime { { let span = self.tracer.start_span(test_name, ctx, &Default::default()); - let type_builder = self - .inner - .get_test_type_builder(function_name, test_name, ctx) - .unwrap(); + let expr_fn = self.inner.ir().find_expr_fn(function_name); + let is_expr_fn = expr_fn.is_ok(); + + if is_expr_fn { + // let type_builder = self + // .inner + // .get_test_type_builder(function_name, test_name, ctx) + // .ok_or(None); + let rctx = ctx + .create_ctx(None, None, span.clone().map(|s| s.span_id)) + .unwrap(); + let (params, _constraints) = self + .get_test_params_and_constraints(function_name, test_name, &rctx, true) + .unwrap(); + + // Call the runtime synchronously. + let (response_res, span_uuid) = self + .call_function_with_expr_events( + function_name.into(), + ¶ms, + &ctx, + None, // TODO: Test with TypeBuilder. + None, // TODO: Create callback. + None, // TODO: Use Collectors? + expr_tx, + ) + .await; + + log::info!("** response_res: {:#?}", response_res); + let test_response = TestResponse { + function_response: response_res.unwrap(), + function_span: span_uuid, + constraints_result: TestConstraintsResult::empty(), + }; + return (Ok(test_response), None); + } if let Some(span) = span.clone() { if let Some(collector) = collector { @@ -308,6 +361,11 @@ impl BamlRuntime { } let run_to_response = || async { + let type_builder = self + .inner + .get_test_type_builder(function_name, test_name, ctx) + .unwrap(); + let rctx = ctx.create_ctx(type_builder.as_ref(), None, span.clone().map(|s| s.span_id))?; let (params, constraints) = @@ -350,7 +408,7 @@ impl BamlRuntime { evaluate_test_constraints( ¶ms, &value_with_constraints, - complete_resp, + &complete_resp, constraints, ) } @@ -384,6 +442,30 @@ impl BamlRuntime { (response, target_id) } + pub async fn run_test( + &self, + function_name: &str, + test_name: &str, + ctx: &RuntimeContextManager, + on_event: Option, + collector: Option>, + ) -> (Result, Option) + where + F: Fn(FunctionResult), + { + let res = self + .run_test_with_expr_events::( + function_name, + test_name, + ctx, + on_event, + None, + collector, + ) + .await; + res + } + #[cfg(not(target_arch = "wasm32"))] pub fn call_function_sync( &self, @@ -406,6 +488,22 @@ impl BamlRuntime { tb: Option<&TypeBuilder>, cb: Option<&ClientRegistry>, collectors: Option>>, + ) -> (Result, Option) { + let res = self + .call_function_with_expr_events(function_name, params, ctx, tb, cb, collectors, None) + .await; + res + } + + pub async fn call_function_with_expr_events( + &self, + function_name: String, + params: &BamlMap, + ctx: &RuntimeContextManager, + tb: Option<&TypeBuilder>, + cb: Option<&ClientRegistry>, + collectors: Option>>, + expr_tx: Option>>, ) -> (Result, Option) { log::trace!("Calling function: {}", function_name); let span = self.tracer.start_span(&function_name, ctx, params); @@ -418,11 +516,114 @@ impl BamlRuntime { } } + let fake_syntax_span = Span::fake(); let response = match ctx.create_ctx(tb, cb, span.clone().map(|s| s.span_id)) { Ok(rctx) => { - self.inner - .call_function_impl(function_name.clone(), params, rctx) - .await + let is_expr_fn = self + .inner + .ir() + .expr_fns + .iter() + .find(|f| f.elem.name == function_name) + .is_some(); + if !is_expr_fn { + self.inner + .call_function_impl(function_name, params, rctx) + .await + } else { + // TODO: This code path is ugly. Calling a function heavily assumes that the + // function is an LLM function. Find a way to make function-calling API more + // hospitable to Expression Fns, or create new APIs for calling Expr Fns. + let expr_fn = &self + .inner + .ir() + .expr_fns + .iter() + .find(|f| f.elem.name == function_name) + .expect("We checked earlier that this function is an expr_fn") + .elem; + let fn_expr = expr_fn.expr.clone(); + let context = initial_context(&self.inner.ir()); + let env = EvalEnv { + context, + runtime: self, + expr_tx: expr_tx.clone(), + }; + let param_baml_values = params + .iter() + .map(|(k, v)| { + let arg_type = infer_type(v); + let baml_value_with_meta: BamlValueWithMeta = + match arg_type { + None => Ok::<_, anyhow::Error>( + BamlValueWithMeta::with_const_meta(v, (Span::fake(), None)), + ), + Some(arg_type) => { + let value_unit_meta: BamlValueWithMeta<()> = + BamlValueWithMeta::with_const_meta(v, ()); + let baml_value = self + .inner + .ir() + .distribute_type_with_meta(value_unit_meta, arg_type)?; + let baml_value_with_meta = + baml_value.map_meta_owned(|(_, field_type)| { + (Span::fake(), Some(field_type)) + }); + + Ok(baml_value_with_meta) + } + }?; + Ok(Expr::Atom(baml_value_with_meta)) + }) + .collect::>() + .unwrap_or(vec![]); //TODO: Is it acceptable to swallow errors here? + + let params_expr: Expr = + Expr::ArgsTuple(param_baml_values, (fake_syntax_span.clone(), None)); + let result_type = expr_fn.output.clone(); + let fn_call_expr = Expr::App( + Arc::new(fn_expr), + Arc::new(params_expr), + (fake_syntax_span.clone(), Some(result_type.clone())), + ); + let res = eval_expr::eval_to_value(&env, &fn_call_expr) + .await + .map(|v| { + v.map(|v| { + ResponseBamlValue(v.map_meta(|_| { + ResponseValueMeta( + vec![], + vec![], + Completion::default(), + result_type.clone(), + ) + })) + }) + }) + .transpose(); + + let llm_response = LLMResponse::Success(LLMCompleteResponse { + client: "openai".to_string(), + model: "gpt-3.5-turbo".to_string(), + prompt: RenderedPrompt::Completion("Sample raw response".to_string()), + request_options: BamlMap::new(), + content: "Sample raw response".to_string(), + start_time: SystemTime::now(), + latency: Duration::from_millis(2025), + metadata: LLMCompleteResponseMetadata { + baml_is_complete: true, + finish_reason: Some("stop".to_string()), + prompt_tokens: Some(50), + output_tokens: Some(50), + total_tokens: Some(100), + }, + }); + Ok(FunctionResult::new( + OrchestrationScope { scope: vec![] }, + llm_response, + res, + )) + } } Err(e) => Err(e), }; diff --git a/engine/baml-runtime/src/runtime/runtime_interface.rs b/engine/baml-runtime/src/runtime/runtime_interface.rs index 47267fd69..b1c75555f 100644 --- a/engine/baml-runtime/src/runtime/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime/runtime_interface.rs @@ -33,10 +33,11 @@ use baml_types::tracing::events::{ }; use baml_types::{BamlMap, BamlValue, Constraint, EvaluationContext}; -use internal_baml_core::ir::repr::TypeBuilderEntry; +use internal_baml_core::ir::repr::{Node, TypeBuilderEntry}; +use internal_baml_core::ir::TestCase; use internal_baml_core::{ internal_baml_diagnostics::SourceFile, - ir::{repr::IntermediateRepr, ArgCoercer, FunctionWalker, IRHelper}, + ir::{repr::IntermediateRepr, ArgCoercer, ExprFunctionWalker, FunctionWalker, IRHelper}, validate, }; use internal_baml_jinja::RenderedPrompt; @@ -152,8 +153,9 @@ impl InternalRuntimeInterface for InternalBamlRuntime { node_index: Option, ) -> Result<(RenderedPrompt, OrchestrationScope, AllowedRoleMetadata)> { let func = self.get_function(function_name, ctx)?; + let function_params = func.inputs(); let baml_args = self.ir().check_function_params( - &func, + &function_params, params, ArgCoercer { span_path: None, @@ -229,6 +231,15 @@ impl InternalRuntimeInterface for InternalBamlRuntime { Ok(walker) } + fn get_expr_function<'ir>( + &'ir self, + function_name: &str, + _ctx: &RuntimeContext, + ) -> Result> { + let walker = self.ir().find_expr_fn(function_name)?; + Ok(walker) + } + fn ir(&self) -> &IntermediateRepr { use std::ops::Deref; self.ir.deref() @@ -241,13 +252,26 @@ impl InternalRuntimeInterface for InternalBamlRuntime { ctx: &RuntimeContext, strict: bool, ) -> Result> { - let func = self.get_function(function_name, ctx)?; - let test = self.ir().find_test(&func, test_name)?; + let maybe_test_and_params = self.get_function(function_name, ctx).and_then(|func| { + let test = self.ir().find_test(&func, test_name)?; + let test_case_params = test.test_case_params(&ctx.eval_ctx(strict))?; + let inputs = func.inputs().clone(); + Ok((test_case_params, inputs)) + }); + let maybe_expr_test_and_params = + self.get_expr_function(function_name, ctx).and_then(|func| { + let test = self.ir().find_expr_fn_test(&func, test_name)?; + let test_case_params = test.test_case_params(&ctx.eval_ctx(strict))?; + let inputs = func.inputs().clone(); + Ok((test_case_params, inputs)) + }); + + let maybe_params = maybe_test_and_params.or(maybe_expr_test_and_params); let eval_ctx = ctx.eval_ctx(strict); - match test.test_case_params(&eval_ctx) { - Ok(params) => { + match maybe_params { + Ok((params, function_params)) => { // Collect all errors and return them as a single error. let mut errors = Vec::new(); let params = params @@ -269,10 +293,10 @@ impl InternalRuntimeInterface for InternalBamlRuntime { } let baml_args = self.ir().check_function_params( - &func, + &function_params, ¶ms, ArgCoercer { - span_path: test.span().map(|s| s.file.path_buf().clone()), + span_path: None, allow_implicit_cast_to_string: true, }, )?; @@ -404,23 +428,23 @@ impl RuntimeInterface for InternalBamlRuntime { ); } - let future = async { - let func = match self.get_function(&function_name, &ctx) { - Ok(func) => func, - Err(e) => { - return Ok(FunctionResult::new( - OrchestrationScope::default(), - LLMResponse::UserFailure(format!( - "BAML function {function_name} does not exist in baml_src/ (did you typo it?): {:?}", - e - )), - None, - )) - } - }; + let func = match self.get_function(&function_name, &ctx) { + Ok(func) => func, + Err(e) => { + return Ok(FunctionResult::new( + OrchestrationScope::default(), + LLMResponse::UserFailure(format!( + "BAML function {function_name} does not exist in baml_src/ (did you typo it?): {:?}", + e + )), + None, + )) + } + }; + let future = async { let baml_args = self.ir().check_function_params( - &func, + &func.inputs(), params, ArgCoercer { span_path: None, @@ -440,6 +464,34 @@ impl RuntimeInterface for InternalBamlRuntime { FunctionResult::new_chain(history) }; + let baml_args = self.ir().check_function_params( + func.inputs(), + params, + ArgCoercer { + span_path: None, + allow_implicit_cast_to_string: false, + }, + )?; + let baml_args = match self.ir().check_function_params( + &func.inputs(), + ¶ms, + ArgCoercer { + span_path: None, + allow_implicit_cast_to_string: false, + }, + ) { + Ok(args) => args, + Err(e) => { + return Ok(FunctionResult::new( + OrchestrationScope::default(), + LLMResponse::UserFailure(format!( + "Failed while validating args for {function_name}: {:?}", + e + )), + None, + )) + } + }; let result = future.await; @@ -480,33 +532,63 @@ impl RuntimeInterface for InternalBamlRuntime { #[cfg(not(target_arch = "wasm32"))] tokio_runtime: Arc, collectors: Vec>, ) -> Result { - let func = self.get_function(&function_name, &ctx)?; - let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?; - let orchestrator = self.orchestration_graph(renderer.client_spec(), &ctx)?; - let Some(baml_args) = self - .ir - .check_function_params( - &func, - params, - ArgCoercer { - span_path: None, - allow_implicit_cast_to_string: false, - }, - )? - .as_map_owned() - else { - anyhow::bail!("Expected parameters to be a map for: {}", function_name); - }; - Ok(FunctionResultStream { - function_name, - ir: self.ir.clone(), - params: baml_args, - orchestrator, - tracer, - renderer, - #[cfg(not(target_arch = "wasm32"))] - tokio_runtime, - collectors, - }) + let is_expr_fn = self.get_expr_function(&function_name, &ctx).is_ok(); + if is_expr_fn { + let func = self.get_expr_function(&function_name, &ctx)?; + let renderer = PromptRenderer::mk_fake(); + let orchestrator = vec![]; + let baml_args = self + .ir + .check_function_params( + &func.inputs(), + params, + ArgCoercer { + span_path: None, + allow_implicit_cast_to_string: false, + }, + )? + .as_map_owned() + .ok_or(anyhow::anyhow!("Failed to check function params."))?; + Ok(FunctionResultStream { + function_name, + ir: self.ir.clone(), + params: baml_args, + orchestrator, + tracer, + renderer, + #[cfg(not(target_arch = "wasm32"))] + tokio_runtime, + collectors, + }) + } else { + let func = self.get_function(&function_name, &ctx)?; + let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?; + let orchestrator = self.orchestration_graph(renderer.client_spec(), &ctx)?; + let Some(baml_args) = self + .ir + .check_function_params( + &func.inputs(), + params, + ArgCoercer { + span_path: None, + allow_implicit_cast_to_string: false, + }, + )? + .as_map_owned() + else { + anyhow::bail!("Expected parameters to be a map for: {}", function_name); + }; + Ok(FunctionResultStream { + function_name, + ir: self.ir.clone(), + params: baml_args, + orchestrator, + tracer, + renderer, + #[cfg(not(target_arch = "wasm32"))] + tokio_runtime, + collectors, + }) + } } } diff --git a/engine/baml-runtime/src/runtime_interface.rs b/engine/baml-runtime/src/runtime_interface.rs index 9ec8cd1f0..61560f24c 100644 --- a/engine/baml-runtime/src/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime_interface.rs @@ -1,6 +1,7 @@ use anyhow::Result; use baml_types::{BamlMap, BamlValue, Constraint}; use internal_baml_core::internal_baml_diagnostics::Diagnostics; +use internal_baml_core::ir::ExprFunctionWalker; use internal_baml_core::ir::{repr::IntermediateRepr, FunctionWalker}; use internal_baml_jinja::RenderedPrompt; use internal_llm_client::{AllowedRoleMetadata, ClientSpec}; @@ -133,6 +134,11 @@ pub trait InternalRuntimeInterface { function_name: &str, ctx: &RuntimeContext, ) -> Result>; + fn get_expr_function<'ir>( + &'ir self, + function_name: &str, + ctx: &RuntimeContext, + ) -> Result>; #[allow(async_fn_in_trait)] async fn render_prompt( diff --git a/engine/baml-runtime/src/test_executor/mod.rs b/engine/baml-runtime/src/test_executor/mod.rs index 4078fd10e..ac869718d 100644 --- a/engine/baml-runtime/src/test_executor/mod.rs +++ b/engine/baml-runtime/src/test_executor/mod.rs @@ -9,6 +9,7 @@ pub use test_execution_args::TestFilter; use std::{ collections::{BTreeMap, BTreeSet}, ops::Deref, + sync::Arc, time::Instant, }; @@ -216,8 +217,8 @@ impl TestExecutor for BamlRuntime { let tx = tx.clone(); // Clone the Arc pointer for self here. let runtime = self.clone(); - let function_name = fn_name.clone(); - let test_name = tt_name.clone(); + let function_name = fn_name.to_string(); + let test_name = tt_name.to_string(); let fut = tokio::spawn(async move { let _permit = semaphore.acquire().await.unwrap(); let ctx_manager = runtime.create_ctx_manager( @@ -232,18 +233,12 @@ impl TestExecutor for BamlRuntime { TestExecutionStatus::Running, )); let (result, _) = runtime - .run_test( - function_name.as_str(), - test_name.as_str(), - &ctx_manager, - Some(|_| {}), - None, - ) + .run_test(&function_name, &test_name, &ctx_manager, Some(|_| {}), None) .await; let duration = start_instant.elapsed(); let _ = tx.send(( - function_name.clone(), - test_name.clone(), + function_name, + test_name, TestExecutionStatus::Finished(result, duration), )); }); diff --git a/engine/baml-schema-wasm/Cargo.toml b/engine/baml-schema-wasm/Cargo.toml index 0727e2c21..0ebc5adcc 100644 --- a/engine/baml-schema-wasm/Cargo.toml +++ b/engine/baml-schema-wasm/Cargo.toml @@ -19,7 +19,6 @@ unused_variables = "deny" [dependencies] reqwest.workspace = true anyhow.workspace = true -futures.workspace = true baml-runtime = { path = "../baml-runtime", features = [ "internal", ], default-features = false } @@ -27,6 +26,7 @@ baml-types = { path = "../baml-lib/baml-types" } cfg-if.workspace = true console_error_panic_hook = "0.1.7" console_log = { version = "1", features = ["color"] } +futures.workspace = true getrandom = { version = "0.2.15", features = ["js"] } indexmap.workspace = true internal-baml-codegen.workspace = true diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index 57abd9c12..e15eb79ba 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -9,8 +9,10 @@ use baml_runtime::internal::llm_client::orchestrator::OrchestrationScope; use baml_runtime::internal::llm_client::orchestrator::OrchestratorNode; use baml_runtime::internal::prompt_renderer::PromptRenderer; use baml_runtime::BamlSrcReader; +use baml_runtime::FunctionResult; use baml_runtime::InternalRuntimeInterface; use baml_runtime::RenderCurlSettings; +use baml_runtime::SerializedSpan; use baml_runtime::{ internal::llm_client::LLMResponse, BamlRuntime, DiagnosticsError, IRHelper, RenderedPrompt, }; @@ -19,7 +21,11 @@ use baml_types::{BamlMediaType, BamlValue, GeneratorOutputType, TypeValue}; use indexmap::IndexMap; use internal_baml_codegen::version_check::GeneratorType; use internal_baml_codegen::version_check::{check_version, VersionCheckMode}; +use internal_baml_core::ir::repr::Walker; use internal_llm_client::AllowedRoleMetadata; + +use futures::channel::mpsc; +use futures::StreamExt; use itertools::join; use js_sys::Promise; use js_sys::Uint8Array; @@ -909,6 +915,7 @@ fn get_dummy_value( Some(format!("({},)", dummy)) } baml_runtime::FieldType::Optional(_) => None, + baml_runtime::FieldType::Arrow(_) => None, baml_runtime::FieldType::WithMetadata { base, .. } => { get_dummy_value(indent, allow_multiline, base) } @@ -976,6 +983,17 @@ impl WasmRuntime { .internal() .ir() .walk_functions() + .chain( + self.runtime + .internal() + .ir() + .expr_fns_as_functions() + .iter() + .map(|f| Walker { + ir: &self.runtime.internal().ir(), + item: f, + }), + ) .map(|f| { let snippet = format!( r#"test TestName {{ @@ -1805,6 +1823,84 @@ impl WasmFunction { .map_err(|e| wasm_bindgen::JsError::new(format!("{e:?}").as_str())) } + #[wasm_bindgen] + pub async fn run_test_with_expr_events( + &self, + rt: &mut WasmRuntime, + test_name: String, + env_vars: JsValue, + on_partial_response: js_sys::Function, + get_baml_src_cb: js_sys::Function, + load_aws_creds_cb: js_sys::Function, + on_expr_event: js_sys::Function, + ) -> Result { + log::info!("TEST LOGGING"); + let rt = &rt.runtime; + let function_name = self.name.clone(); + + let function_name_for_test_pair = function_name.clone(); + let test_name_for_test_pair = test_name.clone(); + + // Create the closure to handle partial responses: + let cb = Box::new(move |r: FunctionResult| { + let this = JsValue::NULL; + let res = WasmFunctionResponse { + function_response: r, + func_test_pair: WasmFunctionTestPair { + function_name: function_name_for_test_pair.clone(), + test_name: test_name_for_test_pair.clone(), + }, + } + .into(); + on_partial_response.call1(&this, &res).unwrap(); + }); + + // Create the channel for expression events + let (tx, mut rx) = mpsc::unbounded::>(); + + // Spawn a task to handle expression events + let on_expr_event_clone = on_expr_event.clone(); + wasm_bindgen_futures::spawn_local(async move { + while let Some(spans) = rx.next().await { + let this = JsValue::NULL; + match serde_wasm_bindgen::to_value(&spans) { + Ok(res) => { + on_expr_event_clone.call1(&this, &res).expect("TODO"); + } + Err(e) => { + log::error!("Error serializing spans: {e}"); + } + } + } + }); + + // Create your evaluation context, etc. + let ctx = rt.create_ctx_manager_with_env( + BamlValue::String("wasm".to_string()), + serde_wasm_bindgen::from_value::>(env_vars) + .map_err(|e| JsValue::from_str(&format!("Failed to parse env_vars: {:?}", e)))?, + js_fn_to_baml_src_reader(get_baml_src_cb), + js_fn_to_aws_cred_provider(load_aws_creds_cb), + ); + + // Pass the sender to run_test_with_expr_events + let (test_response, span) = rt + .run_test_with_expr_events(&function_name, &test_name, &ctx, Some(cb), Some(tx), None) + .await; + + log::info!("test_response: {:#?}", test_response); + + Ok(WasmTestResponse { + test_response, + span, + tracing_project_id: rt.env_vars().get("BOUNDARY_PROJECT_ID").cloned(), + func_test_pair: WasmFunctionTestPair { + function_name, + test_name, + }, + }) + } + #[wasm_bindgen] pub async fn run_test( &self, diff --git a/engine/language_client_cffi/src/ctypes.rs b/engine/language_client_cffi/src/ctypes.rs index 5251579b7..1e509523d 100644 --- a/engine/language_client_cffi/src/ctypes.rs +++ b/engine/language_client_cffi/src/ctypes.rs @@ -204,13 +204,13 @@ impl From> for BamlValue { } impl From> for BamlValue { - fn from(value: CFFIValueChecked) -> Self { + fn from(_value: CFFIValueChecked) -> Self { unimplemented!("CFFIValueChecked is not supported"); } } impl From> for BamlValue { - fn from(value: CFFIValueStreamingState) -> Self { + fn from(_value: CFFIValueStreamingState) -> Self { unimplemented!("CFFIValueStreamingState is not supported"); } } @@ -605,12 +605,11 @@ where type_alias.as_union_value(), ) } - baml_types::FieldType::Tuple(field_types) => unimplemented!("Tuple is not supported"), - baml_types::FieldType::WithMetadata { - base, - constraints, - streaming_behavior, - } => unimplemented!("WithMetadata is not supported"), + baml_types::FieldType::Tuple(_field_types) => unimplemented!("Tuple is not supported"), + baml_types::FieldType::WithMetadata { .. } => { + unimplemented!("WithMetadata is not supported") + } + baml_types::FieldType::Arrow(_) => unimplemented!("Functions are not supported."), }; CFFIFieldTypeHolder::create( diff --git a/engine/language_client_codegen/src/go/generate_types.rs b/engine/language_client_codegen/src/go/generate_types.rs index 4662bf9ef..c0a1ffd6e 100644 --- a/engine/language_client_codegen/src/go/generate_types.rs +++ b/engine/language_client_codegen/src/go/generate_types.rs @@ -502,6 +502,7 @@ fn has_none_default(ir: &IntermediateRepr, field_type: &FieldType) -> bool { FieldType::WithMetadata { .. } => { unreachable!("FieldType::WithMetadata is always consumed by distribute_metadata") } + FieldType::Arrow(_) => panic!("Generation is not supported with expr fns"), } } @@ -595,6 +596,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { } None => base.to_type_ref_2(ir, use_module_prefix).name, }, + FieldType::Arrow(_) => panic!("Generation is not supported with expr fns"), } } @@ -672,6 +674,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { FieldType::WithMetadata { .. } => { unreachable!("distribute_metadata makes this branch unreachable.") } + FieldType::Arrow(_) => panic!("Generation is not supported with expr fns"), }; let base_type_ref = if is_partial_type { base_rep diff --git a/engine/language_client_codegen/src/go/mod.rs b/engine/language_client_codegen/src/go/mod.rs index 7b2984f92..71cc46132 100644 --- a/engine/language_client_codegen/src/go/mod.rs +++ b/engine/language_client_codegen/src/go/mod.rs @@ -166,6 +166,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } None => base.to_type_ref(ir, _with_checked), }, + FieldType::Arrow(_) => panic!("Generation is not supported with expr fns"), } } @@ -220,6 +221,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } None => base.to_partial_type_ref(ir, with_checked), }, + FieldType::Arrow(_) => panic!("Generation is not supported with expr fns"), } } } diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index 5614b8ca8..a83f13f59 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -682,6 +682,7 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { } None => base.to_type_spec(_ir)?, }, + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), }) } } diff --git a/engine/language_client_codegen/src/python/generate_types.rs b/engine/language_client_codegen/src/python/generate_types.rs index f4dc7004a..aee68640e 100644 --- a/engine/language_client_codegen/src/python/generate_types.rs +++ b/engine/language_client_codegen/src/python/generate_types.rs @@ -271,6 +271,7 @@ fn has_none_default(ir: &IntermediateRepr, field_type: &FieldType) -> bool { FieldType::WithMetadata { .. } => { unreachable!("FieldType::WithMetadata is always consumed by distribute_metadata") } + FieldType::Arrow(_) => false, } } @@ -361,6 +362,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { } None => base.to_type_ref(ir, use_module_prefix), }, + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), } } @@ -458,6 +460,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { FieldType::WithMetadata { .. } => { unreachable!("distribute_metadata makes this branch unreachable.") } + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), }; let base_type_ref = if is_partial_type { base_rep diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index f5d29f890..aeda35310 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -323,6 +323,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } None => base.to_type_ref(ir, _with_checked), }, + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), } } @@ -378,6 +379,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } None => base.to_partial_type_ref(ir, with_checked), }, + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), } } } @@ -398,6 +400,7 @@ fn default_value_for_parameter_type(field_type: &FieldType) -> Option<&'static s FieldType::Primitive(_) => None, FieldType::Union(xs) => None, FieldType::WithMetadata { base, .. } => default_value_for_parameter_type(base), + FieldType::Arrow(_) => None, } } diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index 89dc7756f..f1510e8e7 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -57,6 +57,7 @@ impl ToRuby for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("T.nilable({})", inner.to_ruby()), + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(_) => { let base_type_ref = base.to_ruby(); diff --git a/engine/language_client_codegen/src/ruby/generate_types.rs b/engine/language_client_codegen/src/ruby/generate_types.rs index e865042b4..3dd8de621 100644 --- a/engine/language_client_codegen/src/ruby/generate_types.rs +++ b/engine/language_client_codegen/src/ruby/generate_types.rs @@ -253,6 +253,7 @@ impl ToTypeReferenceInTypeDefinition<'_> for FieldType { } } FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false), + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { let base_type_ref = base.to_partial_type_ref(ir, false); diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index 53f97c1fe..58ad0bd6a 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -594,6 +594,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { FieldType::WithMetadata { .. } => { unreachable!("distribute_metadata makes this field unreachable.") } + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), }; let base_type_ref = if is_partial_type { base_rep @@ -674,6 +675,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { FieldType::Optional(inner) => { format!("{} | null", inner.to_type_ref(ir, use_module_prefix)) } + FieldType::Arrow(_) => todo!("Arrow types should not be used in generated type definitions"), FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { let base_type_ref = base.to_type_ref(ir, use_module_prefix); diff --git a/flake.lock b/flake.lock index 2dab9fe68..457574a3e 100644 --- a/flake.lock +++ b/flake.lock @@ -8,11 +8,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1734935689, - "narHash": "sha256-yl/iko/0pvRN3PF6Z4FjQeb6AuGiavMENEisQWJ78h0=", + "lastModified": 1741070164, + "narHash": "sha256-zgHp8rxIbJFeF2DuEMAhKqfdUnclcjaVfdhLNgX5nUM=", "owner": "nix-community", "repo": "fenix", - "rev": "30616281e9bfe0883acb3369f2b89aad6850706f", + "rev": "c36306dbcc4ad8128e659ea072ad35e02936b03e", "type": "github" }, "original": { @@ -57,16 +57,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1734649271, - "narHash": "sha256-4EVBRhOjMDuGtMaofAIqzJbg4Ql7Ai0PSeuVZTHjyKQ=", + "lastModified": 1735563628, + "narHash": "sha256-OnSAY7XDSx7CtDoqNh8jwVwh4xNL/2HaJxGjryLWzX8=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "d70bd19e0a38ad4790d3913bf08fcbfc9eeca507", + "rev": "b134951a4c9f3c995fd7be05f3243f8ecd65d798", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-unstable", + "ref": "nixos-24.05", "repo": "nixpkgs", "type": "github" } @@ -82,11 +82,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1734874959, - "narHash": "sha256-NlsVD/fI32wsHFua9Xvc7IFHCUpQIOs6D6RS/3AhMT8=", + "lastModified": 1741011961, + "narHash": "sha256-bssSxw3Z9CUNB9+f3EHAX/2urT15e12Jy6YU8tHyWkk=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "fa4a40bbe867ed54f5a7c905b591fd7d60ba35eb", + "rev": "02862f5d52c30b476a5dca909a17aa4386d1fdc5", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 3edc91dc2..636914454 100644 --- a/flake.nix +++ b/flake.nix @@ -1,6 +1,6 @@ { inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.05"; flake-utils.url = "github:numtide/flake-utils"; fenix = { url = "github:nix-community/fenix"; @@ -14,54 +14,21 @@ outputs = { self, nixpkgs, flake-utils, fenix, ... }: - let - - # buildTargets = { - # "x86_64-linux" = { - # crossSystemConfig = "x86_64-unknown-linux-musl"; - # rustTarget = "x86_64-unknown-linux-musl"; - # }; - # "aarch64-linux" = { - # crossSystemConfig = "x86_64-unknown-linux-musl"; - # rustTarget = "x86_64-unknown-linux-musl"; - # }; - # "aarch64-darwin" = {}; - # "wasm" = { - # crossSystemConfig = "wasm32-unknown-unknown"; - # rustTarget = "wasm32-unknown-unknown"; - # makeBuildPackageAttrs = pkgsCross: { - # OPENSSL_STATIC = null; - # OPENSSL_LIB_DIR = null; - # OPENSSL_INCLUDE_DIR = null; - # }; - # }; - # }; - - # mkPkgs = buildSystem: targetSystem: import nixpkgs ({ - # system = buildSystem; - # } // (if targetSystem == null then {} else { - # crossSystemcnofig = buildTargets.${targetSystem}.crossSystemConfig; - # })); - - # eachSystem = supportedSystems: callback: builtins.fold' - # (overall: system: overall // { ${system} = callback system; }) - # {} - # supportedSystems; - - in flake-utils.lib.eachDefaultSystem (system: let pkgs = nixpkgs.legacyPackages.${system}; - clang = pkgs.llvmPackages_19.clang; + clang = pkgs.llvmPackages_17.clang; pythonEnv = pkgs.python3.withPackages (ps: []); toolchain = with fenix.packages.${system}; combine [ minimal.cargo minimal.rustc minimal.rust-std + complete.rustfmt targets.wasm32-unknown-unknown.latest.rust-std + targets.x86_64-unknown-linux-musl.latest.rust-std ]; version = (builtins.fromTOML (builtins.readFile ./engine/Cargo.toml)).workspace.package.version; @@ -77,19 +44,47 @@ inherit (fenix.packages.${system}.latest) rust-std; }; + # wasm-bindgen-cli = pkgs.rustPlatform.buildRustPackage rec { + # pname = "wasm-bindgen-cli"; + # version = "0.2.92"; + # src = pkgs.fetchFromGitHub { + # owner = "rustwasm"; + # repo = "wasm-bindgen"; + # rev = "${version}"; + # sha256 = "sha256-VMt+J5sazHPqmAdsoueS2WW6Pn1tvugaJaPnSJq9038="; + # }; + # cargoHash = "sha256-+iIHleftJ+Yl9QHEBVI91NOhBw9qtUZfgooHKoyY1w4="; + # buildInputs = with pkgs; [ openssl ]; + # nativeBuildInputs = with pkgs; [ pkg-config ]; + # cargoBuildFlags = ["--package wasm-bindgen-cli"]; + # }; + buildInputs = (with pkgs; [ + cmake git openssl pkg-config - lld_19 + lld_17 pythonEnv ruby + ruby.devEnv maturin - nodePackages.pnpm - nodePackages.nodejs + vsce # VSCode extension packaging tool toolchain + nodejs uv wasm-pack + pkgs.gcc + napi-rs-cli + wasm-bindgen-cli + + # For building the typescript client. + pixman + cairo + pango + libjpeg + giflib + librsvg ]) ++ (if pkgs.stdenv.isDarwin then appleDeps else []); nativeBuildInputs = [ pkgs.openssl @@ -97,32 +92,45 @@ pkgs.ruby pythonEnv pkgs.maturin + pkgs.perl + pkgs.lld_17 + pkgs.gcc ]; + in { packages.default = rustPlatform.buildRustPackage { + + # Disable tests in this build - FFI is a little tricky. + doCheck = false; + + # Temporary: do a debug build instead of a release build, to speed up the dev cycle. + buildType = "debug"; + pname = "baml-cli"; version = version; - src = let - extraFiles = pkgs.copyPathToStore ./engine/baml-runtime/src/cli/initial_project/baml_src; - in pkgs.symlinkJoin { - name = "source"; - paths = [ ./engine extraFiles ]; - }; + src = ./engine; LIBCLANG_PATH = pkgs.libclang.lib + "/lib/"; BINDGEN_EXTRA_CLANG_ARGS = if pkgs.stdenv.isDarwin then "" # Rely on default includes provided by stdenv.cc + libclang else - "-isystem ${pkgs.llvmPackages_19.libclang.lib}/lib/clang/19/include -isystem ${pkgs.llvmPackages_19.libclang.lib}/include -isystem ${pkgs.glibc.dev}/include"; + "-isystem ${pkgs.llvmPackages_17.libclang.lib}/lib/clang/17/include -isystem ${pkgs.llvmPackages_17.libclang.lib}/include -isystem ${pkgs.glibc.dev}/include"; cargoLock = { lockFile = ./engine/Cargo.lock; outputHashes = { - "pyo3-asyncio-0.21.0" = "sha256-5ZLzWkxp3e2u0B4+/JJTwO9SYKhtmBpMBiyIsTCW5Zw="; - "serde_magnus-0.9.0" = "sha256-+iIHleftJ+Yl9QHEBVI91NOhBw9qtUZfgooHKoyY1w4="; }; }; # Add build-time environment variables - RUSTFLAGS = "-C target-feature=+crt-static --cfg tracing_unstable"; + RUSTFLAGS = if pkgs.stdenv.isDarwin + then + "--cfg tracing_unstable -C linker=lld" + else + "--cfg tracing_unstable -Zlinker-features=+lld -C linker=gcc"; + + OPENSSL_STATIC = "1"; + OPENSSL_DIR = "${pkgs.openssl.dev}"; + OPENSSL_LIB_DIR = "${pkgs.openssl.out}/lib"; + OPENSSL_INCLUDE_DIR = "${pkgs.openssl.dev}/include"; # Modify the test phase to only run library tests checkPhase = '' @@ -132,11 +140,23 @@ runHook postCheck ''; + postPatch = '' + # Disable baml syntax validation tests in build. They require too much + # file system access to run. + cat > baml-lib/baml/build.rs << 'EOF' + fn main() { + println!("cargo:warning=Skipping baml syntax validation tests"); + } + EOF + ''; + inherit buildInputs; + inherit nativeBuildInputs; + PYTHON_SYS_EXECUTABLE="${pythonEnv}/bin/python3"; LD_LIBRARY_PATH="${pythonEnv}/lib"; PYTHONPATH="${pythonEnv}/${pythonEnv.sitePackages}"; - CC="${clang}/bin/clang"; + # CC="${clang}/bin/clang"; # Temporarily commented out for linux testing. }; devShell = pkgs.mkShell rec { @@ -146,7 +166,7 @@ BINDGEN_EXTRA_CLANG_ARGS = if pkgs.stdenv.isDarwin then "" # Rely on default includes provided by stdenv.cc + libclang else - "-isystem ${pkgs.llvmPackages_19.libclang.lib}/lib/clang/19/include -isystem ${pkgs.llvmPackages_19.libclang.lib}/include -isystem ${pkgs.glibc.dev}/include"; + "-isystem ${pkgs.llvmPackages_17.libclang.lib}/lib/clang/17/include -isystem ${pkgs.llvmPackages_17.libclang.lib}/include -isystem ${pkgs.glibc.dev}/include"; }; } ); diff --git a/integ-tests/typescript/package.json b/integ-tests/typescript/package.json index 5019c3d38..07f4a8bbc 100644 --- a/integ-tests/typescript/package.json +++ b/integ-tests/typescript/package.json @@ -8,7 +8,7 @@ "build:debug": "cd ../../engine/language_client_typescript && pnpm run build:debug && cd - && pnpm i", "build": "cd ../../engine/language_client_typescript && npm run build && cd - && pnpm i", "integ-tests:ci": "pnpm tsc && infisical run --env=test -- pnpm test -- --ci --silent false --testTimeout 30000 --verbose=false --reporters=jest-junit", - "integ-tests": "pnpm tsc && infisical run --env=test -- pnpm test -- --silent false --testTimeout 30000", + "integ-tests": "pnpm tsc && infisical run --env=test -- pnpm test -- --silent false --testTimeout 60000", "integ-tests:dotenv": "pnpm tsc && pnpm test -- --silent false --testTimeout 30000", "generate": "baml-cli generate --from ../baml_src", "memory-test": "BAML_LOG=info infisical run --env=test -- pnpm test -- --silent false --testTimeout 60000 -t 'memory'" diff --git a/integ-tests/typescript/tests/logger.test.ts b/integ-tests/typescript/tests/logger.test.ts index 10ba6d9b4..f0d5e11bd 100644 --- a/integ-tests/typescript/tests/logger.test.ts +++ b/integ-tests/typescript/tests/logger.test.ts @@ -45,35 +45,35 @@ describe('Logger tests', () => { expect(getLogLevel()).toBe('INFO') let { result, output } = await captureStdout(() => b.TestOllama("banks using the word 'fiscal'")) - expect(result.toLowerCase()).toContain('fiscal') + expect(result?.toLowerCase()).toContain('fiscal') expect(output).toBe('') // Test with log level "WARN" setLogLevel('WARN') expect(getLogLevel()).toBe('WARN') ;({ result, output } = await captureStdout(() => b.TestOllama("banks using the word 'fiscal'"))) - expect(result.toLowerCase()).toContain('fiscal') + expect(result?.toLowerCase()).toContain('fiscal') expect(output).toBe('') // Finally, reset to "INFO" and test again setLogLevel('INFO') expect(getLogLevel()).toBe('INFO') ;({ result, output } = await captureStdout(() => b.TestOllama("banks using the word 'fiscal'"))) - expect(result.toLowerCase()).toContain('fiscal') + expect(result?.toLowerCase()).toContain('fiscal') expect(output).toBe('') // Test with log level "OFF" setLogLevel('OFF') expect(getLogLevel()).toBe('OFF') ;({ result, output } = await captureStdout(() => b.TestOllama("banks using the word 'fiscal'"))) - expect(result.toLowerCase()).toContain('fiscal') + expect(result?.toLowerCase()).toContain('fiscal') expect(output).toBe('') // Finally, reset to "INFO" and test again setLogLevel('INFO') expect(getLogLevel()).toBe('INFO') ;({ result, output } = await captureStdout(() => b.TestOllama("banks using the word 'fiscal'"))) - expect(result.toLowerCase()).toContain('fiscal') + expect(result?.toLowerCase()).toContain('fiscal') expect(output).toBe('') }) }) diff --git a/typescript/fiddle-frontend/next.config.mjs b/typescript/fiddle-frontend/next.config.mjs index 0a36442f2..cb216f586 100644 --- a/typescript/fiddle-frontend/next.config.mjs +++ b/typescript/fiddle-frontend/next.config.mjs @@ -18,6 +18,12 @@ const nextConfig = { // config.devtool = 'eval-source-map' } + // Fixed WebAssembly loading configuration + config.module.rules.push({ + test: /\.wasm$/, + type: 'asset/resource', + }) + if (!isServer) { // watch my locak pnpm package @gloo-ai/playground-common for changes config.watchOptions = { diff --git a/typescript/nextjs-plugin/src/index.ts b/typescript/nextjs-plugin/src/index.ts index a61f78d1d..77dd2d674 100644 --- a/typescript/nextjs-plugin/src/index.ts +++ b/typescript/nextjs-plugin/src/index.ts @@ -133,6 +133,15 @@ export function withBaml(bamlConfig: BamlNextConfig = {}) { }) } + // Ensure module and rules are defined + config.module = config.module || {} + config.module.rules = config.module.rules || [] + // Add WebAssembly loading configuration (properly indented) + config.module.rules.push({ + test: /\.wasm$/, + type: 'asset/resource', + }) + return config }, } as T diff --git a/typescript/playground-common/src/baml_wasm_web/EventListener.tsx b/typescript/playground-common/src/baml_wasm_web/EventListener.tsx index 757ba2f9e..26dd3ac06 100644 --- a/typescript/playground-common/src/baml_wasm_web/EventListener.tsx +++ b/typescript/playground-common/src/baml_wasm_web/EventListener.tsx @@ -135,6 +135,18 @@ export const EventListener: React.FC<{ children: React.ReactNode }> = ({ childre root_path: string } } + | { + command: 'set_flashing_regions' + content: { + spans: { + file_path: string + start_line: number + start: number + end_line: number + end: number + }[] + } + } | { command: 'select_function' content: { @@ -181,6 +193,10 @@ export const EventListener: React.FC<{ children: React.ReactNode }> = ({ childre } break + case 'set_flashing_regions': + console.log('DEBUG set_flashing_regions', content) + break + case 'select_function': console.log('select_function', content) setSelectedFunction(content.function_name) diff --git a/typescript/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/test-panel/test-runner.ts b/typescript/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/test-panel/test-runner.ts index 99ba043d7..6a4835217 100644 --- a/typescript/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/test-panel/test-runner.ts +++ b/typescript/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/test-panel/test-runner.ts @@ -1,8 +1,9 @@ -import type { WasmFunctionResponse, WasmTestResponse } from '@gloo-ai/baml-schema-wasm-web' +import type { WasmFunctionResponse, WasmSpan, WasmTestResponse } from '@gloo-ai/baml-schema-wasm-web' import { useAtomValue, useSetAtom } from 'jotai' import { findMediaFile } from '../media-utils' import { ctxAtom, runtimeAtom, wasmAtom } from '../../../atoms' import { useAtomCallback } from 'jotai/utils' +import { vscode } from '../../../vscode' import { useCallback } from 'react' import { type TestState, @@ -11,10 +12,20 @@ import { selectedTestcaseAtom, selectedFunctionAtom, } from '../../atoms' -import { vscode } from '../../../vscode' -import { testHistoryAtom, selectedHistoryIndexAtom, type TestHistoryRun, isParallelTestsEnabledAtom } from './atoms' +import { isParallelTestsEnabledAtom, testHistoryAtom, selectedHistoryIndexAtom, type TestHistoryRun } from './atoms' import { isClientCallGraphEnabledAtom } from '../../preview-toolbar' +// Helper function to clear highlights if in VSCode +const clearHighlights = () => { + try { + vscode.postMessage({ + command: 'clearHighlights', + }) + } catch (e) { + console.error('Failed to clear highlights in VSCode:', e) + } +} + // TODO: use a single hook for both run and parallel run const useRunTests = (maxBatchSize = 5) => { const { rt } = useAtomValue(runtimeAtom) @@ -69,18 +80,32 @@ const useRunTests = (maxBatchSize = 5) => { } const runTest = async (test: { functionName: string; testName: string }) => { + console.log('runTest', test) + + // TEMPORARY DEBUGGING HELPER: + // console.log("Try to set flashing regions") + // try { + // vscode.postMessage({ + // command: 'set_flashing_regions', + // spans: [{file_path: "tmp", start: 1, end: 4, start_line:0, end_line: 0}], + // }) + // } catch (e) { + // console.error('Failed to set flashing regions in VSCode:', e) + // } + try { const testCase = get(testCaseAtom(test)) - console.log('test deps', testCase, rt, ctx, wasm) if (!rt || !ctx || !testCase || !wasm) { - setState(test, { status: 'error', message: 'Missing required dependencies' }) + setState(test, { status: 'error', message: 'Missing required dependencies.' }) console.error('Missing required dependencies') + clearHighlights() // Clear highlights on error return } const startTime = performance.now() setState(test, { status: 'running' }) - const result = await testCase.fn.run_test( + + const result = await testCase.fn.run_test_with_expr_events( rt, testCase.tc.name, vscode.loadEnv(), @@ -89,6 +114,26 @@ const useRunTests = (maxBatchSize = 5) => { }, findMediaFile, vscode.loadAwsCreds.bind(vscode), + + (spans: WasmSpan[]) => { + // Send spans to VSCode for highlighting if we're in the VSCode environment + const spans_to_send = spans.map((span) => ({ + file_path: span.file_path, + start_line: span.start_line, + start: span.start, + end_line: span.end_line, + end: span.end, + })) + console.log('spans_to_send: ', spans_to_send) + try { + vscode.postMessage({ + command: 'set_flashing_regions', + spans: spans_to_send, + }) + } catch (e) { + console.error('Failed to send spans to VSCode:', e) + } + }, ) console.log('result', result) @@ -110,9 +155,13 @@ const useRunTests = (maxBatchSize = 5) => { response_status: responseStatusMap[response_status] || 'error', latency_ms: endTime - startTime, }) + + // Clear highlights when test is completed, whether success or failure + clearHighlights() } catch (e) { console.log('test error!') console.error(e) + clearHighlights() // Clear highlights on error setState(test, { status: 'error', message: e instanceof Error ? e.message : 'Unknown error', @@ -153,6 +202,7 @@ const useRunTests = (maxBatchSize = 5) => { set(areTestsRunningAtom, true) await run().finally(() => { set(areTestsRunningAtom, false) + clearHighlights() // Clear highlights when all tests are done }) }, [maxBatchSize, rt, ctx, wasm], diff --git a/typescript/vscode-ext/packages/language-server/src/server.ts b/typescript/vscode-ext/packages/language-server/src/server.ts index fb6e7a7a8..136cca2c8 100644 --- a/typescript/vscode-ext/packages/language-server/src/server.ts +++ b/typescript/vscode-ext/packages/language-server/src/server.ts @@ -33,7 +33,7 @@ import { exec } from 'child_process' // import { FileChangeType } from 'vscode' import fs from 'fs' // import { cliBuild, cliCheckForUpdates, cliVersion } from './baml-cli' -import { type ParserDatabase, TestRequest } from '@baml/common' +// import { type ParserDatabase, TestRequest } from '@baml/common' import debounce from 'lodash/debounce' // import { TextDocumentIdentifier } from 'vscode-languageserver-protocol' import { TextDocument } from 'vscode-languageserver-textdocument' diff --git a/typescript/vscode-ext/packages/vscode/server/darwin/baml-cli b/typescript/vscode-ext/packages/vscode/server/darwin/baml-cli new file mode 120000 index 000000000..15ae75ca4 --- /dev/null +++ b/typescript/vscode-ext/packages/vscode/server/darwin/baml-cli @@ -0,0 +1 @@ +../baml-cli \ No newline at end of file diff --git a/typescript/vscode-ext/packages/vscode/server/linux/baml-cli b/typescript/vscode-ext/packages/vscode/server/linux/baml-cli new file mode 120000 index 000000000..15ae75ca4 --- /dev/null +++ b/typescript/vscode-ext/packages/vscode/server/linux/baml-cli @@ -0,0 +1 @@ +../baml-cli \ No newline at end of file diff --git a/typescript/vscode-ext/packages/vscode/src/extension.ts b/typescript/vscode-ext/packages/vscode/src/extension.ts index 33badc615..78c4ea8a6 100644 --- a/typescript/vscode-ext/packages/vscode/src/extension.ts +++ b/typescript/vscode-ext/packages/vscode/src/extension.ts @@ -20,6 +20,12 @@ let timeout: NodeJS.Timeout | undefined let statusBarItem: vscode.StatusBarItem let server: any +let glowOnDecoration: vscode.TextEditorDecorationType | null = null +let glowOffDecoration: vscode.TextEditorDecorationType | null = null +let isGlowOn: boolean = true +let animationTimer: NodeJS.Timeout | null = null +let highlightRanges: vscode.Range[] = [] + function scheduleDiagnostics(): void { if (timeout) { clearTimeout(timeout) @@ -164,6 +170,10 @@ export function activate(context: vscode.ExtensionContext) { providedCodeActionKinds: [vscode.CodeActionKind.QuickFix], }) + // Initialize the highlight effect. + createDecorations() + startAnimation() + context.subscriptions.push(codeActionProvider) const app: Express = require('express')() @@ -323,6 +333,40 @@ export function activate(context: vscode.ExtensionContext) { }, ) + vscode.commands.registerCommand( + 'baml.setFlashingRegions', + async (args: { + spans: { file_path: string; start_line: number; start: number; end_line: number; end: number }[] + }) => { + // A helpful thing to toggle on for debugging: + // console.log('HANDLER setFlashingRegions', args) + // vscode.window.showWarningMessage(`setFlashingRegions:` + JSON.stringify(args)) + + // Focus the editor to ensure styling updates are applied rapidly. + if (vscode.window.activeTextEditor) { + vscode.window.showTextDocument( + vscode.window.activeTextEditor.document, + vscode.window.activeTextEditor.viewColumn, + ) + } + + context.subscriptions.push({ + dispose: () => { + stopAnimation() + if (glowOnDecoration) glowOnDecoration.dispose() + if (glowOffDecoration) glowOffDecoration.dispose() + }, + }) + const ranges = args.spans.map((span) => { + const start = new vscode.Position(span.start_line, span.start) + const end = new vscode.Position(span.end_line, span.end) + return new vscode.Range(start, end) + }) + highlightRanges = ranges + updateHighlight() + }, + ) + context.subscriptions.push(bamlPlaygroundCommand) console.log('pushing glooLens') @@ -439,3 +483,89 @@ class DiagnosticCodeActionProvider implements vscode.CodeActionProvider { return codeActions } } + +// Create our two decoration states +function createDecorations() { + // Bright neon color for the glow effect (bright green) + const glowColor = '#00FF00' + const offColor = '#009900' + + // Glow ON - attempt to create text glow with textDecoration property + glowOnDecoration = vscode.window.createTextEditorDecorationType({ + color: glowColor, + fontWeight: 'bold', + backgroundColor: 'transparent', + textDecoration: `none; text-shadow: 0 0 4px ${glowColor}, 0 0 6px ${glowColor}`, + // Try using before/after elements to reinforce the glow effect + before: { + contentText: '', + textDecoration: `none; text-shadow: 0 0 4px ${glowColor}, 0 0 6px ${glowColor}`, + color: glowColor, + }, + after: { + contentText: '', + textDecoration: `none; text-shadow: 0 0 4px ${glowColor}, 0 0 6px ${glowColor}`, + color: glowColor, + }, + }) + + // Glow OFF - text glow with textDecoration property. + glowOffDecoration = vscode.window.createTextEditorDecorationType({ + color: offColor, + fontWeight: 'bold', + backgroundColor: 'transparent', + textDecoration: `none; `, + // Try using before/after elements to reinforce the glow effect + before: { + contentText: '', + textDecoration: `none; `, + color: offColor, + }, + after: { + contentText: '', + textDecoration: `none; `, + color: offColor, + }, + }) +} + +// Update the highlight based on current state +function updateHighlight() { + // vscode.window.showWarningMessage(`updateHighlight:` + isGlowOn) + const editor = vscode.window.activeTextEditor + if (!editor) return + + // Clear both decorations + // Apply appropriate decoration based on state + if (glowOnDecoration && glowOffDecoration && isGlowOn) { + editor.setDecorations(glowOffDecoration, []) + editor.setDecorations(glowOnDecoration, highlightRanges) + } + if (glowOnDecoration && glowOffDecoration && !isGlowOn) { + editor.setDecorations(glowOnDecoration, []) + editor.setDecorations(glowOffDecoration, highlightRanges) + } +} + +// Start the simple toggling animation +function startAnimation() { + console.log('startAnimation') + if (animationTimer) return + + // Toggle every 500ms (2 times per second) + animationTimer = setInterval(() => { + // Toggle between on and off states + isGlowOn = !isGlowOn + + // Update the highlight + updateHighlight() + }, 500) // 500ms = half a second +} + +// Stop animation +function stopAnimation(): void { + if (animationTimer) { + clearInterval(animationTimer) + animationTimer = null + } +} diff --git a/typescript/vscode-ext/packages/vscode/src/panels/WebviewPanelHost.ts b/typescript/vscode-ext/packages/vscode/src/panels/WebviewPanelHost.ts index 502715899..3053fa168 100644 --- a/typescript/vscode-ext/packages/vscode/src/panels/WebviewPanelHost.ts +++ b/typescript/vscode-ext/packages/vscode/src/panels/WebviewPanelHost.ts @@ -215,6 +215,10 @@ export class WebviewPanelHost { | { command: 'get_port' | 'add_project' | 'cancelTestRun' | 'removeTest' } + | { + command: 'set_flashing_regions' + spans: { file_path: string; start_line: number; start_char: number; end_line: number; end_char: number }[] + } | { command: 'jumpToFile' span: StringSpan @@ -231,6 +235,7 @@ export class WebviewPanelHost { data: WebviewToVscodeRpc }, ) => { + console.log('DEBUG: webview message: ', message) if ('command' in message) { switch (message.command) { case 'add_project': @@ -262,6 +267,12 @@ export class WebviewPanelHost { }) return } + case 'set_flashing_regions': { + // Call the command handler with the spans + console.log('WEBPANELVIEW set_flashing_regions', message.spans) + vscode.commands.executeCommand('baml.setFlashingRegions', { spans: message.spans }) + return + } } }