From cdded5e7125d1207960fdf49d5115acfa7350974 Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Wed, 5 Feb 2025 17:22:06 +0530 Subject: [PATCH 1/8] Allow default types --- golem-rib/src/compiler/byte_code.rs | 30 +- golem-rib/src/compiler/desugar.rs | 2 +- golem-rib/src/compiler/mod.rs | 21 +- golem-rib/src/expr.rs | 15 +- golem-rib/src/inferred_type/unification.rs | 18 +- golem-rib/src/inferred_type/validation.rs | 40 +- golem-rib/src/interpreter/rib_interpreter.rs | 9 +- golem-rib/src/lib.rs | 1 + golem-rib/src/type_checker/mod.rs | 2 +- golem-rib/src/type_checker/path.rs | 23 + .../type_inference/global_input_inference.rs | 22 +- .../global_variable_type_default.rs | 1096 +++++++++++++++++ .../type_inference/identifier_inference.rs | 2 +- .../src/type_inference/inference_fix_point.rs | 3 +- golem-rib/src/type_inference/inferred_expr.rs | 6 +- golem-rib/src/type_inference/mod.rs | 97 +- golem-rib/src/type_inference/type_pull_up.rs | 3 +- .../src/type_inference/type_unification.rs | 6 +- .../src/gateway_rib_compiler/mod.rs | 9 +- 19 files changed, 1322 insertions(+), 83 deletions(-) create mode 100644 golem-rib/src/type_inference/global_variable_type_default.rs diff --git a/golem-rib/src/compiler/byte_code.rs b/golem-rib/src/compiler/byte_code.rs index 0bb83f3500..4465cff982 100644 --- a/golem-rib/src/compiler/byte_code.rs +++ b/golem-rib/src/compiler/byte_code.rs @@ -706,7 +706,7 @@ mod compiler_tests { fn test_instructions_for_literal() { let literal = Expr::Literal("hello".to_string(), InferredType::Str); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&literal, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&literal, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -725,7 +725,7 @@ mod compiler_tests { let variable_id = VariableId::local("request", 0); let empty_registry = FunctionTypeRegistry::empty(); let expr = Expr::Identifier(variable_id.clone(), inferred_input_type); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -752,7 +752,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -787,7 +787,7 @@ mod compiler_tests { let expr = Expr::equal_to(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -826,7 +826,7 @@ mod compiler_tests { let expr = Expr::greater_than(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -865,7 +865,7 @@ mod compiler_tests { let expr = Expr::less_than(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -904,7 +904,7 @@ mod compiler_tests { let expr = Expr::greater_than_or_equal_to(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -943,7 +943,7 @@ mod compiler_tests { let expr = Expr::less_than_or_equal_to(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -983,7 +983,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1027,7 +1027,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1057,7 +1057,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1097,7 +1097,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1147,7 +1147,7 @@ mod compiler_tests { let expr = Expr::SelectField(Box::new(record), "bar_key".to_string(), InferredType::Str); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1194,7 +1194,7 @@ mod compiler_tests { let expr = Expr::SelectIndex(Box::new(sequence), 1, InferredType::Str); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1243,7 +1243,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); diff --git a/golem-rib/src/compiler/desugar.rs b/golem-rib/src/compiler/desugar.rs index 30911956a7..a36d5f02b2 100644 --- a/golem-rib/src/compiler/desugar.rs +++ b/golem-rib/src/compiler/desugar.rs @@ -574,7 +574,7 @@ mod desugar_tests { let function_type_registry = get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let desugared_expr = match internal::last_expr(&expr) { Expr::PatternMatch(predicate, match_arms, _) => { diff --git a/golem-rib/src/compiler/mod.rs b/golem-rib/src/compiler/mod.rs index e3bcb50aae..a392474be3 100644 --- a/golem-rib/src/compiler/mod.rs +++ b/golem-rib/src/compiler/mod.rs @@ -16,11 +16,12 @@ pub use byte_code::*; pub use compiler_output::*; use golem_wasm_ast::analysis::AnalysedExport; pub use ir::*; +use std::collections::HashSet; pub use type_with_unit::*; pub use worker_functions_in_rib::*; use crate::type_registry::FunctionTypeRegistry; -use crate::{Expr, InferredExpr, RibInputTypeInfo, RibOutputTypeInfo}; +use crate::{Expr, InferredExpr, RibInputTypeInfo, RibOutputTypeInfo, TypeDefault}; mod byte_code; mod compiler_output; @@ -33,25 +34,37 @@ pub fn compile( expr: &Expr, export_metadata: &Vec, ) -> Result { - compile_with_limited_globals(expr, export_metadata, None) + compile_with_restricted_global_variables(expr, export_metadata, None, None) } // Rib allows global input variables, however, we can choose to fail compilation // if they don't fall under a pre-defined set of global variables. // There is no restriction imposed to the type of this variable. -pub fn compile_with_limited_globals( +// Also we can specify types for certain global variables and if needed be specific +// on the path. Example: All variables under the variable `path` which is under the global variable `request` can be `Str` +pub fn compile_with_restricted_global_variables( expr: &Expr, export_metadata: &Vec, allowed_global_variables: Option>, + global_variable_type_default: Option, ) -> Result { let type_registry = FunctionTypeRegistry::from_export_metadata(export_metadata); - let inferred_expr = InferredExpr::from_expr(expr, &type_registry)?; + let inferred_expr = + InferredExpr::from_expr(expr, &type_registry, global_variable_type_default.as_ref())?; let function_calls_identified = WorkerFunctionsInRib::from_inferred_expr(&inferred_expr, &type_registry)?; let global_input_type_info = RibInputTypeInfo::from_expr(&inferred_expr).map_err(|e| format!("Error: {}", e))?; + let global_keys: HashSet<_> = global_input_type_info.types.keys().cloned().collect(); + + if let Some(info) = &global_variable_type_default { + if !info.variable_id.is_global() || !global_keys.contains(&info.variable_id.to_string()) { + return Err("Only global variables can have default types".to_string()); + } + } + let output_type_info = RibOutputTypeInfo::from_expr(&inferred_expr)?; if let Some(allowed_global_variables) = &allowed_global_variables { diff --git a/golem-rib/src/expr.rs b/golem-rib/src/expr.rs index e725b7d129..2d0d414ae5 100644 --- a/golem-rib/src/expr.rs +++ b/golem-rib/src/expr.rs @@ -18,7 +18,7 @@ use crate::parser::type_name::TypeName; use crate::type_registry::FunctionTypeRegistry; use crate::{ from_string, text, type_checker, type_inference, DynamicParsedFunctionName, InferredType, - ParsedFunctionName, VariableId, + ParsedFunctionName, TypeDefault, VariableId, }; use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; use combine::parser::char::spaces; @@ -392,6 +392,11 @@ impl Expr { ) } + pub fn override_types(&self, type_default: &TypeDefault) -> Result { + let result_expr = type_inference::tag_default_global_variable_type(self, type_default)?; + Ok(result_expr) + } + pub fn literal(value: impl AsRef) -> Self { Expr::Literal(value.as_ref().to_string(), InferredType::Str) } @@ -541,13 +546,13 @@ impl Expr { pub fn infer_types( &mut self, function_type_registry: &FunctionTypeRegistry, + type_default: Option<&TypeDefault>, ) -> Result<(), Vec> { - self.infer_types_initial_phase(function_type_registry)?; + self.infer_types_initial_phase(function_type_registry, type_default)?; self.infer_call_arguments_type(function_type_registry) .map_err(|x| vec![x])?; type_inference::type_inference_fix_point(Self::inference_scan, self) .map_err(|x| vec![x])?; - self.check_types(function_type_registry) .map_err(|x| vec![x])?; self.unify_types()?; @@ -557,7 +562,11 @@ impl Expr { pub fn infer_types_initial_phase( &mut self, function_type_registry: &FunctionTypeRegistry, + type_default: Option<&TypeDefault>, ) -> Result<(), Vec> { + if let Some(type_default) = type_default { + *self = self.override_types(type_default).map_err(|x| vec![x])?; + } self.bind_types(); self.bind_variables_of_list_comprehension(); self.bind_variables_of_list_reduce(); diff --git a/golem-rib/src/inferred_type/unification.rs b/golem-rib/src/inferred_type/unification.rs index 03acefe393..c841285a72 100644 --- a/golem-rib/src/inferred_type/unification.rs +++ b/golem-rib/src/inferred_type/unification.rs @@ -573,18 +573,22 @@ pub fn unify_with_required( result.unify_with_required(inferred_type) } - (inferred_type1, inferred_type2) => { - if inferred_type1 == inferred_type2 { - Ok(inferred_type1.clone()) - } else if inferred_type1.is_number() && inferred_type2.is_number() { + (inferred_type_left, inferred_type_right) => { + if inferred_type_left == inferred_type_right { + Ok(inferred_type_left.clone()) + } else if inferred_type_left.is_number() && inferred_type_right.is_number() { Ok(InferredType::AllOf(vec![ - inferred_type1.clone(), - inferred_type2.clone(), + inferred_type_left.clone(), + inferred_type_right.clone(), ])) + } else if inferred_type_left.is_string() && inferred_type_right.is_number() { + Ok(inferred_type_right.clone()) + } else if inferred_type_left.is_number() && inferred_type_right.is_string() { + Ok(inferred_type_left.clone()) } else { Err(format!( "Types do not match. Inferred to be both {:?} and {:?}", - inferred_type1, inferred_type2 + inferred_type_left, inferred_type_right )) } } diff --git a/golem-rib/src/inferred_type/validation.rs b/golem-rib/src/inferred_type/validation.rs index 9b4025f512..ef8794596b 100644 --- a/golem-rib/src/inferred_type/validation.rs +++ b/golem-rib/src/inferred_type/validation.rs @@ -99,7 +99,7 @@ pub fn validate_unified_type(inferred_type: &InferredType) -> UnificationResult } resource @ InferredType::Resource { .. } => Ok(Unified(resource.clone())), InferredType::OneOf(possibilities) => Err(format!("Cannot resolve {:?}", possibilities)), - InferredType::AllOf(possibilities) => Err(format!("Cannot be all of {:?}", possibilities)), + InferredType::AllOf(possibilities) => coerce_to_numerical_type(possibilities).map(Unified), InferredType::Unknown => Err("Unknown".to_string()), inferred_type @ InferredType::Sequence(inferred_types) => { for typ in inferred_types { @@ -110,3 +110,41 @@ pub fn validate_unified_type(inferred_type: &InferredType) -> UnificationResult } } } + +fn coerce_to_numerical_type(possibilities: &Vec) -> Result { + let mut number_type: Option = None; + let mut has_string = false; + + for ty in possibilities { + match ty { + InferredType::Str => has_string = true, + InferredType::U8 + | InferredType::U16 + | InferredType::U32 + | InferredType::U64 + | InferredType::S8 + | InferredType::S16 + | InferredType::S32 + | InferredType::S64 + | InferredType::F32 + | InferredType::F64 => { + if let Some(existing) = &number_type { + if existing != ty { + return Err("Cannot coerce to a mixed number type".to_string()); + } + } else { + number_type = Some(ty.clone()); + } + } + _ => return Err(format!("Cannot be all of {:?}", possibilities)), + } + } + + if let Some(num) = number_type { + Ok(num) + } else if has_string { + Ok(InferredType::Str) + } else { + Err("No valid number or string type found".to_string()) + } +} diff --git a/golem-rib/src/interpreter/rib_interpreter.rs b/golem-rib/src/interpreter/rib_interpreter.rs index 1341af1a46..cba539b042 100644 --- a/golem-rib/src/interpreter/rib_interpreter.rs +++ b/golem-rib/src/interpreter/rib_interpreter.rs @@ -1596,7 +1596,8 @@ mod interpreter_tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let compiled = compiler::compile(&expr, &vec![]).unwrap(); let result = interpreter.run(compiled.byte_code).await.unwrap(); @@ -1616,7 +1617,8 @@ mod interpreter_tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let compiled = compiler::compile(&expr, &vec![]).unwrap(); let result = interpreter.run(compiled.byte_code).await.unwrap(); @@ -1637,7 +1639,8 @@ mod interpreter_tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let compiled = compiler::compile(&expr, &vec![]).unwrap(); let result = interpreter.run(compiled.byte_code).await.unwrap(); diff --git a/golem-rib/src/lib.rs b/golem-rib/src/lib.rs index f2afffbde2..4f49f5c326 100644 --- a/golem-rib/src/lib.rs +++ b/golem-rib/src/lib.rs @@ -19,6 +19,7 @@ pub use inferred_type::*; pub use interpreter::*; pub use parser::type_name::TypeName; pub use text::*; +pub use type_checker::*; pub use type_inference::*; pub use type_registry::*; pub use variable_id::*; diff --git a/golem-rib/src/type_checker/mod.rs b/golem-rib/src/type_checker/mod.rs index 5b12f45aaa..c6a65bf59f 100644 --- a/golem-rib/src/type_checker/mod.rs +++ b/golem-rib/src/type_checker/mod.rs @@ -1,5 +1,5 @@ pub(crate) use missing_fields::*; -pub(crate) use path::*; +pub use path::*; pub(crate) use type_check_error::*; pub(crate) use type_mismatch::*; pub(crate) use unresolved_types::*; diff --git a/golem-rib/src/type_checker/path.rs b/golem-rib/src/type_checker/path.rs index 78327d620c..37e09c735c 100644 --- a/golem-rib/src/type_checker/path.rs +++ b/golem-rib/src/type_checker/path.rs @@ -5,10 +5,33 @@ use std::fmt::Display; pub struct Path(Vec); impl Path { + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn current(&self) -> Option<&PathElem> { + self.0.first() + } + + pub fn progress(&mut self) { + if !self.0.is_empty() { + self.0.remove(0); + } + } + pub fn from_elem(elem: PathElem) -> Self { Path(vec![elem]) } + pub fn from_elems(elems: Vec<&str>) -> Self { + Path( + elems + .iter() + .map(|x| PathElem::Field(x.to_string())) + .collect(), + ) + } + pub fn push_front(&mut self, elem: PathElem) { self.0.insert(0, elem); } diff --git a/golem-rib/src/type_inference/global_input_inference.rs b/golem-rib/src/type_inference/global_input_inference.rs index ae708a15af..ea4a04f81a 100644 --- a/golem-rib/src/type_inference/global_input_inference.rs +++ b/golem-rib/src/type_inference/global_input_inference.rs @@ -22,8 +22,9 @@ mod internal { use crate::{Expr, InferredType}; use std::collections::{HashMap, VecDeque}; - // Unlike inferring all identifiers, inputs don't have an associated let binding, - // and yet we need to propagate this type info all over + // request.path.user is used as a string in one place + // request.path.id is used an integer in some other + // request -> AllOf(path -> user, path -> id) pub(crate) fn infer_global_inputs(expr: &mut Expr) { let global_variables_dictionary = collect_all_global_variables_type(expr); // Updating the collected types in all positions of input @@ -54,12 +55,19 @@ mod internal { while let Some(expr) = queue.pop_back() { match expr { Expr::Identifier(variable_id, inferred_type) => { - // We are only interested in global variables if variable_id.is_global() { - all_types_of_global_variables - .entry(variable_id.name().clone()) - .or_insert(Vec::new()) - .push(inferred_type.clone()); + match all_types_of_global_variables.get_mut(&variable_id.name().clone()) { + None => { + all_types_of_global_variables + .insert(variable_id.name(), vec![inferred_type.clone()]); + } + + Some(v) => { + if !v.contains(inferred_type) { + v.push(inferred_type.clone()) + } + } + } } } _ => expr.visit_children_mut_bottom_up(&mut queue), diff --git a/golem-rib/src/type_inference/global_variable_type_default.rs b/golem-rib/src/type_inference/global_variable_type_default.rs new file mode 100644 index 0000000000..e4ab1d05b5 --- /dev/null +++ b/golem-rib/src/type_inference/global_variable_type_default.rs @@ -0,0 +1,1096 @@ +use crate::type_checker::Path; +use crate::{Expr, InferredType, VariableId}; +use std::collections::VecDeque; + +// The goal is to be able to specify the types associated with an identifier. +// i.e, `a.*` is always `Str`, or `a.b.*` is always `Str`, or `a.b.c` is always `Str` +// This can be represented using `TypeDefault { a, vec![], Str }`, `TypeDefault {a, b, Str}` and +// `TypeDefault {a, vec[b, c], Str}` respectively +// If you specify completely opposite types to be default, you will get a compilation error. +// If you tried to specify a variable is always string, but compiler identifies it's usage as `U64`, +// then it chooses `U64` and discards the default. If the compiler finds its usages as `Str` +#[derive(Clone, Debug)] +pub struct TypeDefault { + pub variable_id: VariableId, + pub path: Path, + pub inferred_type: InferredType, +} + +// +// Algorithm: +// +// The goal is to be able to specify the types associated with an identifier +// i.e, `a.*` is always `Str`, or `a.b.*` is always `Str`, or `a.b.c` is always `Str` +// This can be represented using `TypeDefault { a, vec![], Str }`, `TypeDefault {a, b, Str}` and +// `TypeDefault {a, vec[b, c], Str}` respectively +// +// We initially create queue of immutable Expr (to be able to push mutable version has to do into reference count logic in Rust) +// and then push it to an intermediate stack and recreate the Expr. This is similar to `type_pull_up` phase. +// This is verbose but will make the algorithm quite easy to handle. + +// Any other way of non-recursive way of overriding values requires RefCell. i.e, +// get a mutable expr, and send each mutable node into a queue, and then read these +// mutable expr and mutate it elsewhere requires Rc with RefCell in Rust. We +// decide from the beginning to keep the Expr tree as simple as possible with no Rc or RefCell structures +// just for 1 or 2 phases of compilation. +// +// Steps: +// // Pre-process +// Initialize a queue with all expsr in the tree, with the root node first: +// Example queue: +// [select_field(select_field(a, b), c), select_field(a, b), identifier(a)] +// +// Example Walkthrough: Given `TypeDefault { a, vec[b, c], Str]` +// +// 1. Pop the back element in the queue to get `identifier(a)`. +// - Check the `temp_stack` by popping from the front. +// - If it's `None`, push `identifier(a)`'s to the stack. +// +// 2. Pop the back element in the queue again to get `select_field(a, b)`. +// - Check the `temp_stack`, which now has +// `(identifier(a), true)` at the front. We pop it out. +// - Given `b` in the current is part of the path, and given path is not empty now, +// push (select_field(identifier(a), b), true) back to stack (by this time stack has only 1 elem) +// +// 3. Pop the final element from the queue: `select_field(select_field(a, b), c)`. +// - Check the `temp_stack`, which has `select_field(identifier(a), b), true) ` at the front. +// - Given flag is true, and given c is also path (and the path has no more elements) +// push (select_field((select_field(identifier(a), b), c, InferredType::Str)), false) +// where false indicates loop break +// +// The same algorithm above is tweaked even if users specified partial paths. Example: +// Everything under `a.b` (regardless of the existence of c and d) at their leafs follow the default type +pub fn tag_default_global_variable_type( + expr: &Expr, + type_default: &TypeDefault, +) -> Result { + let mut path = type_default.path.clone(); + + let mut expr_queue = VecDeque::new(); + + internal::make_expr_nodes_queue(expr, &mut expr_queue); + + let mut temp_stack = VecDeque::new(); + + while let Some(expr) = expr_queue.pop_back() { + match expr { + expr @ Expr::Identifier(variable_id, _) => { + if variable_id == &type_default.variable_id { + if path.is_empty() { + let continue_traverse = matches!(expr_queue.back(), Some(Expr::SelectField(inner, _, _)) if inner.as_ref() == expr); + + if continue_traverse { + temp_stack.push_front((expr.clone(), true)); + } else { + temp_stack.push_front(( + Expr::Identifier( + variable_id.clone(), + type_default.inferred_type.clone(), + ), + false, + )); + } + } else { + temp_stack.push_front((expr.clone(), true)); + } + } else { + temp_stack.push_front((expr.clone(), false)); + } + } + + outer @ Expr::SelectField(inner_expr, field, current_inferred_type) => { + let continue_search = matches!(expr_queue.back(), Some(Expr::SelectField(inner, _, _)) if inner.as_ref() == outer); + + internal::handle_select_field( + inner_expr, + field, + continue_search, + current_inferred_type, + &mut temp_stack, + &mut path, + &type_default.inferred_type, + )?; + } + + Expr::Tuple(tuple_elems, current_inferred_type) => { + internal::handle_tuple(tuple_elems, current_inferred_type, &mut temp_stack); + } + + expr @ Expr::Flags(_, _) => { + temp_stack.push_front((expr.clone(), false)); + } + + Expr::SelectIndex(expr, index, current_inferred_type) => { + internal::handle_select_index(expr, index, current_inferred_type, &mut temp_stack)?; + } + + Expr::Result(Ok(_), current_inferred_type) => { + internal::handle_result_ok(expr, current_inferred_type, &mut temp_stack); + } + + Expr::Result(Err(_), current_inferred_type) => { + internal::handle_result_error(expr, current_inferred_type, &mut temp_stack); + } + + Expr::Option(Some(expr), current_inferred_type) => { + internal::handle_option_some(expr, current_inferred_type, &mut temp_stack); + } + + Expr::Option(None, current_inferred_type) => { + temp_stack.push_front((Expr::Option(None, current_inferred_type.clone()), false)); + } + + Expr::Cond(pred, then, else_, current_inferred_type) => { + internal::handle_if_else(pred, then, else_, current_inferred_type, &mut temp_stack); + } + + // + Expr::PatternMatch(predicate, match_arms, current_inferred_type) => { + internal::handle_pattern_match( + predicate, + match_arms, + current_inferred_type, + &mut temp_stack, + ); + } + + Expr::Concat(exprs, _) => { + internal::handle_concat(exprs, &mut temp_stack); + } + + Expr::ExprBlock(exprs, current_inferred_type) => { + internal::handle_multiple(exprs, current_inferred_type, &mut temp_stack); + } + + Expr::Not(_, current_inferred_type) => { + internal::handle_not(expr, current_inferred_type, &mut temp_stack); + } + + Expr::GreaterThan(left, right, current_inferred_type) => { + internal::handle_comparison_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::GreaterThan, + ); + } + + Expr::GreaterThanOrEqualTo(left, right, current_inferred_type) => { + internal::handle_comparison_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::GreaterThanOrEqualTo, + ); + } + + Expr::LessThanOrEqualTo(left, right, current_inferred_type) => { + internal::handle_comparison_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::LessThanOrEqualTo, + ); + } + Expr::Plus(left, right, current_inferred_type) => { + internal::handle_math_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::Plus, + ); + } + + Expr::Minus(left, right, current_inferred_type) => { + internal::handle_math_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::Minus, + ); + } + + Expr::Multiply(left, right, current_inferred_type) => { + internal::handle_math_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::Multiply, + ); + } + + Expr::Divide(left, right, current_inferred_type) => { + internal::handle_math_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::Divide, + ); + } + + Expr::EqualTo(left, right, current_inferred_type) => { + internal::handle_comparison_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::EqualTo, + ); + } + + Expr::LessThan(left, right, current_inferred_type) => { + internal::handle_comparison_op( + left, + right, + current_inferred_type, + &mut temp_stack, + Expr::LessThan, + ); + } + + Expr::Let(variable_id, typ, expr, inferred_type) => { + internal::handle_let(variable_id, expr, typ, inferred_type, &mut temp_stack); + } + Expr::Sequence(exprs, current_inferred_type) => { + internal::handle_sequence(exprs, current_inferred_type, &mut temp_stack); + } + Expr::Record(expr, inferred_type) => { + internal::handle_record(expr, inferred_type, &mut temp_stack); + } + Expr::Literal(_, _) => { + temp_stack.push_front((expr.clone(), false)); + } + Expr::Number(_, _, _) => { + temp_stack.push_front((expr.clone(), false)); + } + Expr::Boolean(_, _) => { + temp_stack.push_front((expr.clone(), false)); + } + Expr::And(left, right, _) => { + internal::handle_comparison_op( + left, + right, + &InferredType::Bool, + &mut temp_stack, + Expr::And, + ); + } + + Expr::Or(left, right, _) => { + internal::handle_comparison_op( + left, + right, + &InferredType::Bool, + &mut temp_stack, + Expr::Or, + ); + } + + Expr::Call(call_type, exprs, inferred_type) => { + internal::handle_call(call_type, exprs, inferred_type, &mut temp_stack); + } + + Expr::Unwrap(expr, inferred_type) => { + internal::handle_unwrap(expr, inferred_type, &mut temp_stack); + } + + Expr::Throw(_, _) => { + temp_stack.push_front((expr.clone(), false)); + } + + Expr::GetTag(_, inferred_type) => { + internal::handle_get_tag(expr, inferred_type, &mut temp_stack); + } + + Expr::ListComprehension { + iterated_variable, + iterable_expr, + yield_expr, + inferred_type, + .. + } => { + internal::handle_list_comprehension( + iterated_variable, + iterable_expr, + yield_expr, + inferred_type, + &mut temp_stack, + ); + } + + Expr::ListReduce { + reduce_variable, + iterated_variable, + iterable_expr, + init_value_expr, + yield_expr, + inferred_type, + } => internal::handle_list_reduce( + reduce_variable, + iterated_variable, + iterable_expr, + init_value_expr, + yield_expr, + inferred_type, + &mut temp_stack, + ), + } + } + + temp_stack + .pop_front() + .map(|x| x.0) + .ok_or("Failed type inference during pull up".to_string()) +} + +mod internal { + use crate::call_type::CallType; + + use crate::type_checker::{Path, PathElem}; + use crate::{Expr, InferredType, MatchArm, VariableId}; + use std::collections::VecDeque; + use std::ops::Deref; + + pub(crate) fn make_expr_nodes_queue<'a>(expr: &'a Expr, expr_queue: &mut VecDeque<&'a Expr>) { + let mut stack = VecDeque::new(); + + stack.push_back(expr); + + while let Some(current_expr) = stack.pop_back() { + expr_queue.push_back(current_expr); + + current_expr.visit_children_bottom_up(&mut stack) + } + } + + pub(crate) fn handle_list_comprehension( + variable_id: &VariableId, + current_iterable_expr: &Expr, + current_yield_expr: &Expr, + current_comprehension_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let yield_expr_inferred = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(current_yield_expr.clone()); + let iterable_expr_inferred = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(current_iterable_expr.clone()); + + temp_stack.push_front(( + Expr::typed_list_comprehension( + variable_id.clone(), + iterable_expr_inferred, + yield_expr_inferred, + current_comprehension_type.clone(), + ), + false, + )) + } + + pub(crate) fn handle_list_reduce( + reduce_variable: &VariableId, + iterated_variable: &VariableId, + iterable_expr: &Expr, + initial_value_expr: &Expr, + yield_expr: &Expr, + reduce_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let new_yield_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(yield_expr.clone()); + let new_init_value_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(initial_value_expr.clone()); + let new_iterable_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(iterable_expr.clone()); + + let new_reduce_type = reduce_type.merge(new_init_value_expr.inferred_type()); + + temp_stack.push_front(( + Expr::typed_list_reduce( + reduce_variable.clone(), + iterated_variable.clone(), + new_iterable_expr, + new_init_value_expr, + new_yield_expr, + new_reduce_type, + ), + false, + )) + } + + pub(crate) fn handle_tuple( + tuple_elems: &[Expr], + current_tuple_type: &InferredType, + result_expr_queue: &mut VecDeque<(Expr, bool)>, + ) { + let mut new_tuple_elems = vec![]; + + for current_tuple_elem in tuple_elems.iter().rev() { + let pulled_up_type = result_expr_queue.pop_front(); + let new_tuple_elem = pulled_up_type + .map(|x| x.0) + .unwrap_or(current_tuple_elem.clone()); + new_tuple_elems.push(new_tuple_elem); + } + + new_tuple_elems.reverse(); + + // Reform tuple + let new_tuple = Expr::Tuple(new_tuple_elems, current_tuple_type.clone()); + result_expr_queue.push_front((new_tuple, false)); + } + + pub(crate) fn handle_select_field( + original_selection_expr: &Expr, + field: &str, + continue_search: bool, + current_field_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + path: &mut Path, + override_type: &InferredType, + ) -> Result<(), String> { + let (expr, part_of_path) = temp_stack + .pop_front() + .unwrap_or((original_selection_expr.clone(), false)); + + if part_of_path { + match path.current() { + Some(PathElem::Field(name)) if name == field => path.progress(), + Some(_) => return Err("We disallow type overrides at indices".to_string()), + None => {} + } + + if path.is_empty() { + let new_type = if continue_search { + current_field_type.clone() + } else { + current_field_type.merge(override_type.clone()) + }; + + temp_stack.push_front(( + Expr::SelectField(Box::new(expr.clone()), field.to_string(), new_type), + continue_search, + )); + } else { + temp_stack.push_front(( + Expr::SelectField( + Box::new(expr.clone()), + field.to_string(), + current_field_type.clone(), + ), + true, + )); + } + } else { + temp_stack.push_front(( + Expr::SelectField( + Box::new(expr.clone()), + field.to_string(), + current_field_type.clone(), + ), + false, + )); + } + + Ok(()) + } + + pub fn handle_select_index( + original_selection_expr: &Expr, + index: &usize, + current_index_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) -> Result<(), String> { + let expr = temp_stack + .pop_front() + .unwrap_or((original_selection_expr.clone(), false)); + + let new_select_index = + Expr::SelectIndex(Box::new(expr.0.clone()), *index, current_index_type.clone()); + temp_stack.push_front((new_select_index, false)); + + Ok(()) + } + + pub(crate) fn handle_result_ok( + original_ok_expr: &Expr, + current_ok_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let ok_expr = temp_stack + .pop_front() + .unwrap_or((original_ok_expr.clone(), false)); + + let new_result = Expr::Result(Ok(Box::new(ok_expr.0.clone())), current_ok_type.clone()); + temp_stack.push_front((new_result, true)); + } + + pub(crate) fn handle_result_error( + original_error_expr: &Expr, + current_error_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_error_expr.clone()); + + let new_result = Expr::Result(Err(Box::new(expr.clone())), current_error_type.clone()); + + temp_stack.push_front((new_result, false)); + } + + pub(crate) fn handle_option_some( + original_some_expr: &Expr, + current_some_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let expr = temp_stack + .pop_front() + .unwrap_or((original_some_expr.clone(), false)); + let new_option = Expr::Option(Some(Box::new(expr.0.clone())), current_some_type.clone()); + temp_stack.push_front((new_option, false)); + } + + pub(crate) fn handle_if_else( + original_predicate: &Expr, + original_then_expr: &Expr, + original_else_expr: &Expr, + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let else_expr = temp_stack + .pop_front() + .unwrap_or((original_else_expr.clone(), false)); + let then_expr = temp_stack + .pop_front() + .unwrap_or((original_then_expr.clone(), false)); + let cond_expr = temp_stack + .pop_front() + .unwrap_or((original_predicate.clone(), false)); + + let new_expr = Expr::Cond( + Box::new(cond_expr.0), + Box::new(then_expr.0.clone()), + Box::new(else_expr.0.clone()), + current_inferred_type.clone(), + ); + + temp_stack.push_front((new_expr, false)); + } + + pub fn handle_pattern_match( + predicate: &Expr, + current_match_arms: &[MatchArm], + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let mut new_resolutions = vec![]; + let mut new_arm_patterns = vec![]; + for un_inferred_match_arm in current_match_arms.iter().rev() { + let arm_resolution = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(un_inferred_match_arm.arm_resolution_expr.deref().clone()); + + let mut arm_pattern = un_inferred_match_arm.arm_pattern.clone(); + let current_arm_pattern_exprs = arm_pattern.get_expr_literals_mut(); + + let mut new_arm_pattern_exprs = vec![]; + + for _ in ¤t_arm_pattern_exprs { + let arm_expr = temp_stack.pop_front().map(|x| x.0); + new_arm_pattern_exprs.push(arm_expr) + } + new_arm_pattern_exprs.reverse(); + + new_resolutions.push(arm_resolution); + new_arm_patterns.push(arm_pattern); + } + + let mut new_match_arms = new_arm_patterns + .iter() + .zip(new_resolutions.iter()) + .map(|(arm_pattern, arm_resolution)| crate::MatchArm { + arm_pattern: arm_pattern.clone(), + arm_resolution_expr: Box::new(arm_resolution.clone()), + }) + .collect::>(); + + new_match_arms.reverse(); + + let pred = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(predicate.clone()); + + let new_expr = Expr::PatternMatch( + Box::new(pred.clone()), + new_match_arms, + current_inferred_type.clone(), + ); + + temp_stack.push_front((new_expr, false)); + } + + pub(crate) fn handle_concat(exprs: &Vec, temp_stack: &mut VecDeque<(Expr, bool)>) { + let mut new_exprs = vec![]; + for expr in exprs { + let expr = temp_stack.pop_front().map(|x| x.0).unwrap_or(expr.clone()); + new_exprs.push(expr); + } + + new_exprs.reverse(); + + let new_concat = Expr::Concat(new_exprs, InferredType::Str); + temp_stack.push_front((new_concat, false)); + } + + pub(crate) fn handle_multiple( + current_expr_list: &Vec, + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let mut new_exprs = vec![]; + for _ in current_expr_list { + let expr = temp_stack.pop_front(); + if let Some(expr) = expr { + new_exprs.push(expr.0); + } else { + break; + } + } + + new_exprs.reverse(); + + let new_multiple = Expr::ExprBlock(new_exprs, current_inferred_type.clone()); + temp_stack.push_front((new_multiple, false)); + } + + pub(crate) fn handle_not( + original_not_expr: &Expr, + current_not_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_not_expr.clone()); + let new_not = Expr::Not(Box::new(expr), current_not_type.clone()); + temp_stack.push_front((new_not, false)); + } + + pub(crate) fn handle_math_op( + original_left_expr: &Expr, + original_right_expr: &Expr, + result_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + f: F, + ) where + F: Fn(Box, Box, InferredType) -> Expr, + { + let right_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_right_expr.clone()); + let left_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_left_expr.clone()); + + let right_expr_type = right_expr.inferred_type(); + let left_expr_type = left_expr.inferred_type(); + let new_result_type = result_type.merge(right_expr_type).merge(left_expr_type); + + let new_math_op = f( + Box::new(left_expr), + Box::new(right_expr), + new_result_type.clone(), + ); + + temp_stack.push_front((new_math_op, false)); + } + + pub(crate) fn handle_comparison_op( + original_left_expr: &Expr, + original_right_expr: &Expr, + result_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + f: F, + ) where + F: Fn(Box, Box, InferredType) -> Expr, + { + let right_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_right_expr.clone()); + let left_expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_left_expr.clone()); + + let new_binary = f( + Box::new(left_expr), + Box::new(right_expr), + result_type.clone(), + ); + temp_stack.push_front((new_binary, false)); + } + + pub(crate) fn handle_call( + call_type: &CallType, + arguments: &[Expr], + inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let mut new_arg_exprs = vec![]; + + // retrieving all argument from the stack + for expr in arguments.iter().rev() { + let expr = temp_stack.pop_front().map(|x| x.0).unwrap_or(expr.clone()); + new_arg_exprs.push(expr); + } + + new_arg_exprs.reverse(); + + match call_type { + CallType::Function(fun_name) => { + let mut function_name = fun_name.clone(); + + // The resource params in the function name was also in stack and need to be retrieved back + let resource_params = function_name.function.raw_resource_params_mut(); + + if let Some(resource_params) = resource_params { + let mut new_resource_params = vec![]; + for expr in resource_params.iter().rev() { + let expr = temp_stack.pop_front().map(|x| x.0).unwrap_or(expr.clone()); + new_resource_params.push(expr); + } + + new_resource_params.reverse(); + + resource_params + .iter_mut() + .zip(new_resource_params.iter()) + .for_each(|(param, new_expr)| { + *param = new_expr.clone(); + }); + } + + let new_call = Expr::Call( + CallType::Function(function_name), + new_arg_exprs, + inferred_type.clone(), + ); + temp_stack.push_front((new_call, false)); + } + + CallType::VariantConstructor(str) => { + let new_call = Expr::Call( + CallType::VariantConstructor(str.clone()), + new_arg_exprs, + inferred_type.clone(), + ); + temp_stack.push_front((new_call, false)); + } + + CallType::EnumConstructor(str) => { + let new_call = Expr::Call( + CallType::EnumConstructor(str.clone()), + new_arg_exprs, + inferred_type.clone(), + ); + temp_stack.push_front((new_call, false)); + } + } + } + + pub(crate) fn handle_unwrap( + expr: &Expr, + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let expr = temp_stack.pop_front().map(|x| x.0).unwrap_or(expr.clone()); + let new_unwrap = Expr::Unwrap(Box::new(expr.clone()), current_inferred_type.clone()); + temp_stack.push_front((new_unwrap, false)); + } + + pub(crate) fn handle_get_tag( + expr: &Expr, + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let expr = temp_stack.pop_front().map(|x| x.0).unwrap_or(expr.clone()); + let new_get_tag = Expr::GetTag(Box::new(expr.clone()), current_inferred_type.clone()); + temp_stack.push_front((new_get_tag, false)); + } + + pub(crate) fn handle_let( + original_variable_id: &VariableId, + original_expr: &Expr, + optional_type: &Option, + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(original_expr.clone()); + let new_let = Expr::Let( + original_variable_id.clone(), + optional_type.clone(), + Box::new(expr), + current_inferred_type.clone(), + ); + temp_stack.push_front((new_let, false)); + } + + pub(crate) fn handle_sequence( + current_expr_list: &[Expr], + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let mut new_exprs = vec![]; + + for expr in current_expr_list.iter().rev() { + let expr = temp_stack.pop_front().map(|x| x.0).unwrap_or(expr.clone()); + new_exprs.push(expr); + } + + new_exprs.reverse(); + + let expr = Expr::Sequence(new_exprs, current_inferred_type.clone()); + + temp_stack.push_front((expr, false)); + } + + pub(crate) fn handle_record( + current_expr_list: &[(String, Box)], + current_inferred_type: &InferredType, + temp_stack: &mut VecDeque<(Expr, bool)>, + ) { + let mut new_exprs = vec![]; + + for (field, expr) in current_expr_list.iter().rev() { + let expr: Expr = temp_stack + .pop_front() + .map(|x| x.0) + .unwrap_or(expr.deref().clone()); + new_exprs.push((field.clone(), Box::new(expr.clone()))); + } + + new_exprs.reverse(); + + let new_record = Expr::Record(new_exprs.to_vec(), current_inferred_type.clone()); + temp_stack.push_front((new_record, false)); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FunctionTypeRegistry, Id, TypeName}; + use test_r::test; + + #[test] + fn test_override_types_1() { + let expr = Expr::from_text( + r#" + foo + "#, + ) + .unwrap(); + + let type_default = TypeDefault { + variable_id: VariableId::global("foo".to_string()), + path: Path::default(), + inferred_type: InferredType::Str, + }; + + let result = expr.override_types(&type_default).unwrap(); + + let expected = Expr::Identifier(VariableId::global("foo".to_string()), InferredType::Str); + + assert_eq!(result, expected); + } + + // Be able to + #[test] + fn test_override_types_2() { + let expr = Expr::from_text( + r#" + foo.bar.baz + "#, + ) + .unwrap(); + + let type_default = TypeDefault { + variable_id: VariableId::global("foo".to_string()), + path: Path::from_elems(vec!["bar"]), + inferred_type: InferredType::Str, + }; + + let result = expr.override_types(&type_default).unwrap(); + + let expected = Expr::SelectField( + Box::new(Expr::select_field(Expr::identifier("foo"), "bar")), + "baz".to_string(), + InferredType::Str, + ); + + assert_eq!(result, expected); + } + + #[test] + fn test_override_types_3() { + let expr = Expr::from_text( + r#" + foo.bar.baz + "#, + ) + .unwrap(); + + let type_default = TypeDefault { + variable_id: VariableId::global("foo".to_string()), + path: Path::from_elems(vec!["bar", "baz"]), + inferred_type: InferredType::Str, + }; + + let result = expr.override_types(&type_default).unwrap(); + + let expected = Expr::SelectField( + Box::new(Expr::select_field(Expr::identifier("foo"), "bar")), + "baz".to_string(), + InferredType::Str, + ); + + assert_eq!(result, expected); + } + + #[test] + fn test_override_types_4() { + let expr = Expr::from_text( + r#" + foo.bar.baz + "#, + ) + .unwrap(); + + let type_default = TypeDefault { + variable_id: VariableId::global("foo".to_string()), + path: Path::default(), + inferred_type: InferredType::Str, + }; + + let result = expr.override_types(&type_default).unwrap(); + + let expected = Expr::SelectField( + Box::new(Expr::select_field(Expr::identifier("foo"), "bar")), + "baz".to_string(), + InferredType::Str, + ); + + assert_eq!(result, expected); + } + + #[test] + fn test_override_types_5() { + let mut expr = Expr::from_text( + r#" + let res = foo.bar.user-id; + let hello: u64 = foo.bar.number; + hello + "#, + ) + .unwrap(); + + let type_default = TypeDefault { + variable_id: VariableId::global("foo".to_string()), + path: Path::from_elems(vec!["bar"]), + inferred_type: InferredType::Str, + }; + + expr.infer_types(&FunctionTypeRegistry::empty(), Some(&type_default)) + .unwrap(); + + let expected = Expr::ExprBlock( + vec![ + Expr::Let( + VariableId::Local("res".to_string(), Some(Id(0))), + None, + Box::new(Expr::SelectField( + Box::new(Expr::SelectField( + Box::new(Expr::Identifier( + VariableId::Global("foo".to_string()), + InferredType::Record(vec![( + "bar".to_string(), + InferredType::Record(vec![ + ("number".to_string(), InferredType::U64), + ("user-id".to_string(), InferredType::Str), + ]), + )]), + )), + "bar".to_string(), + InferredType::Record(vec![ + ("number".to_string(), InferredType::U64), + ("user-id".to_string(), InferredType::Str), + ]), + )), + "user-id".to_string(), + InferredType::Str, + )), + InferredType::Unknown, + ), + Expr::Let( + VariableId::Local("hello".to_string(), Some(Id(0))), + Some(TypeName::U64), + Box::new(Expr::SelectField( + Box::new(Expr::SelectField( + Box::new(Expr::Identifier( + VariableId::Global("foo".to_string()), + InferredType::Record(vec![( + "bar".to_string(), + InferredType::Record(vec![ + ("number".to_string(), InferredType::U64), + ("user-id".to_string(), InferredType::Str), + ]), + )]), + )), + "bar".to_string(), + InferredType::Record(vec![ + ("number".to_string(), InferredType::U64), + ("user-id".to_string(), InferredType::Str), + ]), + )), + "number".to_string(), + InferredType::U64, + )), + InferredType::Unknown, + ), + Expr::Identifier( + VariableId::Local("hello".to_string(), Some(Id(0))), + InferredType::U64, + ), + ], + InferredType::U64, + ); + + assert_eq!(expr, expected); + } +} diff --git a/golem-rib/src/type_inference/identifier_inference.rs b/golem-rib/src/type_inference/identifier_inference.rs index 148165ec46..7f61281f51 100644 --- a/golem-rib/src/type_inference/identifier_inference.rs +++ b/golem-rib/src/type_inference/identifier_inference.rs @@ -15,7 +15,7 @@ use crate::Expr; pub fn infer_all_identifiers(expr: &mut Expr) -> Result<(), String> { - // We scan top-down and bottom-up to inform the type info between the identifiers + // We scan top-down and bottom-up to inform the type between the identifiers // It doesn't matter which order we do it in (i.e, which identifier expression has the right type isn't a problem), // as we accumulate all the types in both directions internal::infer_all_identifiers_bottom_up(expr)?; diff --git a/golem-rib/src/type_inference/inference_fix_point.rs b/golem-rib/src/type_inference/inference_fix_point.rs index 0704a177e5..1e41fdad5d 100644 --- a/golem-rib/src/type_inference/inference_fix_point.rs +++ b/golem-rib/src/type_inference/inference_fix_point.rs @@ -389,7 +389,8 @@ mod tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let expected = Expr::ExprBlock( vec![ Expr::Let( diff --git a/golem-rib/src/type_inference/inferred_expr.rs b/golem-rib/src/type_inference/inferred_expr.rs index 38acecd6f4..15b92a0aa4 100644 --- a/golem-rib/src/type_inference/inferred_expr.rs +++ b/golem-rib/src/type_inference/inferred_expr.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::call_type::CallType; -use crate::{DynamicParsedFunctionName, Expr, FunctionTypeRegistry, RegistryKey}; +use crate::{DynamicParsedFunctionName, Expr, FunctionTypeRegistry, RegistryKey, TypeDefault}; use std::collections::{HashSet, VecDeque}; #[derive(Debug, Clone)] @@ -27,10 +27,12 @@ impl InferredExpr { pub fn from_expr( expr: &Expr, function_type_registry: &FunctionTypeRegistry, + type_default: Option<&TypeDefault>, ) -> Result { let mut mutable_expr = expr.clone(); + mutable_expr - .infer_types(function_type_registry) + .infer_types(function_type_registry, type_default) .map_err(|err| err.join("\n"))?; Ok(InferredExpr(mutable_expr)) } diff --git a/golem-rib/src/type_inference/mod.rs b/golem-rib/src/type_inference/mod.rs index 67e9068b28..b0de3cb7cd 100644 --- a/golem-rib/src/type_inference/mod.rs +++ b/golem-rib/src/type_inference/mod.rs @@ -16,6 +16,7 @@ pub use call_arguments_inference::*; pub use enum_resolution::*; pub use expr_visitor::*; pub use global_input_inference::*; +pub use global_variable_type_default::*; pub use identifier_inference::*; pub use inference_fix_point::*; pub use inferred_expr::*; @@ -46,6 +47,7 @@ mod variant_resolution; mod enum_resolution; mod global_input_inference; +mod global_variable_type_default; mod inference_fix_point; mod inferred_expr; pub(crate) mod kind; @@ -57,6 +59,34 @@ mod variable_binding_list_reduce; #[cfg(test)] mod type_inference_tests { + mod global_variable { + use crate::type_checker::Path; + use crate::type_inference::global_variable_type_default::TypeDefault; + use crate::{Expr, FunctionTypeRegistry, InferredType, VariableId}; + use test_r::test; + + #[test] + fn test_global_variable_inference() { + let rib_expr = r#" + let res = request.path.user-id; + let hello: u64 = request.path.number; + let y: u64 = res; + hello + "#; + + let mut expr = Expr::from_text(rib_expr).unwrap(); + let type_default = TypeDefault { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["path"]), + inferred_type: InferredType::Str, + }; + + assert!(expr + .infer_types(&FunctionTypeRegistry::empty(), Some(&type_default)) + .is_ok()); + } + } + mod let_binding_tests { use bigdecimal::BigDecimal; use test_r::test; @@ -77,7 +107,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let let_binding = Expr::Let( VariableId::local("x", 0), @@ -125,7 +155,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let let_binding1 = Expr::Let( VariableId::local("x", 0), @@ -207,7 +237,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -241,7 +271,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -281,7 +311,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -434,7 +464,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = internal::expected_expr_for_enum_test(); @@ -511,7 +541,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr).unwrap(); - let result = expr.infer_types(&function_type_registry); + let result = expr.infer_types(&function_type_registry, None); assert!(result.is_ok()); } } @@ -532,7 +562,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -578,7 +608,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -616,7 +646,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -704,7 +734,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -743,7 +773,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -803,7 +833,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -867,7 +897,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -936,7 +966,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1000,7 +1030,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1056,7 +1086,8 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1146,7 +1177,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let let_binding1 = Expr::Let( VariableId::local("x", 0), @@ -1251,7 +1282,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - let result = expr.infer_types(&function_type_registry); + let result = expr.infer_types(&function_type_registry, None); assert!(result.is_ok()); } @@ -1268,7 +1299,8 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1365,7 +1397,8 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1441,7 +1474,8 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1522,7 +1556,8 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty()).unwrap(); + expr.infer_types(&FunctionTypeRegistry::empty(), None) + .unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1674,7 +1709,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1715,7 +1750,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1781,7 +1816,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1836,7 +1871,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1947,7 +1982,7 @@ mod type_inference_tests { let function_type_registry = FunctionTypeRegistry::from_export_metadata(&component_metadata); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = internal::expected_expr_for_select_index(); @@ -1973,7 +2008,7 @@ mod type_inference_tests { let expr = Expr::from_text(rib_expr).unwrap(); let inferred_expr = - InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty()).unwrap(); + InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty(), None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -2052,7 +2087,7 @@ mod type_inference_tests { let expr = Expr::from_text(rib_expr).unwrap(); let inferred_expr = - InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty()).unwrap(); + InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty(), None).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -2104,7 +2139,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry).unwrap(); + expr.infer_types(&function_type_registry, None).unwrap(); let expected = Expr::ExprBlock( vec![ diff --git a/golem-rib/src/type_inference/type_pull_up.rs b/golem-rib/src/type_inference/type_pull_up.rs index a31ca74d2a..94faad81e4 100644 --- a/golem-rib/src/type_inference/type_pull_up.rs +++ b/golem-rib/src/type_inference/type_pull_up.rs @@ -1273,7 +1273,8 @@ mod type_pull_up_tests { let mut expr = Expr::from_text(rib).unwrap(); let function_registry = FunctionTypeRegistry::empty(); - expr.infer_types_initial_phase(&function_registry).unwrap(); + expr.infer_types_initial_phase(&function_registry, None) + .unwrap(); expr.infer_all_identifiers().unwrap(); let new_expr = expr.pull_types_up().unwrap(); diff --git a/golem-rib/src/type_inference/type_unification.rs b/golem-rib/src/type_inference/type_unification.rs index 61ccfe5747..40146494e1 100644 --- a/golem-rib/src/type_inference/type_unification.rs +++ b/golem-rib/src/type_inference/type_unification.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{ArmPattern, Expr}; +use crate::{text, ArmPattern, Expr}; pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { let mut queue = vec![]; @@ -299,8 +299,8 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of identifier {:?}", - expr + "Unable to resolve the type of identifier {}", + text::to_string(expr).unwrap() )); errors.push(e); } diff --git a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs index c3377fd6c0..fe5e46d3d7 100644 --- a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs +++ b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use golem_wasm_ast::analysis::AnalysedExport; -use rib::{CompilerOutput, Expr}; +use rib::{CompilerOutput, Expr, InferredType, Path, TypeDefault, VariableId}; // A wrapper service over original Rib Compiler concerning // the details of the worker bridge. @@ -25,10 +25,15 @@ pub struct DefaultWorkerServiceRibCompiler; impl WorkerServiceRibCompiler for DefaultWorkerServiceRibCompiler { fn compile(rib: &Expr, export_metadata: &[AnalysedExport]) -> Result { - rib::compile_with_limited_globals( + rib::compile_with_restricted_global_variables( rib, &export_metadata.to_vec(), Some(vec!["request".to_string()]), + Some(TypeDefault { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["path"]), + inferred_type: InferredType::Str + }) ) } } From 181b7bb93261e496c2571b9b6958d1b1a3035dea Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 13:23:55 +0530 Subject: [PATCH 2/8] Remove automatic coercison --- golem-rib/src/inferred_type/mod.rs | 11 +- golem-rib/src/inferred_type/unification.rs | 147 ++++++++++++++++- .../src/inferred_type/unification_result.rs | 10 -- golem-rib/src/inferred_type/validation.rs | 150 ------------------ golem-rib/src/parser/type_name.rs | 82 ++++++++++ golem-rib/src/type_inference/mod.rs | 1 - .../src/type_inference/type_unification.rs | 113 ++++++------- golem-worker-service-base/src/headers.rs | 4 +- 8 files changed, 281 insertions(+), 237 deletions(-) delete mode 100644 golem-rib/src/inferred_type/unification_result.rs delete mode 100644 golem-rib/src/inferred_type/validation.rs diff --git a/golem-rib/src/inferred_type/mod.rs b/golem-rib/src/inferred_type/mod.rs index 7353c11016..504ff28fbf 100644 --- a/golem-rib/src/inferred_type/mod.rs +++ b/golem-rib/src/inferred_type/mod.rs @@ -13,17 +13,12 @@ // limitations under the License. pub(crate) use flatten::*; -pub(crate) use unification_result::*; -pub(crate) use validation::*; mod flatten; mod unification; -mod unification_result; -mod validation; - -use std::collections::HashSet; - +use crate::TypeName; use bincode::{Decode, Encode}; use golem_wasm_ast::analysis::*; +use std::collections::HashSet; #[derive(Debug, Hash, Clone, Eq, PartialEq, PartialOrd, Ord, Encode, Decode)] pub enum InferredType { @@ -176,7 +171,7 @@ impl InferredType { } pub fn unify(&self) -> Result { - unification::unify(self) + unification::unify(self).map(|unified| unified.inferred_type()) } pub fn unify_all_alternative_types(types: &Vec) -> InferredType { diff --git a/golem-rib/src/inferred_type/unification.rs b/golem-rib/src/inferred_type/unification.rs index c841285a72..946cb8e47c 100644 --- a/golem-rib/src/inferred_type/unification.rs +++ b/golem-rib/src/inferred_type/unification.rs @@ -1,11 +1,19 @@ -use crate::inferred_type::{flatten_all_of_list, flatten_one_of_list, validate_unified_type}; +use crate::inferred_type::{flatten_all_of_list, flatten_one_of_list}; use crate::InferredType; use std::collections::{HashMap, HashSet}; -pub fn unify(inferred_type: &InferredType) -> Result { +pub struct Unified(InferredType); + +impl Unified { + pub fn inferred_type(&self) -> InferredType { + self.0.clone() + } +} + +pub fn unify(inferred_type: &InferredType) -> Result { let possibly_unified_type = try_unify_type(inferred_type)?; - validate_unified_type(&possibly_unified_type).map(|unified| unified.inferred_type()) + internal::validate_unified_type(&possibly_unified_type) } pub fn try_unify_type(inferred_type: &InferredType) -> Result { @@ -597,7 +605,8 @@ pub fn unify_with_required( } mod internal { - use crate::InferredType; + use crate::inferred_type::unification::Unified; + use crate::{InferredType, TypeName}; use std::collections::HashMap; pub(crate) fn sort_and_convert( @@ -607,4 +616,134 @@ mod internal { vec.sort_by(|a, b| a.0.cmp(&b.0)); vec } + + pub(crate) fn validate_unified_type(inferred_type: &InferredType) -> Result { + match inferred_type { + InferredType::Bool => Ok(Unified(InferredType::Bool)), + InferredType::S8 => Ok(Unified(InferredType::S8)), + InferredType::U8 => Ok(Unified(InferredType::U8)), + InferredType::S16 => Ok(Unified(InferredType::S16)), + InferredType::U16 => Ok(Unified(InferredType::U16)), + InferredType::S32 => Ok(Unified(InferredType::S32)), + InferredType::U32 => Ok(Unified(InferredType::U32)), + InferredType::S64 => Ok(Unified(InferredType::S64)), + InferredType::U64 => Ok(Unified(InferredType::U64)), + InferredType::F32 => Ok(Unified(InferredType::F32)), + InferredType::F64 => Ok(Unified(InferredType::F64)), + InferredType::Chr => Ok(Unified(InferredType::Chr)), + InferredType::Str => Ok(Unified(InferredType::Str)), + InferredType::List(inferred_type) => { + let verified = validate_unified_type(inferred_type)?; + Ok(Unified(InferredType::List(Box::new( + verified.inferred_type(), + )))) + } + InferredType::Tuple(types) => { + let mut verified_types = vec![]; + + for typ in types { + let verified = validate_unified_type(typ)?; + verified_types.push(verified.inferred_type()); + } + + Ok(Unified(InferredType::Tuple(verified_types))) + } + InferredType::Record(field) => { + for (field, typ) in field { + if let Err(unresolved) = validate_unified_type(typ) { + return Err(format!( + "Un-inferred type for field {} in record: {}", + field, unresolved + )); + } + } + + Ok(Unified(InferredType::Record(field.clone()))) + } + InferredType::Flags(flags) => Ok(Unified(InferredType::Flags(flags.clone()))), + InferredType::Enum(enums) => Ok(Unified(InferredType::Enum(enums.clone()))), + InferredType::Option(inferred_type) => { + let result = validate_unified_type(inferred_type)?; + Ok(Unified(InferredType::Option(Box::new( + result.inferred_type(), + )))) + } + result @ InferredType::Result { ok, error } => { + // For Result, we try to be flexible with types + // Example: Allow Rib script to simply return ok(x) as the final output, even if it doesn't know anything about error + match (ok, error) { + (Some(ok), Some(err)) => { + let ok_unified = validate_unified_type(ok); + let err_unified = validate_unified_type(err); + + match (ok_unified, err_unified) { + // We fail only if both are unknown + (Err(ok_err), Err(err_err)) => { + let err = format!("Ok: {}, Error: {}", ok_err, err_err); + Err(err) + } + (_, _) => Ok(Unified(result.clone())), + } + } + + (Some(ok), None) => { + let ok_unified = validate_unified_type(ok); + match ok_unified { + Err(ok_err) => Err(ok_err), + _ => Ok(Unified(result.clone())), + } + } + + (None, Some(err)) => { + let err_unified = validate_unified_type(err); + match err_unified { + Err(err_err) => Err(err_err), + _ => Ok(Unified(result.clone())), + } + } + + (None, None) => Ok(Unified(result.clone())), + } + } + inferred_type @ InferredType::Variant(variant) => { + for (_, typ) in variant { + if let Some(typ) = typ { + validate_unified_type(typ)?; + } + } + Ok(Unified(inferred_type.clone())) + } + resource @ InferredType::Resource { .. } => Ok(Unified(resource.clone())), + InferredType::OneOf(possibilities) => Err(format!( + "Conflicting types: {}", + display_multiple_types(possibilities) + )), + InferredType::AllOf(possibilities) => Err(format!( + "Conflicting types: {}", + display_multiple_types(possibilities) + )), + + InferredType::Unknown => Err("Unresolved types".to_string()), + inferred_type @ InferredType::Sequence(inferred_types) => { + for typ in inferred_types { + validate_unified_type(typ)?; + } + + Ok(Unified(inferred_type.clone())) + } + } + } + + fn display_multiple_types(types: &Vec) -> String { + let types = types + .iter() + .map(|x| { + TypeName::try_from(x.clone()) + .map(|x| x.to_string()) + .unwrap_or(format!("{:?}", x)) + }) + .collect::>(); + + types.join(", ") + } } diff --git a/golem-rib/src/inferred_type/unification_result.rs b/golem-rib/src/inferred_type/unification_result.rs deleted file mode 100644 index bbeb1a7c15..0000000000 --- a/golem-rib/src/inferred_type/unification_result.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::InferredType; - -pub type UnificationResult = Result; -pub struct Unified(pub InferredType); - -impl Unified { - pub fn inferred_type(&self) -> InferredType { - self.0.clone() - } -} diff --git a/golem-rib/src/inferred_type/validation.rs b/golem-rib/src/inferred_type/validation.rs deleted file mode 100644 index ef8794596b..0000000000 --- a/golem-rib/src/inferred_type/validation.rs +++ /dev/null @@ -1,150 +0,0 @@ -use crate::inferred_type::{UnificationResult, Unified}; -use crate::InferredType; - -pub fn validate_unified_type(inferred_type: &InferredType) -> UnificationResult { - match inferred_type { - InferredType::Bool => Ok(Unified(InferredType::Bool)), - InferredType::S8 => Ok(Unified(InferredType::S8)), - InferredType::U8 => Ok(Unified(InferredType::U8)), - InferredType::S16 => Ok(Unified(InferredType::S16)), - InferredType::U16 => Ok(Unified(InferredType::U16)), - InferredType::S32 => Ok(Unified(InferredType::S32)), - InferredType::U32 => Ok(Unified(InferredType::U32)), - InferredType::S64 => Ok(Unified(InferredType::S64)), - InferredType::U64 => Ok(Unified(InferredType::U64)), - InferredType::F32 => Ok(Unified(InferredType::F32)), - InferredType::F64 => Ok(Unified(InferredType::F64)), - InferredType::Chr => Ok(Unified(InferredType::Chr)), - InferredType::Str => Ok(Unified(InferredType::Str)), - InferredType::List(inferred_type) => { - let verified = validate_unified_type(inferred_type)?; - Ok(Unified(InferredType::List(Box::new( - verified.inferred_type(), - )))) - } - InferredType::Tuple(types) => { - let mut verified_types = vec![]; - - for typ in types { - let verified = validate_unified_type(typ)?; - verified_types.push(verified.inferred_type()); - } - - Ok(Unified(InferredType::Tuple(verified_types))) - } - InferredType::Record(field) => { - for (field, typ) in field { - if let Err(unresolved) = validate_unified_type(typ) { - return Err(format!( - "Un-inferred type for field {} in record: {}", - field, unresolved - )); - } - } - - Ok(Unified(InferredType::Record(field.clone()))) - } - InferredType::Flags(flags) => Ok(Unified(InferredType::Flags(flags.clone()))), - InferredType::Enum(enums) => Ok(Unified(InferredType::Enum(enums.clone()))), - InferredType::Option(inferred_type) => { - let result = validate_unified_type(inferred_type)?; - Ok(Unified(InferredType::Option(Box::new( - result.inferred_type(), - )))) - } - result @ InferredType::Result { ok, error } => { - // For Result, we try to be flexible with types - // Example: Allow Rib script to simply return ok(x) as the final output, even if it doesn't know anything about error - match (ok, error) { - (Some(ok), Some(err)) => { - let ok_unified = validate_unified_type(ok); - let err_unified = validate_unified_type(err); - - match (ok_unified, err_unified) { - // We fail only if both are unknown - (Err(ok_err), Err(err_err)) => { - let err = format!("Ok: {}, Error: {}", ok_err, err_err); - Err(err) - } - (_, _) => Ok(Unified(result.clone())), - } - } - - (Some(ok), None) => { - let ok_unified = validate_unified_type(ok); - match ok_unified { - Err(ok_err) => Err(ok_err), - _ => Ok(Unified(result.clone())), - } - } - - (None, Some(err)) => { - let err_unified = validate_unified_type(err); - match err_unified { - Err(err_err) => Err(err_err), - _ => Ok(Unified(result.clone())), - } - } - - (None, None) => Ok(Unified(result.clone())), - } - } - inferred_type @ InferredType::Variant(variant) => { - for (_, typ) in variant { - if let Some(typ) = typ { - validate_unified_type(typ)?; - } - } - Ok(Unified(inferred_type.clone())) - } - resource @ InferredType::Resource { .. } => Ok(Unified(resource.clone())), - InferredType::OneOf(possibilities) => Err(format!("Cannot resolve {:?}", possibilities)), - InferredType::AllOf(possibilities) => coerce_to_numerical_type(possibilities).map(Unified), - InferredType::Unknown => Err("Unknown".to_string()), - inferred_type @ InferredType::Sequence(inferred_types) => { - for typ in inferred_types { - validate_unified_type(typ)?; - } - - Ok(Unified(inferred_type.clone())) - } - } -} - -fn coerce_to_numerical_type(possibilities: &Vec) -> Result { - let mut number_type: Option = None; - let mut has_string = false; - - for ty in possibilities { - match ty { - InferredType::Str => has_string = true, - InferredType::U8 - | InferredType::U16 - | InferredType::U32 - | InferredType::U64 - | InferredType::S8 - | InferredType::S16 - | InferredType::S32 - | InferredType::S64 - | InferredType::F32 - | InferredType::F64 => { - if let Some(existing) = &number_type { - if existing != ty { - return Err("Cannot coerce to a mixed number type".to_string()); - } - } else { - number_type = Some(ty.clone()); - } - } - _ => return Err(format!("Cannot be all of {:?}", possibilities)), - } - } - - if let Some(num) = number_type { - Ok(num) - } else if has_string { - Ok(InferredType::Str) - } else { - Err("No valid number or string type found".to_string()) - } -} diff --git a/golem-rib/src/parser/type_name.rs b/golem-rib/src/parser/type_name.rs index 431f164982..f439c076e6 100644 --- a/golem-rib/src/parser/type_name.rs +++ b/golem-rib/src/parser/type_name.rs @@ -22,6 +22,7 @@ use combine::parser::choice::choice; use combine::{attempt, between, sep_by, Parser}; use combine::{parser, ParseError}; use golem_wasm_ast::analysis::{AnalysedType, TypeResult}; +use poem_openapi::types::Type; use crate::parser::errors::RibParseError; use crate::InferredType; @@ -280,6 +281,87 @@ impl From for InferredType { } } +impl TryFrom for TypeName { + type Error = String; + + fn try_from(value: InferredType) -> Result { + match value { + InferredType::Bool => Ok(TypeName::Bool), + InferredType::S8 => Ok(TypeName::S8), + InferredType::U8 => Ok(TypeName::U8), + InferredType::S16 => Ok(TypeName::S16), + InferredType::U16 => Ok(TypeName::U16), + InferredType::S32 => Ok(TypeName::S32), + InferredType::U32 => Ok(TypeName::U32), + InferredType::S64 => Ok(TypeName::S64), + InferredType::U64 => Ok(TypeName::U64), + InferredType::F32 => Ok(TypeName::F32), + InferredType::F64 => Ok(TypeName::F64), + InferredType::Chr => Ok(TypeName::Chr), + InferredType::Str => Ok(TypeName::Str), + InferredType::List(inferred_type) => { + let verified = inferred_type.deref().clone().try_into()?; + Ok(TypeName::List(Box::new(verified))) + } + InferredType::Tuple(inferred_types) => { + let mut verified_types = vec![]; + for typ in inferred_types { + let verified = typ.try_into()?; + verified_types.push(verified); + } + Ok(TypeName::Tuple(verified_types)) + } + InferredType::Record(name_and_types) => { + let mut fields = vec![]; + for (field, typ) in name_and_types { + let verified = typ.try_into()?; + fields.push((field, Box::new(verified))); + } + Ok(TypeName::Record(fields)) + } + InferredType::Flags(flags) => Ok(TypeName::Flags(flags)), + InferredType::Enum(enums) => Ok(TypeName::Enum(enums)), + InferredType::Option(inferred_type) => { + let result = inferred_type.deref().clone().try_into()?; + Ok(TypeName::Option(Box::new(result))) + } + InferredType::Result { ok, error } => { + let ok_unified = ok.map(|ok| ok.deref().clone().try_into()).transpose()?; + let err_unified = error + .map(|err| err.deref().clone().try_into()) + .transpose()?; + Ok(TypeName::Result { + ok: ok_unified.map(Box::new), + error: err_unified.map(Box::new), + }) + } + InferredType::Variant(variant) => { + let mut cases = vec![]; + for (case, typ) in variant { + let verified = typ.map(|x| TypeName::try_from(x)).transpose()?; + cases.push((case, verified.map(Box::new))); + } + Ok(TypeName::Variant { cases }) + } + InferredType::Resource { .. } => { + Err("Cannot convert a resource type to a type name".to_string()) + } + InferredType::OneOf(_) => { + Err("Cannot convert a one of type to a type name".to_string()) + } + InferredType::AllOf(_) => { + Err("Cannot convert a all of type to a type name".to_string()) + } + InferredType::Unknown => { + Err("Cannot convert an unknown type to a type name".to_string()) + } + InferredType::Sequence(_) => { + Err("Cannot convert a sequence type to a type name".to_string()) + } + } + } +} + pub fn parse_basic_type() -> impl Parser where Input: combine::Stream, diff --git a/golem-rib/src/type_inference/mod.rs b/golem-rib/src/type_inference/mod.rs index b0de3cb7cd..df6fbc4272 100644 --- a/golem-rib/src/type_inference/mod.rs +++ b/golem-rib/src/type_inference/mod.rs @@ -70,7 +70,6 @@ mod type_inference_tests { let rib_expr = r#" let res = request.path.user-id; let hello: u64 = request.path.number; - let y: u64 = res; hello "#; diff --git a/golem-rib/src/type_inference/type_unification.rs b/golem-rib/src/type_inference/type_unification.rs index 40146494e1..04b9614bdd 100644 --- a/golem-rib/src/type_inference/type_unification.rs +++ b/golem-rib/src/type_inference/type_unification.rs @@ -29,7 +29,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } @@ -42,8 +42,10 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of record {}", expr_str)); - errors.push(e); + errors.push(format!( + "unable to infer the type of record {}, {}", + expr_str, e + )); } } } @@ -55,8 +57,10 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of tuple {}", expr_str)); - errors.push(e); + errors.push(format!( + "unable to infer the type of tuple {}, {}", + expr_str, e + )); } } } @@ -68,10 +72,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of sequence {}", - expr_str + "unable to infer the type of sequence {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -82,8 +85,10 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of option {}", expr_str)); - errors.push(e); + errors.push(format!( + "unable to infer the type of option {}, {}", + expr_str, e + )); } } } @@ -94,8 +99,10 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of option {}", expr_str)); - errors.push(e); + errors.push(format!( + "unable to infer the type of option {}, {}", + expr_str, e + )); } } } @@ -108,10 +115,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of `result::ok` {}", - expr_str + "unable to infer the type of `result::ok` {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -124,10 +130,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of `result::err` {}", - expr_str + "unable to infer the type of `result::err` {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -142,10 +147,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of condition expression {}", - expr_str + "unable to infer the type of condition expression {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -165,11 +169,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of list comprehension {}", - expr_str + "unable to infer the type of list comprehension {}, {}", + expr_str, e )); - - errors.push(e) } } } @@ -191,11 +193,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of list aggregation {}", - expr_str + "unable to infer the type of list aggregation {}, {}", + expr_str, e )); - - errors.push(e) } } } @@ -214,10 +214,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of pattern match expression {}", - expr_str + "unable to infer the type of pattern match expression {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -230,10 +229,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of function return {}", - function_call + "unable to infer the type of function return {}, {}", + function_call, e )); - errors.push(e); } } } @@ -245,10 +243,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of field selection {}", - expr_str + "unable to infer the type of field selection {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -260,10 +257,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { errors.push(format!( - "Unable to resolve the type of index selection {}", - expr_str + "unable to infer the type of index selection {}, {}", + expr_str, e )); - errors.push(e); } } } @@ -277,7 +273,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } @@ -287,8 +283,10 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of flags {}", expr_str)); - errors.push(e); + errors.push(format!( + "unable to infer the type of flags {}, {}", + expr_str, e + )); } } } @@ -298,11 +296,9 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!( - "Unable to resolve the type of identifier {}", - text::to_string(expr).unwrap() - )); - errors.push(e); + let expr_str = expr.to_string(); + let error = format!("unable to infer the type of {}, {}", expr_str, e); + errors.push(error); } } } @@ -317,9 +313,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, - Err(e) => { - errors.push(e); - } + Err(_) => {} } } @@ -330,8 +324,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of {}", expr_str)); - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } @@ -342,8 +335,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of {}", expr_str)); - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } @@ -354,8 +346,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of {}", expr_str)); - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } @@ -366,8 +357,7 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of {}", expr_str)); - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } @@ -466,8 +456,7 @@ mod internal { match unified_inferred_type { Ok(unified_type) => *inferred_type = unified_type, Err(e) => { - errors.push(format!("Unable to resolve the type of {}", expr_str)); - errors.push(e); + errors.push(format!("unable to infer the type of {}, {}", expr_str, e)); } } } diff --git a/golem-worker-service-base/src/headers.rs b/golem-worker-service-base/src/headers.rs index 4f490d534a..ed17769116 100644 --- a/golem-worker-service-base/src/headers.rs +++ b/golem-worker-service-base/src/headers.rs @@ -39,7 +39,7 @@ impl ResolvedResponseHeaders { let value_str = value .get_literal() .map(|primitive| primitive.to_string()) - .unwrap_or_else(|| "Unable to resolve header".to_string()); + .unwrap_or_else(|| "unable to infer header".to_string()); resolved_headers.insert(field_def.name, value_str); } @@ -47,7 +47,7 @@ impl ResolvedResponseHeaders { let headers = (&resolved_headers) .try_into() .map_err(|e: http::Error| e.to_string()) - .map_err(|e| format!("Unable to resolve valid headers. Error: {e}"))?; + .map_err(|e| format!("unable to infer valid headers. Error: {e}"))?; Ok(ResolvedResponseHeaders { headers }) } From 1279382858a843072b9e3c05b1145e7773600c11 Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 13:26:29 +0530 Subject: [PATCH 3/8] Remove automatic coercison --- golem-rib/src/inferred_type/mod.rs | 1 - golem-rib/src/inferred_type/unification.rs | 2 +- golem-rib/src/parser/type_name.rs | 3 +-- golem-rib/src/type_inference/type_unification.rs | 7 +++---- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/golem-rib/src/inferred_type/mod.rs b/golem-rib/src/inferred_type/mod.rs index 504ff28fbf..680d3c8d82 100644 --- a/golem-rib/src/inferred_type/mod.rs +++ b/golem-rib/src/inferred_type/mod.rs @@ -15,7 +15,6 @@ pub(crate) use flatten::*; mod flatten; mod unification; -use crate::TypeName; use bincode::{Decode, Encode}; use golem_wasm_ast::analysis::*; use std::collections::HashSet; diff --git a/golem-rib/src/inferred_type/unification.rs b/golem-rib/src/inferred_type/unification.rs index 946cb8e47c..7ae851d4e9 100644 --- a/golem-rib/src/inferred_type/unification.rs +++ b/golem-rib/src/inferred_type/unification.rs @@ -734,7 +734,7 @@ mod internal { } } - fn display_multiple_types(types: &Vec) -> String { + fn display_multiple_types(types: &[InferredType]) -> String { let types = types .iter() .map(|x| { diff --git a/golem-rib/src/parser/type_name.rs b/golem-rib/src/parser/type_name.rs index f439c076e6..268901840c 100644 --- a/golem-rib/src/parser/type_name.rs +++ b/golem-rib/src/parser/type_name.rs @@ -22,7 +22,6 @@ use combine::parser::choice::choice; use combine::{attempt, between, sep_by, Parser}; use combine::{parser, ParseError}; use golem_wasm_ast::analysis::{AnalysedType, TypeResult}; -use poem_openapi::types::Type; use crate::parser::errors::RibParseError; use crate::InferredType; @@ -338,7 +337,7 @@ impl TryFrom for TypeName { InferredType::Variant(variant) => { let mut cases = vec![]; for (case, typ) in variant { - let verified = typ.map(|x| TypeName::try_from(x)).transpose()?; + let verified = typ.map(TypeName::try_from).transpose()?; cases.push((case, verified.map(Box::new))); } Ok(TypeName::Variant { cases }) diff --git a/golem-rib/src/type_inference/type_unification.rs b/golem-rib/src/type_inference/type_unification.rs index 04b9614bdd..6e77e69cf1 100644 --- a/golem-rib/src/type_inference/type_unification.rs +++ b/golem-rib/src/type_inference/type_unification.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{text, ArmPattern, Expr}; +use crate::{ArmPattern, Expr}; pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { let mut queue = vec![]; @@ -311,9 +311,8 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), Vec> { let unified_inferred_type = inferred_type.unify(); - match unified_inferred_type { - Ok(unified_type) => *inferred_type = unified_type, - Err(_) => {} + if let Ok(unified_type) = unified_inferred_type { + *inferred_type = unified_type } } From 2e10cc84c15f61d0ae1775942db92bb01721e045 Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 17:53:06 +0530 Subject: [PATCH 4/8] Reformat code --- golem-worker-service-base/src/gateway_rib_compiler/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs index fe5e46d3d7..773e90554e 100644 --- a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs +++ b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs @@ -32,8 +32,8 @@ impl WorkerServiceRibCompiler for DefaultWorkerServiceRibCompiler { Some(TypeDefault { variable_id: VariableId::global("request".to_string()), path: Path::from_elems(vec!["path"]), - inferred_type: InferredType::Str - }) + inferred_type: InferredType::Str, + }), ) } } From efb835439a96a563a682e01979a44b1aaf0adeef Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 18:59:14 +0530 Subject: [PATCH 5/8] Reformat code --- golem-rib/src/compiler/byte_code.rs | 30 ++--- golem-rib/src/compiler/desugar.rs | 2 +- golem-rib/src/compiler/mod.rs | 14 +- golem-rib/src/expr.rs | 23 ++-- golem-rib/src/interpreter/rib_interpreter.rs | 116 +++++++++++++++-- ...efault.rs => global_variable_type_spec.rs} | 58 +++++---- .../src/type_inference/inference_fix_point.rs | 2 +- golem-rib/src/type_inference/inferred_expr.rs | 8 +- golem-rib/src/type_inference/mod.rs | 121 +++++++++++------- golem-rib/src/type_inference/type_pull_up.rs | 2 +- .../src/gateway_rib_compiler/mod.rs | 6 +- 11 files changed, 261 insertions(+), 121 deletions(-) rename golem-rib/src/type_inference/{global_variable_type_default.rs => global_variable_type_spec.rs} (95%) diff --git a/golem-rib/src/compiler/byte_code.rs b/golem-rib/src/compiler/byte_code.rs index 4465cff982..bd7f5e4638 100644 --- a/golem-rib/src/compiler/byte_code.rs +++ b/golem-rib/src/compiler/byte_code.rs @@ -706,7 +706,7 @@ mod compiler_tests { fn test_instructions_for_literal() { let literal = Expr::Literal("hello".to_string(), InferredType::Str); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&literal, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&literal, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -725,7 +725,7 @@ mod compiler_tests { let variable_id = VariableId::local("request", 0); let empty_registry = FunctionTypeRegistry::empty(); let expr = Expr::Identifier(variable_id.clone(), inferred_input_type); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -752,7 +752,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -787,7 +787,7 @@ mod compiler_tests { let expr = Expr::equal_to(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -826,7 +826,7 @@ mod compiler_tests { let expr = Expr::greater_than(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -865,7 +865,7 @@ mod compiler_tests { let expr = Expr::less_than(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -904,7 +904,7 @@ mod compiler_tests { let expr = Expr::greater_than_or_equal_to(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -943,7 +943,7 @@ mod compiler_tests { let expr = Expr::less_than_or_equal_to(number_f32, number_u32); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -983,7 +983,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1027,7 +1027,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1057,7 +1057,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1097,7 +1097,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1147,7 +1147,7 @@ mod compiler_tests { let expr = Expr::SelectField(Box::new(record), "bar_key".to_string(), InferredType::Str); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1194,7 +1194,7 @@ mod compiler_tests { let expr = Expr::SelectIndex(Box::new(sequence), 1, InferredType::Str); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); @@ -1243,7 +1243,7 @@ mod compiler_tests { ); let empty_registry = FunctionTypeRegistry::empty(); - let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, None).unwrap(); + let inferred_expr = InferredExpr::from_expr(&expr, &empty_registry, &vec![]).unwrap(); let instructions = RibByteCode::from_expr(&inferred_expr).unwrap(); diff --git a/golem-rib/src/compiler/desugar.rs b/golem-rib/src/compiler/desugar.rs index a36d5f02b2..4242afb69f 100644 --- a/golem-rib/src/compiler/desugar.rs +++ b/golem-rib/src/compiler/desugar.rs @@ -574,7 +574,7 @@ mod desugar_tests { let function_type_registry = get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let desugared_expr = match internal::last_expr(&expr) { Expr::PatternMatch(predicate, match_arms, _) => { diff --git a/golem-rib/src/compiler/mod.rs b/golem-rib/src/compiler/mod.rs index a392474be3..46fbb41af7 100644 --- a/golem-rib/src/compiler/mod.rs +++ b/golem-rib/src/compiler/mod.rs @@ -21,7 +21,7 @@ pub use type_with_unit::*; pub use worker_functions_in_rib::*; use crate::type_registry::FunctionTypeRegistry; -use crate::{Expr, InferredExpr, RibInputTypeInfo, RibOutputTypeInfo, TypeDefault}; +use crate::{Expr, GlobalVariableTypeSpec, InferredExpr, RibInputTypeInfo, RibOutputTypeInfo}; mod byte_code; mod compiler_output; @@ -34,7 +34,7 @@ pub fn compile( expr: &Expr, export_metadata: &Vec, ) -> Result { - compile_with_restricted_global_variables(expr, export_metadata, None, None) + compile_with_restricted_global_variables(expr, export_metadata, None, &vec![]) } // Rib allows global input variables, however, we can choose to fail compilation @@ -42,15 +42,15 @@ pub fn compile( // There is no restriction imposed to the type of this variable. // Also we can specify types for certain global variables and if needed be specific // on the path. Example: All variables under the variable `path` which is under the global variable `request` can be `Str` +// Not all global variable require a type specification, and you can leave it to the compiler. pub fn compile_with_restricted_global_variables( expr: &Expr, export_metadata: &Vec, allowed_global_variables: Option>, - global_variable_type_default: Option, + global_variable_type_spec: &Vec, ) -> Result { let type_registry = FunctionTypeRegistry::from_export_metadata(export_metadata); - let inferred_expr = - InferredExpr::from_expr(expr, &type_registry, global_variable_type_default.as_ref())?; + let inferred_expr = InferredExpr::from_expr(expr, &type_registry, global_variable_type_spec)?; let function_calls_identified = WorkerFunctionsInRib::from_inferred_expr(&inferred_expr, &type_registry)?; @@ -59,7 +59,9 @@ pub fn compile_with_restricted_global_variables( let global_keys: HashSet<_> = global_input_type_info.types.keys().cloned().collect(); - if let Some(info) = &global_variable_type_default { + // We make the global variable spec given by the user is infact corresponds to the real + // global variables identified by the compiler + for info in global_variable_type_spec { if !info.variable_id.is_global() || !global_keys.contains(&info.variable_id.to_string()) { return Err("Only global variables can have default types".to_string()); } diff --git a/golem-rib/src/expr.rs b/golem-rib/src/expr.rs index 2d0d414ae5..acf3fd7f13 100644 --- a/golem-rib/src/expr.rs +++ b/golem-rib/src/expr.rs @@ -17,8 +17,8 @@ use crate::parser::block::block; use crate::parser::type_name::TypeName; use crate::type_registry::FunctionTypeRegistry; use crate::{ - from_string, text, type_checker, type_inference, DynamicParsedFunctionName, InferredType, - ParsedFunctionName, TypeDefault, VariableId, + from_string, text, type_checker, type_inference, DynamicParsedFunctionName, + GlobalVariableTypeSpec, InferredType, ParsedFunctionName, VariableId, }; use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; use combine::parser::char::spaces; @@ -392,8 +392,11 @@ impl Expr { ) } - pub fn override_types(&self, type_default: &TypeDefault) -> Result { - let result_expr = type_inference::tag_default_global_variable_type(self, type_default)?; + pub fn bind_global_variables_type( + &self, + type_spec: &Vec, + ) -> Result { + let result_expr = type_inference::bind_global_variables_type(self, type_spec)?; Ok(result_expr) } @@ -546,9 +549,9 @@ impl Expr { pub fn infer_types( &mut self, function_type_registry: &FunctionTypeRegistry, - type_default: Option<&TypeDefault>, + type_spec: &Vec, ) -> Result<(), Vec> { - self.infer_types_initial_phase(function_type_registry, type_default)?; + self.infer_types_initial_phase(function_type_registry, type_spec)?; self.infer_call_arguments_type(function_type_registry) .map_err(|x| vec![x])?; type_inference::type_inference_fix_point(Self::inference_scan, self) @@ -562,11 +565,11 @@ impl Expr { pub fn infer_types_initial_phase( &mut self, function_type_registry: &FunctionTypeRegistry, - type_default: Option<&TypeDefault>, + type_spec: &Vec, ) -> Result<(), Vec> { - if let Some(type_default) = type_default { - *self = self.override_types(type_default).map_err(|x| vec![x])?; - } + *self = self + .bind_global_variables_type(type_spec) + .map_err(|x| vec![x])?; self.bind_types(); self.bind_variables_of_list_comprehension(); self.bind_variables_of_list_reduce(); diff --git a/golem-rib/src/interpreter/rib_interpreter.rs b/golem-rib/src/interpreter/rib_interpreter.rs index cba539b042..91c6c96d49 100644 --- a/golem-rib/src/interpreter/rib_interpreter.rs +++ b/golem-rib/src/interpreter/rib_interpreter.rs @@ -1379,6 +1379,97 @@ mod interpreter_tests { assert_eq!(result.get_val().unwrap(), 2i32.into_value_and_type()); } + mod global_variable_tests { + use crate::interpreter::rib_interpreter::interpreter_tests::internal; + use crate::interpreter::rib_interpreter::interpreter_tests::internal::get_value_and_type; + use crate::{ + compiler, Expr, GlobalVariableTypeSpec, InferredType, Path, RibInput, VariableId, + }; + use golem_wasm_ast::analysis::analysed_type::{record, s8, str}; + use golem_wasm_ast::analysis::NameTypePair; + use golem_wasm_rpc::{Value, ValueAndType}; + use std::collections::HashMap; + use test_r::test; + + #[test] + async fn test_global_variable_custom() { + let mut rib_input = HashMap::new(); + + let value_and_type = get_value_and_type( + &record(vec![ + NameTypePair { + name: "path".to_string(), + typ: record(vec![NameTypePair { + name: "user-id".to_string(), + typ: str(), + }]), + }, + NameTypePair { + name: "headers".to_string(), + typ: record(vec![ + NameTypePair { + name: "name".to_string(), + typ: str(), + }, + NameTypePair { + name: "age".to_string(), + typ: str(), + }, + ]), + }, + ]), + r#"{path : { user-id: "1" }, headers: { name: "foo", age: "20" }}"#, + ); + + rib_input.insert("request".to_string(), value_and_type); + + let mut interpreter = internal::static_test_interpreter( + &ValueAndType::new(Value::S8(1), s8()), + Some(RibInput::new(rib_input)), + ); + + let rib_expr = r#" + let res1 = request.path.user-id; + let res2 = request.headers.name; + let res3 = request.headers.age; + "${res1}-${res2}-${res3}" + "#; + + let expr = Expr::from_text(rib_expr).unwrap(); + + let type_spec = vec![ + GlobalVariableTypeSpec { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["path"]), + inferred_type: InferredType::Str, + }, + GlobalVariableTypeSpec { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["headers"]), + inferred_type: InferredType::Str, + }, + ]; + + let compiled = compiler::compile_with_restricted_global_variables( + &expr, + &vec![], + None, + &type_spec, + ) + .unwrap(); + + let result = interpreter + .run(compiled.byte_code) + .await + .unwrap() + .get_val() + .unwrap() + .value; + + assert_eq!(result, Value::String("1-foo-20".to_string())) + } + } + mod list_reduce_interpreter_tests { use crate::interpreter::rib_interpreter::Interpreter; use crate::{compiler, Expr}; @@ -1596,7 +1687,7 @@ mod interpreter_tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let compiled = compiler::compile(&expr, &vec![]).unwrap(); let result = interpreter.run(compiled.byte_code).await.unwrap(); @@ -1617,7 +1708,7 @@ mod interpreter_tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let compiled = compiler::compile(&expr, &vec![]).unwrap(); let result = interpreter.run(compiled.byte_code).await.unwrap(); @@ -1639,7 +1730,7 @@ mod interpreter_tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let compiled = compiler::compile(&expr, &vec![]).unwrap(); @@ -1736,7 +1827,7 @@ mod interpreter_tests { let result_value = internal::get_value_and_type(&output_analysed_type, r#"ok(1)"#); - let mut interpreter = internal::static_test_interpreter(&result_value); + let mut interpreter = internal::static_test_interpreter(&result_value, None); let analysed_exports = internal::get_component_metadata( "my-worker-function", @@ -1774,7 +1865,7 @@ mod interpreter_tests { let result_value = internal::get_value_and_type(&output_analysed_type, r#"err("failed")"#); - let mut interpreter = internal::static_test_interpreter(&result_value); + let mut interpreter = internal::static_test_interpreter(&result_value, None); let analysed_exports = internal::get_component_metadata( "my-worker-function", @@ -1859,7 +1950,7 @@ mod interpreter_tests { internal::get_shopping_cart_metadata_with_cart_resource_with_parameters(); let compiled = compiler::compile(&expr, &component_metadata).unwrap(); - let mut rib_executor = internal::static_test_interpreter(&result_value); + let mut rib_executor = internal::static_test_interpreter(&result_value, None); let result = rib_executor.run(compiled.byte_code).await.unwrap(); assert_eq!(result.get_val().unwrap(), result_value); @@ -1893,7 +1984,7 @@ mod interpreter_tests { internal::get_shopping_cart_metadata_with_cart_resource_with_parameters(); let compiled = compiler::compile(&expr, &component_metadata).unwrap(); - let mut rib_executor = internal::static_test_interpreter(&result_value); + let mut rib_executor = internal::static_test_interpreter(&result_value, None); let result = rib_executor.run(compiled.byte_code).await.unwrap(); assert_eq!(result.get_val().unwrap(), "foo".into_value_and_type()); @@ -2004,7 +2095,7 @@ mod interpreter_tests { let component_metadata = internal::get_shopping_cart_metadata_with_cart_raw_resource(); let compiled = compiler::compile(&expr, &component_metadata).unwrap(); - let mut rib_executor = internal::static_test_interpreter(&result_value); + let mut rib_executor = internal::static_test_interpreter(&result_value, None); let result = rib_executor.run(compiled.byte_code).await.unwrap(); assert_eq!(result.get_val().unwrap(), "foo".into_value_and_type()); @@ -2058,7 +2149,7 @@ mod interpreter_tests { let component_metadata = internal::get_shopping_cart_metadata_with_cart_raw_resource(); let compiled = compiler::compile(&expr, &component_metadata).unwrap(); - let mut rib_executor = internal::static_test_interpreter(&result_value); + let mut rib_executor = internal::static_test_interpreter(&result_value, None); let result = rib_executor.run(compiled.byte_code).await.unwrap(); assert_eq!(result.get_val().unwrap(), result_value); @@ -2318,9 +2409,12 @@ mod interpreter_tests { golem_wasm_rpc::parse_value_and_type(analysed_type, wasm_wave_str).unwrap() } - pub(crate) fn static_test_interpreter(result_value: &ValueAndType) -> Interpreter { + pub(crate) fn static_test_interpreter( + result_value: &ValueAndType, + input: Option, + ) -> Interpreter { Interpreter { - input: RibInput::default(), + input: input.unwrap_or_default(), invoke: static_worker_invoke(result_value), } } diff --git a/golem-rib/src/type_inference/global_variable_type_default.rs b/golem-rib/src/type_inference/global_variable_type_spec.rs similarity index 95% rename from golem-rib/src/type_inference/global_variable_type_default.rs rename to golem-rib/src/type_inference/global_variable_type_spec.rs index e4ab1d05b5..be031b5806 100644 --- a/golem-rib/src/type_inference/global_variable_type_default.rs +++ b/golem-rib/src/type_inference/global_variable_type_spec.rs @@ -4,13 +4,13 @@ use std::collections::VecDeque; // The goal is to be able to specify the types associated with an identifier. // i.e, `a.*` is always `Str`, or `a.b.*` is always `Str`, or `a.b.c` is always `Str` -// This can be represented using `TypeDefault { a, vec![], Str }`, `TypeDefault {a, b, Str}` and -// `TypeDefault {a, vec[b, c], Str}` respectively +// This can be represented using `TypeSpecification { a, vec![], Str }`, `TypeSpecification {a, b, Str}` and +// `TypeSpecification {a, vec[b, c], Str}` respectively // If you specify completely opposite types to be default, you will get a compilation error. // If you tried to specify a variable is always string, but compiler identifies it's usage as `U64`, // then it chooses `U64` and discards the default. If the compiler finds its usages as `Str` #[derive(Clone, Debug)] -pub struct TypeDefault { +pub struct GlobalVariableTypeSpec { pub variable_id: VariableId, pub path: Path, pub inferred_type: InferredType, @@ -21,8 +21,8 @@ pub struct TypeDefault { // // The goal is to be able to specify the types associated with an identifier // i.e, `a.*` is always `Str`, or `a.b.*` is always `Str`, or `a.b.c` is always `Str` -// This can be represented using `TypeDefault { a, vec![], Str }`, `TypeDefault {a, b, Str}` and -// `TypeDefault {a, vec[b, c], Str}` respectively +// This can be represented using `TypeSpecification { a, vec![], Str }`, `TypeSpecification {a, b, Str}` and +// `TypeSpecification {a, vec[b, c], Str}` respectively // // We initially create queue of immutable Expr (to be able to push mutable version has to do into reference count logic in Rust) // and then push it to an intermediate stack and recreate the Expr. This is similar to `type_pull_up` phase. @@ -40,7 +40,7 @@ pub struct TypeDefault { // Example queue: // [select_field(select_field(a, b), c), select_field(a, b), identifier(a)] // -// Example Walkthrough: Given `TypeDefault { a, vec[b, c], Str]` +// Example Walkthrough: Given `TypeSpecification { a, vec[b, c], Str]` // // 1. Pop the back element in the queue to get `identifier(a)`. // - Check the `temp_stack` by popping from the front. @@ -60,11 +60,22 @@ pub struct TypeDefault { // // The same algorithm above is tweaked even if users specified partial paths. Example: // Everything under `a.b` (regardless of the existence of c and d) at their leafs follow the default type -pub fn tag_default_global_variable_type( + +pub fn bind_global_variables_type( expr: &Expr, - type_default: &TypeDefault, + type_pecs: &Vec, ) -> Result { - let mut path = type_default.path.clone(); + let mut result_expr = expr.clone(); + + for spec in type_pecs { + result_expr = bind_with_type_spec(&result_expr, spec)?; + } + + Ok(result_expr) +} + +fn bind_with_type_spec(expr: &Expr, type_spec: &GlobalVariableTypeSpec) -> Result { + let mut path = type_spec.path.clone(); let mut expr_queue = VecDeque::new(); @@ -75,7 +86,7 @@ pub fn tag_default_global_variable_type( while let Some(expr) = expr_queue.pop_back() { match expr { expr @ Expr::Identifier(variable_id, _) => { - if variable_id == &type_default.variable_id { + if variable_id == &type_spec.variable_id { if path.is_empty() { let continue_traverse = matches!(expr_queue.back(), Some(Expr::SelectField(inner, _, _)) if inner.as_ref() == expr); @@ -85,7 +96,7 @@ pub fn tag_default_global_variable_type( temp_stack.push_front(( Expr::Identifier( variable_id.clone(), - type_default.inferred_type.clone(), + type_spec.inferred_type.clone(), ), false, )); @@ -108,7 +119,7 @@ pub fn tag_default_global_variable_type( current_inferred_type, &mut temp_stack, &mut path, - &type_default.inferred_type, + &type_spec.inferred_type, )?; } @@ -472,7 +483,8 @@ mod internal { if part_of_path { match path.current() { Some(PathElem::Field(name)) if name == field => path.progress(), - Some(_) => return Err("We disallow type overrides at indices".to_string()), + Some(PathElem::Field(_)) => {} + Some(PathElem::Index(_)) => {} None => {} } @@ -917,13 +929,13 @@ mod tests { ) .unwrap(); - let type_default = TypeDefault { + let type_spec = GlobalVariableTypeSpec { variable_id: VariableId::global("foo".to_string()), path: Path::default(), inferred_type: InferredType::Str, }; - let result = expr.override_types(&type_default).unwrap(); + let result = expr.bind_global_variables_type(&vec![type_spec]).unwrap(); let expected = Expr::Identifier(VariableId::global("foo".to_string()), InferredType::Str); @@ -940,13 +952,13 @@ mod tests { ) .unwrap(); - let type_default = TypeDefault { + let type_spec = GlobalVariableTypeSpec { variable_id: VariableId::global("foo".to_string()), path: Path::from_elems(vec!["bar"]), inferred_type: InferredType::Str, }; - let result = expr.override_types(&type_default).unwrap(); + let result = expr.bind_global_variables_type(&vec![type_spec]).unwrap(); let expected = Expr::SelectField( Box::new(Expr::select_field(Expr::identifier("foo"), "bar")), @@ -966,13 +978,13 @@ mod tests { ) .unwrap(); - let type_default = TypeDefault { + let type_spec = GlobalVariableTypeSpec { variable_id: VariableId::global("foo".to_string()), path: Path::from_elems(vec!["bar", "baz"]), inferred_type: InferredType::Str, }; - let result = expr.override_types(&type_default).unwrap(); + let result = expr.bind_global_variables_type(&vec![type_spec]).unwrap(); let expected = Expr::SelectField( Box::new(Expr::select_field(Expr::identifier("foo"), "bar")), @@ -992,13 +1004,13 @@ mod tests { ) .unwrap(); - let type_default = TypeDefault { + let type_spec = GlobalVariableTypeSpec { variable_id: VariableId::global("foo".to_string()), path: Path::default(), inferred_type: InferredType::Str, }; - let result = expr.override_types(&type_default).unwrap(); + let result = expr.bind_global_variables_type(&vec![type_spec]).unwrap(); let expected = Expr::SelectField( Box::new(Expr::select_field(Expr::identifier("foo"), "bar")), @@ -1020,13 +1032,13 @@ mod tests { ) .unwrap(); - let type_default = TypeDefault { + let type_spec = GlobalVariableTypeSpec { variable_id: VariableId::global("foo".to_string()), path: Path::from_elems(vec!["bar"]), inferred_type: InferredType::Str, }; - expr.infer_types(&FunctionTypeRegistry::empty(), Some(&type_default)) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![type_spec]) .unwrap(); let expected = Expr::ExprBlock( diff --git a/golem-rib/src/type_inference/inference_fix_point.rs b/golem-rib/src/type_inference/inference_fix_point.rs index 1e41fdad5d..fb3b8f7e7f 100644 --- a/golem-rib/src/type_inference/inference_fix_point.rs +++ b/golem-rib/src/type_inference/inference_fix_point.rs @@ -389,7 +389,7 @@ mod tests { "#; let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let expected = Expr::ExprBlock( vec![ diff --git a/golem-rib/src/type_inference/inferred_expr.rs b/golem-rib/src/type_inference/inferred_expr.rs index 15b92a0aa4..5bf5476f74 100644 --- a/golem-rib/src/type_inference/inferred_expr.rs +++ b/golem-rib/src/type_inference/inferred_expr.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::call_type::CallType; -use crate::{DynamicParsedFunctionName, Expr, FunctionTypeRegistry, RegistryKey, TypeDefault}; +use crate::{ + DynamicParsedFunctionName, Expr, FunctionTypeRegistry, GlobalVariableTypeSpec, RegistryKey, +}; use std::collections::{HashSet, VecDeque}; #[derive(Debug, Clone)] @@ -27,12 +29,12 @@ impl InferredExpr { pub fn from_expr( expr: &Expr, function_type_registry: &FunctionTypeRegistry, - type_default: Option<&TypeDefault>, + type_spec: &Vec, ) -> Result { let mut mutable_expr = expr.clone(); mutable_expr - .infer_types(function_type_registry, type_default) + .infer_types(function_type_registry, type_spec) .map_err(|err| err.join("\n"))?; Ok(InferredExpr(mutable_expr)) } diff --git a/golem-rib/src/type_inference/mod.rs b/golem-rib/src/type_inference/mod.rs index df6fbc4272..7234e291c7 100644 --- a/golem-rib/src/type_inference/mod.rs +++ b/golem-rib/src/type_inference/mod.rs @@ -16,7 +16,7 @@ pub use call_arguments_inference::*; pub use enum_resolution::*; pub use expr_visitor::*; pub use global_input_inference::*; -pub use global_variable_type_default::*; +pub use global_variable_type_spec::*; pub use identifier_inference::*; pub use inference_fix_point::*; pub use inferred_expr::*; @@ -34,39 +34,38 @@ pub use variable_binding_pattern_match::*; pub use variant_resolution::*; mod call_arguments_inference; +mod enum_resolution; mod expr_visitor; +mod global_input_inference; +mod global_variable_type_spec; mod identifier_inference; +mod inference_fix_point; +mod inferred_expr; +pub(crate) mod kind; mod rib_input_type; +mod rib_output_type; +mod type_binding; mod type_pull_up; mod type_push_down; mod type_reset; mod type_unification; mod variable_binding_let_assignment; -mod variable_binding_pattern_match; -mod variant_resolution; - -mod enum_resolution; -mod global_input_inference; -mod global_variable_type_default; -mod inference_fix_point; -mod inferred_expr; -pub(crate) mod kind; -mod rib_output_type; -mod type_binding; mod variable_binding_list_comprehension; mod variable_binding_list_reduce; +mod variable_binding_pattern_match; +mod variant_resolution; #[cfg(test)] mod type_inference_tests { mod global_variable { use crate::type_checker::Path; - use crate::type_inference::global_variable_type_default::TypeDefault; + use crate::type_inference::global_variable_type_spec::GlobalVariableTypeSpec; use crate::{Expr, FunctionTypeRegistry, InferredType, VariableId}; use test_r::test; #[test] - fn test_global_variable_inference() { + fn test_global_variable_inference_1() { let rib_expr = r#" let res = request.path.user-id; let hello: u64 = request.path.number; @@ -74,14 +73,42 @@ mod type_inference_tests { "#; let mut expr = Expr::from_text(rib_expr).unwrap(); - let type_default = TypeDefault { + let type_spec = GlobalVariableTypeSpec { variable_id: VariableId::global("request".to_string()), path: Path::from_elems(vec!["path"]), inferred_type: InferredType::Str, }; assert!(expr - .infer_types(&FunctionTypeRegistry::empty(), Some(&type_default)) + .infer_types(&FunctionTypeRegistry::empty(), &vec![type_spec]) + .is_ok()); + } + + #[test] + fn test_global_variable_inference_2() { + let rib_expr = r#" + let res1 = request.path.user-id; + let res2 = request.headers.name; + let res3 = request.headers.age; + "${res1}-${res2}-${res3}" + "#; + + let mut expr = Expr::from_text(rib_expr).unwrap(); + let type_spec = vec![ + GlobalVariableTypeSpec { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["path"]), + inferred_type: InferredType::Str, + }, + GlobalVariableTypeSpec { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["headers"]), + inferred_type: InferredType::Str, + }, + ]; + + assert!(expr + .infer_types(&FunctionTypeRegistry::empty(), &type_spec) .is_ok()); } } @@ -106,7 +133,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let let_binding = Expr::Let( VariableId::local("x", 0), @@ -154,7 +181,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let let_binding1 = Expr::Let( VariableId::local("x", 0), @@ -236,7 +263,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -270,7 +297,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -310,7 +337,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -463,7 +490,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = internal::expected_expr_for_enum_test(); @@ -540,7 +567,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr).unwrap(); - let result = expr.infer_types(&function_type_registry, None); + let result = expr.infer_types(&function_type_registry, &vec![]); assert!(result.is_ok()); } } @@ -561,7 +588,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -607,7 +634,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -645,7 +672,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -733,7 +760,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -772,7 +799,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -832,7 +859,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -896,7 +923,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -965,7 +992,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1029,7 +1056,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1085,7 +1112,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let expected = Expr::ExprBlock( @@ -1176,7 +1203,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let let_binding1 = Expr::Let( VariableId::local("x", 0), @@ -1281,7 +1308,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - let result = expr.infer_types(&function_type_registry, None); + let result = expr.infer_types(&function_type_registry, &vec![]); assert!(result.is_ok()); } @@ -1298,7 +1325,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let expected = Expr::ExprBlock( @@ -1396,7 +1423,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let expected = Expr::ExprBlock( @@ -1473,7 +1500,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let expected = Expr::ExprBlock( @@ -1555,7 +1582,7 @@ mod type_inference_tests { let mut expr = Expr::from_text(expr_str).unwrap(); - expr.infer_types(&FunctionTypeRegistry::empty(), None) + expr.infer_types(&FunctionTypeRegistry::empty(), &vec![]) .unwrap(); let expected = Expr::ExprBlock( @@ -1708,7 +1735,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1749,7 +1776,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1815,7 +1842,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1870,7 +1897,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -1981,7 +2008,7 @@ mod type_inference_tests { let function_type_registry = FunctionTypeRegistry::from_export_metadata(&component_metadata); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = internal::expected_expr_for_select_index(); @@ -2007,7 +2034,7 @@ mod type_inference_tests { let expr = Expr::from_text(rib_expr).unwrap(); let inferred_expr = - InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty(), None).unwrap(); + InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty(), &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -2086,7 +2113,7 @@ mod type_inference_tests { let expr = Expr::from_text(rib_expr).unwrap(); let inferred_expr = - InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty(), None).unwrap(); + InferredExpr::from_expr(&expr, &FunctionTypeRegistry::empty(), &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ @@ -2138,7 +2165,7 @@ mod type_inference_tests { let function_type_registry = internal::get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_types(&function_type_registry, None).unwrap(); + expr.infer_types(&function_type_registry, &vec![]).unwrap(); let expected = Expr::ExprBlock( vec![ diff --git a/golem-rib/src/type_inference/type_pull_up.rs b/golem-rib/src/type_inference/type_pull_up.rs index 94faad81e4..79feca32e4 100644 --- a/golem-rib/src/type_inference/type_pull_up.rs +++ b/golem-rib/src/type_inference/type_pull_up.rs @@ -1273,7 +1273,7 @@ mod type_pull_up_tests { let mut expr = Expr::from_text(rib).unwrap(); let function_registry = FunctionTypeRegistry::empty(); - expr.infer_types_initial_phase(&function_registry, None) + expr.infer_types_initial_phase(&function_registry, &vec![]) .unwrap(); expr.infer_all_identifiers().unwrap(); let new_expr = expr.pull_types_up().unwrap(); diff --git a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs index 773e90554e..ff6e387718 100644 --- a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs +++ b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use golem_wasm_ast::analysis::AnalysedExport; -use rib::{CompilerOutput, Expr, InferredType, Path, TypeDefault, VariableId}; +use rib::{CompilerOutput, Expr, InferredType, Path, GlobalVariableTypeSpec, VariableId}; // A wrapper service over original Rib Compiler concerning // the details of the worker bridge. @@ -29,11 +29,11 @@ impl WorkerServiceRibCompiler for DefaultWorkerServiceRibCompiler { rib, &export_metadata.to_vec(), Some(vec!["request".to_string()]), - Some(TypeDefault { + &vec![GlobalVariableTypeSpec { variable_id: VariableId::global("request".to_string()), path: Path::from_elems(vec!["path"]), inferred_type: InferredType::Str, - }), + }], ) } } From 951672e7af377820c9685c1900f39ab75ca2a33f Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 19:40:44 +0530 Subject: [PATCH 6/8] Fix tests --- golem-rib/src/compiler/mod.rs | 19 +++++++++---------- .../src/gateway_rib_compiler/mod.rs | 19 +++++++++++++------ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/golem-rib/src/compiler/mod.rs b/golem-rib/src/compiler/mod.rs index 46fbb41af7..c740b37ce3 100644 --- a/golem-rib/src/compiler/mod.rs +++ b/golem-rib/src/compiler/mod.rs @@ -49,6 +49,15 @@ pub fn compile_with_restricted_global_variables( allowed_global_variables: Option>, global_variable_type_spec: &Vec, ) -> Result { + for info in global_variable_type_spec { + if !info.variable_id.is_global() { + return Err(format!( + "Only global variables can have default types, but found {}", + info.variable_id + )); + } + } + let type_registry = FunctionTypeRegistry::from_export_metadata(export_metadata); let inferred_expr = InferredExpr::from_expr(expr, &type_registry, global_variable_type_spec)?; let function_calls_identified = @@ -57,16 +66,6 @@ pub fn compile_with_restricted_global_variables( let global_input_type_info = RibInputTypeInfo::from_expr(&inferred_expr).map_err(|e| format!("Error: {}", e))?; - let global_keys: HashSet<_> = global_input_type_info.types.keys().cloned().collect(); - - // We make the global variable spec given by the user is infact corresponds to the real - // global variables identified by the compiler - for info in global_variable_type_spec { - if !info.variable_id.is_global() || !global_keys.contains(&info.variable_id.to_string()) { - return Err("Only global variables can have default types".to_string()); - } - } - let output_type_info = RibOutputTypeInfo::from_expr(&inferred_expr)?; if let Some(allowed_global_variables) = &allowed_global_variables { diff --git a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs index ff6e387718..943cffddfb 100644 --- a/golem-worker-service-base/src/gateway_rib_compiler/mod.rs +++ b/golem-worker-service-base/src/gateway_rib_compiler/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use golem_wasm_ast::analysis::AnalysedExport; -use rib::{CompilerOutput, Expr, InferredType, Path, GlobalVariableTypeSpec, VariableId}; +use rib::{CompilerOutput, Expr, GlobalVariableTypeSpec, InferredType, Path, VariableId}; // A wrapper service over original Rib Compiler concerning // the details of the worker bridge. @@ -29,11 +29,18 @@ impl WorkerServiceRibCompiler for DefaultWorkerServiceRibCompiler { rib, &export_metadata.to_vec(), Some(vec!["request".to_string()]), - &vec![GlobalVariableTypeSpec { - variable_id: VariableId::global("request".to_string()), - path: Path::from_elems(vec!["path"]), - inferred_type: InferredType::Str, - }], + &vec![ + GlobalVariableTypeSpec { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["path"]), + inferred_type: InferredType::Str, + }, + GlobalVariableTypeSpec { + variable_id: VariableId::global("request".to_string()), + path: Path::from_elems(vec!["headers"]), + inferred_type: InferredType::Str, + }, + ], ) } } From a8a886302b544aed94c8c08e12bea5d7cbf808a2 Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 19:42:08 +0530 Subject: [PATCH 7/8] Fix tests --- golem-rib/src/compiler/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/golem-rib/src/compiler/mod.rs b/golem-rib/src/compiler/mod.rs index c740b37ce3..b8be2282da 100644 --- a/golem-rib/src/compiler/mod.rs +++ b/golem-rib/src/compiler/mod.rs @@ -16,7 +16,6 @@ pub use byte_code::*; pub use compiler_output::*; use golem_wasm_ast::analysis::AnalysedExport; pub use ir::*; -use std::collections::HashSet; pub use type_with_unit::*; pub use worker_functions_in_rib::*; From 57f1972ca3c55577da66d415e2b904eeb258cf5f Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Sat, 8 Feb 2025 20:07:38 +0530 Subject: [PATCH 8/8] Fix comments --- .../type_inference/global_variable_type_spec.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/golem-rib/src/type_inference/global_variable_type_spec.rs b/golem-rib/src/type_inference/global_variable_type_spec.rs index be031b5806..7cadf96a5d 100644 --- a/golem-rib/src/type_inference/global_variable_type_spec.rs +++ b/golem-rib/src/type_inference/global_variable_type_spec.rs @@ -4,11 +4,9 @@ use std::collections::VecDeque; // The goal is to be able to specify the types associated with an identifier. // i.e, `a.*` is always `Str`, or `a.b.*` is always `Str`, or `a.b.c` is always `Str` -// This can be represented using `TypeSpecification { a, vec![], Str }`, `TypeSpecification {a, b, Str}` and -// `TypeSpecification {a, vec[b, c], Str}` respectively +// This can be represented using `GlobalVariableTypeSpec { a, vec![], Str }`, `GlobalVariableTypeSpec {a, b, Str}` and +// `GlobalVariableTypeSpec {a, vec[b, c], Str}` respectively // If you specify completely opposite types to be default, you will get a compilation error. -// If you tried to specify a variable is always string, but compiler identifies it's usage as `U64`, -// then it chooses `U64` and discards the default. If the compiler finds its usages as `Str` #[derive(Clone, Debug)] pub struct GlobalVariableTypeSpec { pub variable_id: VariableId, @@ -18,12 +16,6 @@ pub struct GlobalVariableTypeSpec { // // Algorithm: -// -// The goal is to be able to specify the types associated with an identifier -// i.e, `a.*` is always `Str`, or `a.b.*` is always `Str`, or `a.b.c` is always `Str` -// This can be represented using `TypeSpecification { a, vec![], Str }`, `TypeSpecification {a, b, Str}` and -// `TypeSpecification {a, vec[b, c], Str}` respectively -// // We initially create queue of immutable Expr (to be able to push mutable version has to do into reference count logic in Rust) // and then push it to an intermediate stack and recreate the Expr. This is similar to `type_pull_up` phase. // This is verbose but will make the algorithm quite easy to handle. @@ -40,7 +32,7 @@ pub struct GlobalVariableTypeSpec { // Example queue: // [select_field(select_field(a, b), c), select_field(a, b), identifier(a)] // -// Example Walkthrough: Given `TypeSpecification { a, vec[b, c], Str]` +// Example Walkthrough: Given `GlobalVariableTypeSpec { a, vec[b, c], Str]` // // 1. Pop the back element in the queue to get `identifier(a)`. // - Check the `temp_stack` by popping from the front.