Skip to content

pre-evaluate let bindings and cache evaluations of free vars #1801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions engine/baml-runtime/src/eval_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use futures::channel::mpsc;
use futures::stream::{self as stream, StreamExt};
use internal_baml_core::internal_baml_diagnostics::SerializedSpan;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use crate::{BamlRuntime, FunctionResult};
use baml_types::expr::{Expr, ExprMetadata, Name, VarIndex};
Expand All @@ -17,6 +17,8 @@ pub struct EvalEnv<'a> {
pub context: HashMap<Name, Expr<ExprMetadata>>,
pub runtime: &'a BamlRuntime,
pub expr_tx: Option<mpsc::UnboundedSender<Vec<SerializedSpan>>>,
/// Evaluated top-level expressions.
pub evaluated_cache: Arc<Mutex<HashMap<Name, Expr<ExprMetadata>>>>,
}

impl<'a> EvalEnv<'a> {
Expand Down Expand Up @@ -131,18 +133,19 @@ async fn beta_reduce<'a>(
match expr {
Expr::Atom(_) => Ok(expr.clone()),
Expr::Let(name, value, body, meta) => {
// Rewrite the let binding as an application.
// e.g. (let x = y in f) => (\x y => f)
// TODO: Should rewriting be done here in beta_reduce? Or elsewhere?
// First evaluate the bound expression
let evaluated_value = Box::pin(beta_reduce(env, value)).await?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review the use of Box::pin on beta_reduce calls. If beta_reduce already returns an async future, you might simplify by awaiting it directly.

Suggested change
let evaluated_value = Box::pin(beta_reduce(env, value)).await?;
let evaluated_value = beta_reduce(env, value).await?;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Box::pin(...).await is unnecessary here; beta_reduce is already async so you can simply await its result.

Suggested change
let evaluated_value = Box::pin(beta_reduce(env, value)).await?;
let evaluated_value = beta_reduce(env, value).await?;


// Then substitute the evaluated value into the body
let target = VarIndex {
de_bruijn: 0,
tuple: 0,
};
let closed_body = body.close(&target, name);
let arity = 1;
let lambda = Expr::Lambda(arity, Arc::new(closed_body), meta.clone());
let app = Expr::App(Arc::new(lambda), value.clone(), meta.clone());
Box::pin(beta_reduce(env, &app)).await
let substituted_body = subst2(&closed_body, &target, &evaluated_value, env)?;

// Finally evaluate the body with the substitution
Box::pin(beta_reduce(env, &substituted_body)).await
}
Expr::App(f, x, meta) => {
match (f.as_ref(), x.as_ref()) {
Expand Down Expand Up @@ -252,11 +255,25 @@ async fn beta_reduce<'a>(
}
}
Expr::FreeVar(name, _) => {
if let Some(cached) = env.evaluated_cache.lock().unwrap().get(name) {
return Ok(cached.clone());
}

let var_lookup = env
.context
.get(name)
.context(format!("Variable not found: {:?}", name))?;
Ok(var_lookup.clone())

// Evaluate the expression
let evaluated = Box::pin(beta_reduce(env, var_lookup)).await?;

// Cache the result
env.evaluated_cache
.lock()
.unwrap()
.insert(name.clone(), evaluated.clone());

Ok(evaluated)
}
Expr::BoundVar(_, _) => Ok(expr.clone()),
Expr::List(_, _) => Ok(expr.clone()),
Expand Down
1 change: 1 addition & 0 deletions engine/baml-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ impl BamlRuntime {
context,
runtime: self,
expr_tx: expr_tx.clone(),
evaluated_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
};
let param_baml_values = params
.iter()
Expand Down
Loading