diff --git a/engine/baml-runtime/src/eval_expr.rs b/engine/baml-runtime/src/eval_expr.rs index 8733018ba..4a221a187 100644 --- a/engine/baml-runtime/src/eval_expr.rs +++ b/engine/baml-runtime/src/eval_expr.rs @@ -3,7 +3,7 @@ use futures::channel::mpsc; use futures::stream::{self as stream, StreamExt}; use internal_baml_core::internal_baml_diagnostics::SerializedSpan; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use crate::{BamlRuntime, FunctionResult}; use baml_types::expr::{Expr, ExprMetadata, Name, VarIndex}; @@ -17,6 +17,8 @@ pub struct EvalEnv<'a> { pub context: HashMap>, pub runtime: &'a BamlRuntime, pub expr_tx: Option>>, + /// Evaluated top-level expressions. + pub evaluated_cache: Arc>>>, } impl<'a> EvalEnv<'a> { @@ -131,18 +133,19 @@ async fn beta_reduce<'a>( match expr { Expr::Atom(_) => Ok(expr.clone()), Expr::Let(name, value, body, meta) => { - // Rewrite the let binding as an application. - // e.g. (let x = y in f) => (\x y => f) - // TODO: Should rewriting be done here in beta_reduce? Or elsewhere? + // First evaluate the bound expression + let evaluated_value = Box::pin(beta_reduce(env, value)).await?; + + // Then substitute the evaluated value into the body let target = VarIndex { de_bruijn: 0, tuple: 0, }; let closed_body = body.close(&target, name); - let arity = 1; - let lambda = Expr::Lambda(arity, Arc::new(closed_body), meta.clone()); - let app = Expr::App(Arc::new(lambda), value.clone(), meta.clone()); - Box::pin(beta_reduce(env, &app)).await + let substituted_body = subst2(&closed_body, &target, &evaluated_value, env)?; + + // Finally evaluate the body with the substitution + Box::pin(beta_reduce(env, &substituted_body)).await } Expr::App(f, x, meta) => { match (f.as_ref(), x.as_ref()) { @@ -252,11 +255,25 @@ async fn beta_reduce<'a>( } } Expr::FreeVar(name, _) => { + if let Some(cached) = env.evaluated_cache.lock().unwrap().get(name) { + return Ok(cached.clone()); + } + let var_lookup = env .context .get(name) .context(format!("Variable not found: {:?}", name))?; - Ok(var_lookup.clone()) + + // Evaluate the expression + let evaluated = Box::pin(beta_reduce(env, var_lookup)).await?; + + // Cache the result + env.evaluated_cache + .lock() + .unwrap() + .insert(name.clone(), evaluated.clone()); + + Ok(evaluated) } Expr::BoundVar(_, _) => Ok(expr.clone()), Expr::List(_, _) => Ok(expr.clone()), diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index 38bc0a960..3a9ec2f8a 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -560,6 +560,7 @@ impl BamlRuntime { context, runtime: self, expr_tx: expr_tx.clone(), + evaluated_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), }; let param_baml_values = params .iter()