diff --git a/cedar-drt/fuzz/fuzz_targets/common-type-resolution.rs b/cedar-drt/fuzz/fuzz_targets/common-type-resolution.rs index 48626ae38..cec1f16f3 100644 --- a/cedar-drt/fuzz/fuzz_targets/common-type-resolution.rs +++ b/cedar-drt/fuzz/fuzz_targets/common-type-resolution.rs @@ -15,7 +15,7 @@ */ #![no_main] -use cedar_drt_inner::{schemas::validator_schema_attr_types_equivalent, *}; +use cedar_drt_inner::*; use cedar_policy_core::ast; use cedar_policy_generators::{ schema::{downgrade_frag_to_raw, Schema}, @@ -84,16 +84,13 @@ fuzz_target!(|i: Input| { downgrade_frag_to_raw(i.schema_with_common_types).try_into(); match (validator_schema1, validator_schema2) { (Ok(s1), Ok(s2)) => { - assert!( - validator_schema_attr_types_equivalent(&s1, &s2), - "reduced to different validator schemas: {:?}\n{:?}\n", - s1, - s2 - ); + if let Err(e) = schemas::Equiv::equiv(&s1, &s2) { + panic!("reduced to different validator schemas: {s1:?}\n{s2:?}\n\n{e}\n"); + } } (Err(_), Err(_)) => {} (Ok(s), Err(_)) | (Err(_), Ok(s)) => { - panic!("reduction results differ, got validator schema: {:?}\n", s); + panic!("reduction results differ, got validator schema: {s:?}\n"); } } }); diff --git a/cedar-drt/fuzz/src/schemas.rs b/cedar-drt/fuzz/src/schemas.rs index 1b3c3dea3..4b0e30c90 100644 --- a/cedar-drt/fuzz/src/schemas.rs +++ b/cedar-drt/fuzz/src/schemas.rs @@ -14,16 +14,17 @@ * limitations under the License. */ -use cedar_policy_core::ast::{Id, InternalName, UnreservedId}; +use cedar_policy_core::ast::{Id, InternalName}; use cedar_policy_validator::json_schema::{ - ApplySpec, EntityType, RecordType, Type, TypeOfAttribute, TypeVariant, + self, ApplySpec, EntityAttributeType, EntityAttributeTypeInternal, EntityAttributes, + EntityAttributesInternal, EntityType, RecordAttributeType, RecordType, Type, TypeVariant, }; use cedar_policy_validator::RawName; use itertools::Itertools; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use cedar_policy_validator::json_schema; use std::fmt::{Debug, Display}; +use std::hash::Hash; /// Check if two schema fragments are equivalent, modulo empty apply specs. /// We do this because there are schemas that are representable in the JSON that are not @@ -62,7 +63,7 @@ pub fn equivalence_check(schema: &mut json_schema::Fragment) { } } -fn namespace_equivalence( - lhs: json_schema::NamespaceDefinition, - rhs: json_schema::NamespaceDefinition, -) -> Result<(), String> { - entity_types_equivalence(lhs.entity_types, rhs.entity_types)?; - if lhs.common_types != rhs.common_types { - Err("Common types differ".to_string()) - } else if lhs.actions.len() != rhs.actions.len() { - Err("Different number of actions".to_string()) - } else { - lhs.actions - .into_iter() - .map(|(name, lhs_action)| { - let rhs_action = rhs - .actions - .get(&name) - .ok_or_else(|| format!("Action `{name}` not present on rhs"))?; - action_type_equivalence(name.as_ref(), lhs_action, rhs_action.clone()) - }) - .fold(Ok(()), Result::and) +pub trait Equiv { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String>; +} + +impl<'a, T: Equiv> Equiv for &'a T { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + Equiv::equiv(*lhs, *rhs) } } -type EntityData = HashMap>; +impl Equiv + for json_schema::NamespaceDefinition +{ + fn equiv( + lhs: &json_schema::NamespaceDefinition, + rhs: &json_schema::NamespaceDefinition, + ) -> Result<(), String> { + Equiv::equiv(&lhs.entity_types, &rhs.entity_types)?; + if &lhs.common_types != &rhs.common_types { + Err("Common types differ".to_string()) + } else if lhs.actions.len() != rhs.actions.len() { + Err("Different number of actions".to_string()) + } else { + lhs.actions + .iter() + .map(|(name, lhs_action)| { + let rhs_action = rhs + .actions + .get(name) + .ok_or_else(|| format!("Action `{name}` not present on rhs"))?; + action_type_equivalence(name.as_ref(), lhs_action, rhs_action) + }) + .fold(Ok(()), Result::and) + } + } +} -fn entity_types_equivalence( - lhs: EntityData, - rhs: EntityData, -) -> Result<(), String> { - if lhs.len() == rhs.len() { - let errors = lhs - .into_iter() - .filter_map(|lhs| entity_type_equivalence(lhs, &rhs).err()) - .collect::>(); - if errors.is_empty() { +/// `Equiv` for `HashSet` requires that the items in the set are exactly equal, +/// not equivalent by `Equiv`. (It would be hard to line up which item is +/// supposed to correspond to which, given an arbitrary `Equiv` implementation.) +impl Equiv for HashSet { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + if lhs != rhs { + let missing_elems = lhs.symmetric_difference(&rhs).join(", "); + Err(format!("missing set elements: {missing_elems}")) + } else { Ok(()) + } + } +} + +/// `Equiv` for `BTreeSet` requires that the items in the set are exactly equal, +/// not equivalent by `Equiv`. (It would be hard to line up which item is +/// supposed to correspond to which, given an arbitrary `Equiv` implementation.) +impl Equiv for BTreeSet { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + if lhs != rhs { + let missing_elems = lhs.symmetric_difference(&rhs).join(", "); + Err(format!("missing set elements: {missing_elems}")) } else { - Err(format!( - "Found the following entity type mismatches: {}", - errors.into_iter().join("\n") - )) + Ok(()) } - } else { - let lhs_keys: HashSet<_> = lhs.keys().collect(); - let rhs_keys: HashSet<_> = rhs.keys().collect(); - let missing_keys = lhs_keys.symmetric_difference(&rhs_keys).join(", "); - Err(format!("Missing keys: {missing_keys}")) } } -fn entity_type_equivalence( - (name, lhs_type): (UnreservedId, EntityType), - rhs: &EntityData, -) -> Result<(), String> { - let rhs_type = rhs - .get(&name) - .ok_or_else(|| format!("Type `{name}` was missing from right-hand-side"))?; - - if !vector_equiv(&lhs_type.member_of_types, &rhs_type.member_of_types) { - Err(format!( - "For `{name}`: lhs and rhs membership are not equal. LHS: [{}], RHS: [{}].", - lhs_type - .member_of_types - .into_iter() - .map(|id| id.to_string()) - .join(","), - rhs_type - .member_of_types +impl Equiv for HashMap { + fn equiv(lhs: &HashMap, rhs: &HashMap) -> Result<(), String> { + if lhs.len() == rhs.len() { + let errors = lhs .iter() - .map(|id| id.to_string()) - .join(",") - )) - } else if shape_equiv(&lhs_type.shape.0, &rhs_type.shape.0) { - Ok(()) - } else { - Err(format!("`{name}` has mismatched types")) + .filter_map(|(k, lhs_v)| match rhs.get(k) { + Some(rhs_v) => Equiv::equiv(lhs_v, rhs_v).err(), + None => Some(format!("`{k}` missing from rhs")), + }) + .collect::>(); + if errors.is_empty() { + Ok(()) + } else { + Err(format!( + "Found the following mismatches: {}", + errors.into_iter().join("\n") + )) + } + } else { + let lhs_keys: HashSet<_> = lhs.keys().collect(); + let rhs_keys: HashSet<_> = rhs.keys().collect(); + let missing_keys = lhs_keys.symmetric_difference(&rhs_keys).join(", "); + Err(format!("Missing keys: {missing_keys}")) + } } } -fn shape_equiv(lhs: &Type, rhs: &Type) -> bool { - match (lhs, rhs) { - (Type::Type(lhs), Type::Type(rhs)) => type_varient_equiv(lhs, rhs), - (Type::CommonTypeRef { type_name: lhs }, Type::CommonTypeRef { type_name: rhs }) => { - lhs == rhs +impl Equiv for BTreeMap { + fn equiv(lhs: &BTreeMap, rhs: &BTreeMap) -> Result<(), String> { + if lhs.len() == rhs.len() { + let errors = lhs + .iter() + .filter_map(|(k, lhs_v)| match rhs.get(k) { + Some(rhs_v) => Equiv::equiv(lhs_v, rhs_v).err(), + None => Some(format!("`{k}` missing from rhs")), + }) + .collect::>(); + if errors.is_empty() { + Ok(()) + } else { + Err(format!( + "Found the following mismatches: {}", + errors.into_iter().join("\n") + )) + } + } else { + let lhs_keys: BTreeSet<_> = lhs.keys().collect(); + let rhs_keys: BTreeSet<_> = rhs.keys().collect(); + let missing_keys = lhs_keys.symmetric_difference(&rhs_keys).join(", "); + Err(format!("Missing keys: {missing_keys}")) } - _ => false, } } -/// Type Variant equivalence. See the arms of each match for details -fn type_varient_equiv( - lhs: &TypeVariant, - rhs: &TypeVariant, -) -> bool { - match (lhs, rhs) { - // Records are equivalent iff - // A) They have all the same required keys - // B) Each key has a value that is equivalent - // C) the `additional_attributes` field is equal - ( - TypeVariant::Record(RecordType { - attributes: lhs_attributes, - additional_attributes: lhs_additional_attributes, - }), - TypeVariant::Record(RecordType { - attributes: rhs_attributes, - additional_attributes: rhs_additional_attributes, - }), - ) => { - let lhs_required_keys = lhs_attributes.keys().collect::>(); - let rhs_required_keys = rhs_attributes.keys().collect::>(); - if lhs_required_keys == rhs_required_keys { - lhs_attributes - .into_iter() - .all(|(key, lhs)| attribute_equiv(&lhs, rhs_attributes.get(key).unwrap())) - && lhs_additional_attributes == rhs_additional_attributes - } else { - false +impl Equiv for EntityType { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + Equiv::equiv( + &lhs.member_of_types.iter().collect::>(), + &rhs.member_of_types.iter().collect::>(), + ) + .map_err(|e| format!("memberOfTypes are not equal: {e}"))?; + Equiv::equiv(&lhs.shape, &rhs.shape).map_err(|e| format!("mismatched types: {e}")) + } +} + +impl Equiv for cedar_policy_validator::ValidatorEntityType { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + Equiv::equiv(&lhs.descendants, &rhs.descendants)?; + Equiv::equiv( + &lhs.attributes().collect::>(), + &rhs.attributes().collect::>(), + )?; + Ok(()) + } +} + +impl Equiv for EntityAttributes { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + match (lhs, rhs) { + ( + EntityAttributes::RecordAttributes(rca_l), + EntityAttributes::RecordAttributes(rca_r), + ) => Equiv::equiv(&rca_l.0, &rca_r.0) + .map_err(|e| format!("entity attributes not equivalent: {e}")), + ( + EntityAttributes::EntityAttributes(EntityAttributesInternal { + attrs: attrs_l, .. + }), + EntityAttributes::EntityAttributes(EntityAttributesInternal { + attrs: attrs_r, .. + }), + ) => { + if attrs_l.additional_attributes != attrs_r.additional_attributes { + return Err("attributes differ in additional_attributes flag".into()); + } + Equiv::equiv(&attrs_l.attributes, &attrs_r.attributes) + .map_err(|e| format!("entity attributes not equivalent: {e}")) + } + (_, _) => { + // these could still be equivalent in some cases + unimplemented!() } } - // Sets are equivalent if their elements are equivalent - ( - TypeVariant::Set { - element: lhs_element, - }, - TypeVariant::Set { - element: rhs_element, - }, - ) => shape_equiv(lhs_element.as_ref(), rhs_element.as_ref()), - - // Base types are equivalent to `EntityOrCommon` variants where the type_name is of the - // form `__cedar::` - (TypeVariant::String, TypeVariant::EntityOrCommon { type_name }) - | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::String) => { - is_internal_type(type_name, "String") + } +} + +impl Equiv for EntityAttributeType { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + if lhs.required != rhs.required { + return Err("attributes differ in required flag".into()); } - (TypeVariant::Long, TypeVariant::EntityOrCommon { type_name }) - | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Long) => { - is_internal_type(type_name, "Long") + Equiv::equiv(&lhs.ty, &rhs.ty) + } +} + +impl Equiv for cedar_policy_validator::types::AttributeType { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + if lhs.is_required != rhs.is_required { + return Err("attributes differ in required flag".into()); } - (TypeVariant::Boolean, TypeVariant::EntityOrCommon { type_name }) - | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Boolean) => { - is_internal_type(type_name, "Bool") + Equiv::equiv(&lhs.attr_type, &rhs.attr_type) + } +} + +impl Equiv for EntityAttributeTypeInternal { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + match (lhs, rhs) { + (EntityAttributeTypeInternal::Type(ty_l), EntityAttributeTypeInternal::Type(ty_r)) => { + Equiv::equiv(ty_l, ty_r) + } + ( + EntityAttributeTypeInternal::EAMap { + value_type: val_ty_l, + }, + EntityAttributeTypeInternal::EAMap { + value_type: val_ty_r, + }, + ) => Equiv::equiv(val_ty_l, val_ty_r), + (_, _) => Err("EAMap is not equivalent to non-EAMap type".into()), } - (TypeVariant::Extension { name }, TypeVariant::EntityOrCommon { type_name }) - | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Extension { name }) => { - is_internal_type(type_name, &name.to_string()) + } +} + +impl Equiv for Type { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + match (lhs, rhs) { + (Type::Type(lhs), Type::Type(rhs)) => Equiv::equiv(lhs, rhs), + (Type::CommonTypeRef { type_name: lhs }, Type::CommonTypeRef { type_name: rhs }) => { + if lhs == rhs { + Ok(()) + } else { + Err(format!( + "common type names do not match: `{lhs}` != `{rhs}`" + )) + } + } + (Type::Type(lhs), Type::CommonTypeRef { type_name: rhs }) => Err(format!( + "lhs is ordinary type `{lhs:?}`, rhs is common type `{rhs}`" + )), + (Type::CommonTypeRef { type_name: lhs }, Type::Type(rhs)) => Err(format!( + "lhs is common type `{lhs}`, rhs is ordinary type `{rhs:?}`" + )), } + } +} - (TypeVariant::Entity { name }, TypeVariant::EntityOrCommon { type_name }) - | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Entity { name }) => { - type_name == name +impl Equiv for cedar_policy_validator::types::Type { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + if lhs != rhs { + Err(format!("types are not equal: {lhs} != {rhs}")) + } else { + Ok(()) } + } +} + +impl Equiv for TypeVariant { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + match (lhs, rhs) { + // Records are equivalent iff + // A) They have all the same required keys + // B) Each key has a value that is equivalent + // C) the `additional_attributes` field is equal + ( + TypeVariant::Record(RecordType { + attributes: lhs_attributes, + additional_attributes: lhs_additional_attributes, + }), + TypeVariant::Record(RecordType { + attributes: rhs_attributes, + additional_attributes: rhs_additional_attributes, + }), + ) => { + let lhs_required_keys = lhs_attributes.keys().collect::>(); + let rhs_required_keys = rhs_attributes.keys().collect::>(); + if lhs_required_keys != rhs_required_keys { + return Err( + "records are not equivalent because they have different keysets".into(), + ); + } + if lhs_additional_attributes != rhs_additional_attributes { + return Err("records are not equivalent because they have different additional_attributes flags".into()); + } + lhs_attributes + .into_iter() + .map(|(key, lhs)| Equiv::equiv(lhs, rhs_attributes.get(key).unwrap())) + .collect::>() + } + // Sets are equivalent if their elements are equivalent + ( + TypeVariant::Set { + element: lhs_element, + }, + TypeVariant::Set { + element: rhs_element, + }, + ) => Equiv::equiv(lhs_element.as_ref(), rhs_element.as_ref()), - // Types that are exactly equal are of course equivalent - (lhs, rhs) => lhs == rhs, + // Base types are equivalent to `EntityOrCommon` variants where the type_name is of the + // form `__cedar::` + (TypeVariant::String, TypeVariant::EntityOrCommon { type_name }) + | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::String) => { + if is_internal_type(type_name, "String") { + Ok(()) + } else { + Err(format!( + "entity-or-common `{type_name}` is not equivalent to String" + )) + } + } + (TypeVariant::Long, TypeVariant::EntityOrCommon { type_name }) + | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Long) => { + if is_internal_type(type_name, "Long") { + Ok(()) + } else { + Err(format!( + "entity-or-common `{type_name}` is not equivalent to Long" + )) + } + } + (TypeVariant::Boolean, TypeVariant::EntityOrCommon { type_name }) + | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Boolean) => { + if is_internal_type(type_name, "Bool") { + Ok(()) + } else { + Err(format!( + "entity-or-common `{type_name}` is not equivalent to Boolean" + )) + } + } + (TypeVariant::Extension { name }, TypeVariant::EntityOrCommon { type_name }) + | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Extension { name }) => { + if is_internal_type(type_name, &name.to_string()) { + Ok(()) + } else { + Err(format!( + "entity-or-common `{type_name}` is not equivalent to Extension `{name}` " + )) + } + } + + (TypeVariant::Entity { name }, TypeVariant::EntityOrCommon { type_name }) + | (TypeVariant::EntityOrCommon { type_name }, TypeVariant::Entity { name }) => { + if type_name == name { + Ok(()) + } else { + Err(format!( + "entity `{name}` is not equivalent to entity-or-common `{type_name}`" + )) + } + } + + // Types that are exactly equal are of course equivalent + (lhs, rhs) => { + if lhs == rhs { + Ok(()) + } else { + Err("types are not equivalent".into()) + } + } + } } } -/// Attributes are equivalent iff their shape is equivalent and they have the same required status -fn attribute_equiv( - lhs: &TypeOfAttribute, - rhs: &TypeOfAttribute, -) -> bool { - lhs.required == rhs.required && shape_equiv(&lhs.ty, &rhs.ty) +impl Equiv for RecordAttributeType { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + if lhs.required != rhs.required { + return Err(format!("attribute `{lhs:?}` is not equivalent to attribute `{rhs:?}` because of difference in .required")); + } + Equiv::equiv(&lhs.ty, &rhs.ty) + } } /// Is the given type name the `__cedar` alias for an internal type @@ -264,15 +456,6 @@ fn is_internal_type(type_name: &N, expected: &str) -> bool == vec!["__cedar"] } -/// Vectors are equivalent if they contain the same items, regardless of order -fn vector_equiv(lhs: &[N], rhs: &[N]) -> bool { - let mut lhs = lhs.iter().collect::>(); - let mut rhs = rhs.iter().collect::>(); - lhs.sort(); - rhs.sort(); - lhs == rhs -} - /// Trait for taking either `N` to a concrete type we can do equality over pub trait TypeName { fn qualify(self) -> InternalName; @@ -294,19 +477,19 @@ impl TypeName for InternalName { fn action_type_equivalence( name: &str, - lhs: json_schema::ActionType, - rhs: json_schema::ActionType, + lhs: &json_schema::ActionType, + rhs: &json_schema::ActionType, ) -> Result<(), String> { - if lhs.attributes != rhs.attributes { + if &lhs.attributes != &rhs.attributes { Err(format!("Attributes don't match for `{name}`")) - } else if lhs.member_of != rhs.member_of { + } else if &lhs.member_of != &rhs.member_of { Err(format!("Member of don't match for `{name}`")) } else { - match (lhs.applies_to, rhs.applies_to) { + match (&lhs.applies_to, &rhs.applies_to) { (None, None) => Ok(()), (Some(lhs), Some(rhs)) => { // If either of them has at least one empty appliesTo list, the other must have the same attribute. - if (either_empty(&lhs) && either_empty(&rhs)) || apply_spec_equiv(&lhs, &rhs) { + if (either_empty(&lhs) && either_empty(&rhs)) || Equiv::equiv(lhs, rhs).is_ok() { Ok(()) } else { Err(format!( @@ -318,7 +501,7 @@ fn action_type_equivalence { + (Some(applies_to), None) | (None, Some(applies_to)) if either_empty(applies_to) => { Ok(()) } (Some(_), None) => Err(format!( @@ -331,53 +514,48 @@ fn action_type_equivalence( - lhs: &ApplySpec, - rhs: &ApplySpec, -) -> bool { - shape_equiv(&lhs.context.0, &rhs.context.0) - && vector_equiv(&lhs.principal_types, &rhs.principal_types) - && vector_equiv(&lhs.resource_types, &rhs.resource_types) +impl Equiv for ApplySpec { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + // ApplySpecs are equivalent iff + // A) the principal and resource type lists are equal + // B) the context shapes are equivalent + Equiv::equiv(&lhs.context.0, &rhs.context.0)?; + Equiv::equiv( + &lhs.principal_types.iter().collect::>(), + &rhs.principal_types.iter().collect::>(), + )?; + Equiv::equiv( + &lhs.resource_types.iter().collect::>(), + &rhs.resource_types.iter().collect::>(), + )?; + Ok(()) + } } fn either_empty(spec: &json_schema::ApplySpec) -> bool { spec.principal_types.is_empty() || spec.resource_types.is_empty() } -/// Just compare entity attribute types and context types are equivalent -pub fn validator_schema_attr_types_equivalent( - schema1: &cedar_policy_validator::ValidatorSchema, - schema2: &cedar_policy_validator::ValidatorSchema, -) -> bool { - let entity_attr_tys1: HashMap< - &cedar_drt::ast::EntityType, - HashMap<&smol_str::SmolStr, &cedar_policy_validator::types::AttributeType>, - > = HashMap::from_iter( - schema1 - .entity_types() - .map(|(name, ty)| (name, HashMap::from_iter(ty.attributes()))), - ); - let entity_attr_tys2 = HashMap::from_iter( - schema2 - .entity_types() - .map(|(name, ty)| (name, HashMap::from_iter(ty.attributes()))), - ); - let context_ty1: HashSet<&cedar_policy_validator::types::Type> = HashSet::from_iter( - schema1 - .action_entities() - .unwrap() - .iter() - .map(|e| schema1.get_action_id(e.uid()).unwrap().context_type()), - ); - let context_ty2: HashSet<&cedar_policy_validator::types::Type> = HashSet::from_iter( - schema2 - .action_entities() - .unwrap() - .iter() - .map(|e| schema1.get_action_id(e.uid()).unwrap().context_type()), - ); - entity_attr_tys1 == entity_attr_tys2 && context_ty1 == context_ty2 +impl Equiv for cedar_policy_validator::ValidatorSchema { + fn equiv(lhs: &Self, rhs: &Self) -> Result<(), String> { + Equiv::equiv( + &lhs.entity_types().collect::>(), + &rhs.entity_types().collect::>(), + ) + .map_err(|e| format!("entity attributes are not equivalent: {e}"))?; + Equiv::equiv( + &lhs.action_entities() + .unwrap() + .iter() + .map(|e| lhs.get_action_id(e.uid()).unwrap().context_type()) + .collect::>(), + &rhs.action_entities() + .unwrap() + .iter() + .map(|e| rhs.get_action_id(e.uid()).unwrap().context_type()) + .collect::>(), + ) + .map_err(|e| format!("contexts are not equivalent: {e}"))?; + Ok(()) + } } diff --git a/cedar-policy-generators/src/expr.rs b/cedar-policy-generators/src/expr.rs index 557601998..eb66b82bc 100644 --- a/cedar-policy-generators/src/expr.rs +++ b/cedar-policy-generators/src/expr.rs @@ -21,8 +21,8 @@ use crate::hierarchy::{ arbitrary_specified_uid, generate_uid_with_type, EntityUIDGenMode, Hierarchy, }; use crate::schema::{ - attrs_from_attrs_or_context, entity_type_name_to_schema_type, lookup_common_type, - uid_for_action_name, Schema, + attr_names_from_ea, entity_type_name_to_schema_type, lookup_common_type, uid_for_action_name, + Schema, }; use crate::settings::ABACSettings; use crate::size_hint_utils::{size_hint_for_choose, size_hint_for_range, size_hint_for_ratio}; @@ -555,11 +555,7 @@ impl<'a> ExprGenerator<'a> { .expect("Failed to select entity index."), ) .expect("Failed to select entity from map."); - let attr_names: Vec<&SmolStr> = - attrs_from_attrs_or_context(&self.schema.schema, &entity_type.shape) - .attrs - .keys() - .collect::>(); + let attr_names: Vec = attr_names_from_ea(&self.schema.schema, &entity_type.shape).collect(); let attr_name = SmolStr::clone(u.choose(&attr_names)?); Ok(ast::Expr::has_attr( self.generate_expr_for_schematype( @@ -1691,8 +1687,11 @@ impl<'a> ExprGenerator<'a> { let mut r = HashMap::new(); u.arbitrary_loop(None, Some(self.settings.max_width as u32), |u| { let (attr_name, attr_ty) = self.schema.arbitrary_attr(u)?.clone(); - let attr_val = - self.generate_attr_value_for_schematype(&attr_ty, max_depth - 1, u)?; + let attr_val = self.generate_attr_value_for_eatypeinternal( + &attr_ty, + max_depth - 1, + u, + )?; r.insert(attr_name, attr_val); Ok(std::ops::ControlFlow::Continue(())) })?; @@ -1810,7 +1809,7 @@ impl<'a> ExprGenerator<'a> { // maybe add some "additional" attributes not mentioned in schema u.arbitrary_loop(None, Some(self.settings.max_width as u32), |u| { let (attr_name, attr_ty) = self.schema.arbitrary_attr(u)?.clone(); - let attr_val = self.generate_attr_value_for_schematype( + let attr_val = self.generate_attr_value_for_eatypeinternal( &attr_ty, max_depth - 1, u, @@ -1953,7 +1952,7 @@ impl<'a> ExprGenerator<'a> { u.arbitrary_loop(None, Some(self.settings.max_width as u32), |u| { let (attr_name, attr_ty) = self.schema.arbitrary_attr(u)?.clone(); let attr_val = - self.generate_value_for_schematype(&attr_ty, max_depth - 1, u)?; + self.generate_value_for_eatypeinternal(&attr_ty, max_depth - 1, u)?; r.insert(attr_name, attr_val); Ok(std::ops::ControlFlow::Continue(())) })?; @@ -2036,7 +2035,7 @@ impl<'a> ExprGenerator<'a> { u.arbitrary_loop(None, Some(self.settings.max_width as u32), |u| { let (attr_name, attr_ty) = self.schema.arbitrary_attr(u)?.clone(); let attr_val = - self.generate_value_for_schematype(&attr_ty, max_depth - 1, u)?; + self.generate_value_for_eatypeinternal(&attr_ty, max_depth - 1, u)?; r.insert(attr_name, attr_val); Ok(std::ops::ControlFlow::Continue(())) })?; @@ -2086,6 +2085,73 @@ impl<'a> ExprGenerator<'a> { } } + /// generate an arbitrary [`ast::Value`] of the given [`json_schema::EntityAttributeTypeInternal`] + fn generate_value_for_eatypeinternal( + &self, + target_type: &json_schema::EntityAttributeTypeInternal, + max_depth: usize, + u: &mut Unstructured<'_>, + ) -> Result { + match target_type { + json_schema::EntityAttributeTypeInternal::Type(ty) => { + self.generate_value_for_schematype(ty, max_depth, u) + } + json_schema::EntityAttributeTypeInternal::EAMap { value_type } => { + if max_depth == 0 { + // no recursion allowed: just return empty-record + Ok(ast::Value::empty_record(None)) + } else { + let mut r = HashMap::new(); + // add an arbitrary number of attributes with the appropriate type + u.arbitrary_loop(None, Some(self.settings.max_width as u32), |u| { + let attr_name: SmolStr = u.arbitrary()?; + let attr_val = + self.generate_value_for_schematype(&value_type, max_depth - 1, u)?; + r.insert(attr_name, attr_val); + Ok(std::ops::ControlFlow::Continue(())) + })?; + Ok(ast::Value::record(r, None)) + } + } + } + } + + /// get an [`AttrValue`] of the given [`json_schema::EntityAttributeTypeInternal`] + /// which conforms to this schema + /// + /// `max_depth`: maximum depth of the attribute value expression. + /// For instance, maximum depth of nested sets. Not to be confused with the + /// `depth` parameter to size_hint. + pub fn generate_attr_value_for_eatypeinternal( + &self, + target_type: &json_schema::EntityAttributeTypeInternal, + max_depth: usize, + u: &mut Unstructured<'_>, + ) -> Result { + match target_type { + json_schema::EntityAttributeTypeInternal::Type(ty) => { + self.generate_attr_value_for_schematype(ty, max_depth, u) + } + json_schema::EntityAttributeTypeInternal::EAMap { value_type } => { + if max_depth == 0 { + // no recursion allowed: just return empty-record + Ok(AttrValue::Record(HashMap::new())) + } else { + let mut r = HashMap::new(); + // add an arbitrary number of attributes with the appropriate type + u.arbitrary_loop(None, Some(self.settings.max_width as u32), |u| { + let attr_name: SmolStr = u.arbitrary()?; + let attr_val = + self.generate_attr_value_for_schematype(&value_type, max_depth - 1, u)?; + r.insert(attr_name, attr_val); + Ok(std::ops::ControlFlow::Continue(())) + })?; + Ok(AttrValue::Record(r)) + } + } + } + } + /// get a (fully general) arbitrary constant, as an expression. #[allow(dead_code)] pub fn generate_const_expr(&self, u: &mut Unstructured<'_>) -> Result { @@ -2254,7 +2320,7 @@ fn record_schematype_with_attr( json_schema::Type::Type(json_schema::TypeVariant::Record(json_schema::RecordType { attributes: [( attr_name, - json_schema::TypeOfAttribute { + json_schema::RecordAttributeType { ty: attr_type.into(), required: true, }, diff --git a/cedar-policy-generators/src/hierarchy.rs b/cedar-policy-generators/src/hierarchy.rs index 7baa816f4..6851d10fb 100644 --- a/cedar-policy-generators/src/hierarchy.rs +++ b/cedar-policy-generators/src/hierarchy.rs @@ -17,7 +17,7 @@ use crate::abac::Type; use crate::collections::{HashMap, HashSet}; use crate::err::{while_doing, Error, Result}; -use crate::schema::{attrs_from_attrs_or_context, Schema}; +use crate::schema::{attrs_from_ea, Schema}; use crate::size_hint_utils::{size_hint_for_choose, size_hint_for_ratio}; use arbitrary::{Arbitrary, Unstructured}; use cedar_policy_core::ast::{self, Eid, Entity, EntityUID}; @@ -592,7 +592,7 @@ impl<'a, 'u> HierarchyGenerator<'a, 'u> { let Some(entitytypes_by_type) = &entitytypes_by_type else { unreachable!("in schema-based mode, this should always be Some") }; - let attributes = attrs_from_attrs_or_context( + let attributes = attrs_from_ea( &schema.schema, &entitytypes_by_type .get(name) @@ -636,7 +636,7 @@ impl<'a, 'u> HierarchyGenerator<'a, 'u> { if ty.required || self.u.ratio::(1, 2)? { let attr_val = schema .exprgenerator(Some(&hierarchy_no_attrs)) - .generate_attr_value_for_schematype( + .generate_attr_value_for_eatypeinternal( &ty.ty, schema.settings.max_depth, self.u, diff --git a/cedar-policy-generators/src/schema.rs b/cedar-policy-generators/src/schema.rs index 0dd41b159..47ba70ce2 100644 --- a/cedar-policy-generators/src/schema.rs +++ b/cedar-policy-generators/src/schema.rs @@ -63,30 +63,31 @@ pub struct Schema { /// list of entity types that occur as a valid resource for at least one /// action in the `schema` pub resource_types: Vec, - /// list of (attribute, type) pairs that occur in the `schema` - attributes: Vec<(SmolStr, json_schema::Type)>, + /// list of (attribute, attribute type) pairs that occur in the `schema` + attributes: Vec<( + SmolStr, + json_schema::EntityAttributeTypeInternal, + )>, /// map from type to (entity type, attribute name) pairs indicating - /// attributes in the `schema` that have that type. - /// note that we can't make a similar map for json_schema::Type because it - /// isn't Hash or Ord + /// attributes in the `schema` that have that type attributes_by_type: HashMap>, } -/// internal helper function, basically `impl Arbitrary for AttributesOrContext` -fn arbitrary_attrspec>( +/// internal helper function, basically `impl Arbitrary for RecordOrContextAttributes` +fn arbitrary_rca>( settings: &ABACSettings, entity_types: &[ast::EntityType], u: &mut Unstructured<'_>, -) -> Result> { +) -> Result> { let attr_names: Vec = u .arbitrary() .map_err(|e| while_doing("generating attribute names for an attrspec".into(), e))?; - Ok(json_schema::AttributesOrContext(json_schema::Type::Type( + Ok(json_schema::RecordOrContextAttributes(json_schema::Type::Type( json_schema::TypeVariant::Record(json_schema::RecordType { attributes: attr_names .into_iter() .map(|attr| { - let mut ty = arbitrary_typeofattribute_with_bounded_depth::( + let mut ty = arbitrary_recordattributetype_with_bounded_depth::( settings, entity_types, settings.max_depth, @@ -95,7 +96,7 @@ fn arbitrary_attrspec>( if !settings.enable_extensions { // can't have extension types. regenerate until morale improves while ty.ty.is_extension().expect("DRT does not generate schema type using type defs, so `is_extension` should be `Some`") { - ty = arbitrary_typeofattribute_with_bounded_depth::( + ty = arbitrary_recordattributetype_with_bounded_depth::( settings, entity_types, settings.max_depth, @@ -114,19 +115,32 @@ fn arbitrary_attrspec>( }), ))) } -/// size hint for arbitrary_attrspec -fn arbitrary_attrspec_size_hint(depth: usize) -> (usize, Option) { +/// size hint for [`arbitrary_rca()`] +fn arbitrary_rca_size_hint(depth: usize) -> (usize, Option) { arbitrary::size_hint::recursion_guard(depth, |depth| { arbitrary::size_hint::and_all(&[ as Arbitrary>::size_hint(depth), - arbitrary_typeofattribute_size_hint(depth), + arbitrary_recordattributetype_size_hint(depth), ::size_hint(depth), ]) }) } +/// internal helper function, basically `impl Arbitrary for EntityAttributes` +fn arbitrary_entityattributes>( + settings: &ABACSettings, + entity_types: &[ast::EntityType], + u: &mut Unstructured<'_>, +) -> Result> { + // RFC 68 is not yet fully supported. + // Currently, we never generate `EAMap`s in this function. + Ok(json_schema::EntityAttributes::RecordAttributes( + arbitrary_rca(settings, entity_types, u)?, + )) +} + /// internal helper function, an alternative to the `Arbitrary` impl for -/// `TypeOfAttribute` that implements a bounded maximum depth. +/// [`json_schema::RecordAttributeType`] that implements a bounded maximum depth. /// For instance, if `max_depth` is 3, then Set types (or Record types) /// won't be nested more than 3 deep. /// @@ -137,19 +151,19 @@ fn arbitrary_attrspec_size_hint(depth: usize) -> (usize, Option) { /// settings.enable_additional_attributes; it always behaves as if that setting /// is `true` (ie, it may generate `additional_attributes` as either `true` or /// `false`). -fn arbitrary_typeofattribute_with_bounded_depth>( +fn arbitrary_recordattributetype_with_bounded_depth>( settings: &ABACSettings, entity_types: &[ast::EntityType], max_depth: usize, u: &mut Unstructured<'_>, -) -> Result> { - Ok(json_schema::TypeOfAttribute { +) -> Result> { + Ok(json_schema::RecordAttributeType { ty: arbitrary_schematype_with_bounded_depth::(settings, entity_types, max_depth, u)?, required: u.arbitrary()?, }) } -/// size hint for arbitrary_typeofattribute_with_bounded_depth -fn arbitrary_typeofattribute_size_hint(depth: usize) -> (usize, Option) { +/// size hint for [`arbitrary_recordattributetype_with_bounded_depth()`] +fn arbitrary_recordattributetype_size_hint(depth: usize) -> (usize, Option) { arbitrary::size_hint::and( arbitrary_schematype_size_hint(depth), ::size_hint(depth), @@ -218,7 +232,7 @@ pub fn arbitrary_schematype_with_bounded_depth>( .map(|attr_name| { Ok(( attr_name.into(), - arbitrary_typeofattribute_with_bounded_depth( + arbitrary_recordattributetype_with_bounded_depth( settings, entity_types, max_depth - 1, @@ -330,43 +344,134 @@ fn schematype_to_type( } } +/// internal helper function, convert a +/// [`json_schema::EntityAttributeTypeInternal`] to a [`Type`] (loses some +/// information) +fn eatypeinternal_to_type( + schema: &json_schema::NamespaceDefinition, + eatypeinternal: &json_schema::EntityAttributeTypeInternal, +) -> Type { + match eatypeinternal { + json_schema::EntityAttributeTypeInternal::Type(ty) => schematype_to_type(schema, ty), + json_schema::EntityAttributeTypeInternal::EAMap { .. } => Type::record(), // For these purposes, EAMaps are just records, as runtime values of type EAMap are valid runtime values of type Record + } +} + /// Get an arbitrary namespace for a schema. The namespace may be absent. fn arbitrary_namespace(u: &mut Unstructured<'_>) -> Result> { u.arbitrary() .map_err(|e| while_doing("generating namespace".into(), e)) } -/// Information about attributes from the schema -pub(crate) struct Attributes<'a> { +/// Information about record or context attributes +pub(crate) struct RecordOrContextAttributes<'a> { /// the actual attributes - pub attrs: &'a BTreeMap>, + pub attrs: &'a BTreeMap>, /// whether `additional_attributes` is set pub additional_attrs: bool, } -/// Given a [`json_schema::AttributesOrContext`], get the actual attributes map -/// from it, and whether it has `additional_attributes` set -pub(crate) fn attrs_from_attrs_or_context<'a>( +/// Information about entity attributes +pub(crate) struct EntityAttributes { + /// the actual attributes + pub attrs: BTreeMap>, + /// whether `additional_attributes` is set + pub additional_attrs: bool, +} + +/// Given a [`json_schema::RecordOrContextAttributes`], get the +/// [`RecordOrContextAttributes`] describing it +pub(crate) fn attrs_from_rca<'a>( schema: &'a json_schema::NamespaceDefinition, - attrsorctx: &'a json_schema::AttributesOrContext, -) -> Attributes<'a> { - match &attrsorctx.0 { + rca: &'a json_schema::RecordOrContextAttributes, +) -> RecordOrContextAttributes<'a> { + match &rca.0 { json_schema::Type::CommonTypeRef { type_name } => match lookup_common_type(schema, type_name).unwrap_or_else(|| panic!("reference to undefined common type: {type_name}")) { json_schema::Type::CommonTypeRef { .. } => panic!("common type `{type_name}` refers to another common type, which is not allowed as of this writing?"), - json_schema::Type::Type(json_schema::TypeVariant::Record(json_schema::RecordType { attributes, additional_attributes })) => Attributes { attrs: attributes, additional_attrs: *additional_attributes }, + json_schema::Type::Type(json_schema::TypeVariant::Record(json_schema::RecordType { attributes, additional_attributes })) => RecordOrContextAttributes { attrs: attributes, additional_attrs: *additional_attributes }, json_schema::Type::Type(ty) => panic!("expected attributes or context to be a record, got {ty:?}"), } - json_schema::Type::Type(json_schema::TypeVariant::Record(json_schema::RecordType { attributes, additional_attributes })) => Attributes { attrs: attributes, additional_attrs: *additional_attributes }, + json_schema::Type::Type(json_schema::TypeVariant::Record(json_schema::RecordType { attributes, additional_attributes })) => RecordOrContextAttributes { attrs: attributes, additional_attrs: *additional_attributes }, json_schema::Type::Type(ty) => panic!("expected attributes or context to be a record, got {ty:?}"), } } +/// Given a [`json_schema::EntityAttributes`], get the [`EntityAttributes`] +/// describing it +pub(crate) fn attrs_from_ea( + schema: &json_schema::NamespaceDefinition, + ea: &json_schema::EntityAttributes, +) -> EntityAttributes { + match ea { + json_schema::EntityAttributes::RecordAttributes(rca) => { + let RecordOrContextAttributes { + attrs, + additional_attrs, + } = attrs_from_rca(schema, rca); + EntityAttributes { + attrs: attrs + .iter() + .map(|(k, v)| { + ( + k.clone(), + json_schema::EntityAttributeType { + ty: json_schema::EntityAttributeTypeInternal::Type(v.ty.clone()), + required: v.required, + }, + ) + }) + .collect(), + additional_attrs, + } + } + json_schema::EntityAttributes::EntityAttributes( + json_schema::EntityAttributesInternal { + attrs: + json_schema::RecordType { + attributes, + additional_attributes, + }, + .. + }, + ) => EntityAttributes { + attrs: attributes + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + additional_attrs: *additional_attributes, + }, + } +} + +/// Given a [`json_schema::EntityAttributes`], get just the attribute names in it +pub(crate) fn attr_names_from_ea<'a>( + schema: &'a json_schema::NamespaceDefinition, + ea: &'a json_schema::EntityAttributes, +) -> Box + 'a> { + match ea { + json_schema::EntityAttributes::RecordAttributes(rca) => { + let attrs = attrs_from_rca(schema, rca); + Box::new(attrs.attrs.keys().cloned()) + } + json_schema::EntityAttributes::EntityAttributes(rty) => { + Box::new(rty.attrs.attributes.keys().cloned()) + } + } +} + /// Given a [`json_schema::Type`], return all (attribute, type) pairs that occur /// inside it fn attrs_in_schematype( schema: &json_schema::NamespaceDefinition, schematype: &json_schema::Type, -) -> Box)>> { +) -> Box< + dyn Iterator< + Item = ( + SmolStr, + json_schema::EntityAttributeTypeInternal, + ), + >, +> { match schematype { json_schema::Type::Type(variant) => match variant { json_schema::TypeVariant::Boolean => Box::new(std::iter::empty()), @@ -392,11 +497,16 @@ fn attrs_in_schematype( json_schema::TypeVariant::Record(json_schema::RecordType { attributes, .. }) => { let toplevel = attributes .iter() - .map(|(k, v)| (k.clone(), v.ty.clone())) + .map(|(k, v)| { + ( + k.clone(), + json_schema::EntityAttributeTypeInternal::Type(v.ty.clone()), + ) + }) .collect::>(); let recursed = toplevel .iter() - .flat_map(|(_, v)| attrs_in_schematype(schema, v)) + .flat_map(|(_, v)| attrs_in_eatypeinternal(schema, v)) .collect::>(); Box::new(toplevel.into_iter().chain(recursed)) } @@ -409,6 +519,65 @@ fn attrs_in_schematype( } } +/// Given a [`json_schema::EntityType`], return all (attribute, +/// type) pairs that occur inside it +fn attrs_in_etype( + schema: &json_schema::NamespaceDefinition, + etype: &json_schema::EntityType, +) -> Box< + dyn Iterator< + Item = ( + SmolStr, + json_schema::EntityAttributeTypeInternal, + ), + >, +> { + match &etype.shape { + json_schema::EntityAttributes::RecordAttributes( + json_schema::RecordOrContextAttributes(ty), + ) => attrs_in_schematype(schema, ty), + json_schema::EntityAttributes::EntityAttributes( + json_schema::EntityAttributesInternal { + attrs: json_schema::RecordType { attributes, .. }, + .. + }, + ) => { + let toplevel = attributes + .iter() + .map(|(k, v)| (k.clone(), v.ty.clone())) + .collect::>(); + let recursed = toplevel + .iter() + .flat_map(|(_, v)| attrs_in_eatypeinternal(schema, v)) + .collect::>(); + Box::new(toplevel.into_iter().chain(recursed)) + } + } +} + +/// Given a [`json_schema::EntityAttributeTypeInternal`], return all +/// (attribute, type) pairs that occur inside it +fn attrs_in_eatypeinternal( + schema: &json_schema::NamespaceDefinition, + eatypeinternal: &json_schema::EntityAttributeTypeInternal, +) -> Box< + dyn Iterator< + Item = ( + SmolStr, + json_schema::EntityAttributeTypeInternal, + ), + >, +> { + match eatypeinternal { + json_schema::EntityAttributeTypeInternal::Type(ty) => attrs_in_schematype(schema, ty), + json_schema::EntityAttributeTypeInternal::EAMap { value_type } => { + // we can't return any attributes from the EAMap itself because we + // are not guaranteed that any particular attribute names exist + attrs_in_schematype(schema, value_type) + } + } +} + /// Build `attributes_by_type` from other components of `Schema` fn build_attributes_by_type<'a>( schema: &json_schema::NamespaceDefinition, @@ -425,16 +594,20 @@ fn build_attributes_by_type<'a>( .map(|(name, et)| { ( ast::EntityType::from(ast::Name::from(name.clone())).qualify_with(namespace), - attrs_from_attrs_or_context(schema, &et.shape), + attrs_from_ea(schema, &et.shape), ) }) .flat_map(|(tyname, attributes)| { - attributes.attrs.iter().map(move |(attr_name, ty)| { - ( - schematype_to_type(schema, &ty.ty), - (tyname.clone(), attr_name.clone()), - ) - }) + attributes + .attrs + .iter() + .map(move |(attr_name, ty)| { + ( + eatypeinternal_to_type(schema, &ty.ty), + (tyname.clone(), attr_name.clone()), + ) + }) + .collect::>() }); let mut hm: HashMap> = HashMap::new(); for (ty, pair) in triples { @@ -524,7 +697,7 @@ impl Bindings { .map(|(attr, attr_ty)| { Ok(( attr.to_owned(), - json_schema::TypeOfAttribute { + json_schema::RecordAttributeType { ty: self.rewrite_type(u, &attr_ty.ty)?, required: attr_ty.required.to_owned(), }, @@ -539,33 +712,104 @@ impl Bindings { } } - // Replace attribute types in an entity type with common types + /// Replace attribute types in an entity type with common types fn rewrite_entity_type( &self, u: &mut Unstructured<'_>, et: &json_schema::EntityType, ) -> Result> { - let ty = &et.shape.0; Ok(json_schema::EntityType { member_of_types: et.member_of_types.clone(), - shape: json_schema::AttributesOrContext(self.rewrite_record_type(u, ty)?), + shape: self.rewrite_entity_attributes(u, &et.shape)?, }) } - // Replace attribute types in a record type with common types + /// Replace attribute types in a [`json_schema::EntityAttributes`] with common types + fn rewrite_entity_attributes( + &self, + u: &mut Unstructured<'_>, + ea: &json_schema::EntityAttributes, + ) -> Result> { + match ea { + json_schema::EntityAttributes::RecordAttributes(attrs) => Ok( + json_schema::EntityAttributes::RecordAttributes(self.rewrite_rca(u, attrs)?), + ), + json_schema::EntityAttributes::EntityAttributes(attrs) => Ok( + json_schema::EntityAttributes::from(self.rewrite_record_type(u, &attrs.attrs)?), + ), + } + } + + /// Replace attribute types in a [`json_schema::RecordOrContextAttributes`] with common types + fn rewrite_rca( + &self, + u: &mut Unstructured<'_>, + rca: &json_schema::RecordOrContextAttributes, + ) -> Result> { + Ok(json_schema::RecordOrContextAttributes( + self.rewrite_or_replace_type(u, &rca.0)?, + )) + } + fn rewrite_record_type( + &self, + u: &mut Unstructured<'_>, + rty: &json_schema::RecordType>, + ) -> Result>> { + Ok(json_schema::RecordType { + attributes: rty + .attributes + .iter() + .map(|(k, v)| Ok((k.clone(), self.rewrite_eatype(u, v)?))) + .collect::>()?, + additional_attributes: rty.additional_attributes, + }) + } + + fn rewrite_eatype( + &self, + u: &mut Unstructured<'_>, + eatype: &json_schema::EntityAttributeType, + ) -> Result> { + Ok(json_schema::EntityAttributeType { + ty: self.rewrite_eatypeinternal(u, &eatype.ty)?, + required: eatype.required, + }) + } + + fn rewrite_eatypeinternal( + &self, + u: &mut Unstructured<'_>, + eatypeinternal: &json_schema::EntityAttributeTypeInternal, + ) -> Result> { + match eatypeinternal { + json_schema::EntityAttributeTypeInternal::Type(ty) => { + Ok(json_schema::EntityAttributeTypeInternal::Type( + self.rewrite_or_replace_type(u, ty)?, + )) + } + json_schema::EntityAttributeTypeInternal::EAMap { value_type } => { + Ok(json_schema::EntityAttributeTypeInternal::EAMap { + value_type: self.rewrite_or_replace_type(u, value_type)?, + }) + } + } + } + + /// Replace the type with a common-type reference, or rewrite the type to + /// possibly replace subcomponents of the type with common-type references + fn rewrite_or_replace_type( &self, u: &mut Unstructured<'_>, ty: &json_schema::Type, ) -> Result> { - let new_ty = if let Some(ids) = self.bindings.get(ty) { - json_schema::Type::CommonTypeRef { + if let Some(ids) = self.bindings.get(ty) { + Ok(json_schema::Type::CommonTypeRef { type_name: ast::Name::unqualified_name(u.choose(ids)?.clone()).into(), - } + }) } else { - self.rewrite_type(u, ty)? - }; - Ok(new_ty) + self.rewrite_type(u, ty) + } } // Generate common types based on the bindings @@ -639,10 +883,11 @@ impl Schema { &self, u: &mut Unstructured<'_>, ) -> Result> { - let attribute_types = &self.attributes; let mut bindings = Bindings::new(); - for (_, ty) in attribute_types { - bind_type(ty, u, &mut bindings)?; + for (_, ty) in &self.attributes { + if let json_schema::EntityAttributeTypeInternal::Type(ty) = ty { + bind_type(ty, u, &mut bindings)?; + } } let common_types = bindings.to_common_types(u)?; @@ -668,8 +913,8 @@ impl Schema { Some(applies) => Some(json_schema::ApplySpec { resource_types: applies.resource_types.clone(), principal_types: applies.principal_types.clone(), - context: json_schema::AttributesOrContext( - bindings.rewrite_record_type(u, &applies.context.0)?, + context: json_schema::RecordOrContextAttributes( + bindings.rewrite_or_replace_type(u, &applies.context.0)?, ), }), None => None, @@ -760,13 +1005,12 @@ impl Schema { } } let mut attributes = Vec::new(); - for schematype in nsdef - .common_types - .values() - .chain(nsdef.entity_types.values().map(|etype| &etype.shape.0)) - { + for schematype in nsdef.common_types.values() { attributes.extend(attrs_in_schematype(&nsdef, schematype)); } + for etype in nsdef.entity_types.values() { + attributes.extend(attrs_in_etype(&nsdef, etype)); + } let attributes_by_type = build_attributes_by_type(&nsdef, &nsdef.entity_types, namespace.as_ref()); Ok(Schema { @@ -898,7 +1142,7 @@ impl Schema { id.clone(), json_schema::EntityType { member_of_types: vec![], - shape: arbitrary_attrspec(&settings, &entity_type_names, u)?, + shape: arbitrary_entityattributes(&settings, &entity_type_names, u)?, }, )) }) @@ -991,7 +1235,7 @@ impl Schema { Some(json_schema::ApplySpec { resource_types: picked_resource_types, principal_types: picked_principal_types, - context: arbitrary_attrspec(&settings, &entity_type_names, u)?, + context: arbitrary_rca(&settings, &entity_type_names, u)?, }) }, member_of: if settings.enable_action_groups_and_attrs { @@ -1031,18 +1275,33 @@ impl Schema { entity_types: entity_types.into_iter().collect(), actions: actions.into_iter().collect(), }; - let attrsorcontexts /* : impl Iterator */ = nsdef.entity_types.values().map(|et| attrs_from_attrs_or_context(&nsdef, &et.shape)) - .chain(nsdef.actions.iter().filter_map(|(_, action)| action.applies_to.as_ref()).map(|a| attrs_from_attrs_or_context(&nsdef, &a.context))); - let attributes: Vec<(SmolStr, json_schema::Type<_>)> = attrsorcontexts - .flat_map(|attributes| { - attributes.attrs.iter().map(|(s, ty)| { + let entity_attributes = nsdef + .entity_types + .values() + .map(|et| attrs_from_ea(&nsdef, &et.shape)) + .flat_map(|attrs| { + attrs.attrs.into_iter().map(|(s, ty)| { ( s.parse().expect("attribute names should be valid Ids"), - ty.ty.clone(), + ty.ty, ) }) - }) - .collect(); + }); + let context_attributes = nsdef + .actions + .iter() + .filter_map(|(_, action)| action.applies_to.as_ref()) + .map(|a| attrs_from_rca(&nsdef, &a.context)) + .flat_map(|attrs| { + attrs.attrs.into_iter().map(|(s, ty)| { + ( + s.parse().expect("attribute names should be valid Ids"), + json_schema::EntityAttributeTypeInternal::Type(ty.ty.clone()), + ) + }) + }); + let attributes: Vec<(SmolStr, json_schema::EntityAttributeTypeInternal<_>)> = + entity_attributes.chain(context_attributes).collect(); let attributes_by_type = build_attributes_by_type(&nsdef, nsdef.entity_types.iter(), namespace.as_ref()); let actions_eids = nsdef @@ -1077,13 +1336,13 @@ impl Schema { pub fn arbitrary_size_hint(depth: usize) -> (usize, Option) { arbitrary::size_hint::and_all(&[ as Arbitrary>::size_hint(depth), - arbitrary_attrspec_size_hint(depth), // actually we do one of these per Name that was generated - size_hint_for_ratio(1, 2), // actually many of these calls + arbitrary_rca_size_hint(depth), // actually we do one of these per Name that was generated + size_hint_for_ratio(1, 2), // actually many of these calls as Arbitrary>::size_hint(depth), size_hint_for_ratio(1, 8), // actually many of these calls size_hint_for_ratio(1, 4), // zero to many of these calls size_hint_for_ratio(1, 2), // zero to many of these calls - arbitrary_attrspec_size_hint(depth), + arbitrary_rca_size_hint(depth), size_hint_for_ratio(1, 2), // actually many of these calls ::size_hint(depth), ]) @@ -1176,11 +1435,14 @@ impl Schema { .map(json_schema::Type::Type)) } - /// get an attribute name and its `json_schema::Type`, from the schema + /// get an attribute name and its attribute type, from the schema pub fn arbitrary_attr( &self, u: &mut Unstructured<'_>, - ) -> Result<&(SmolStr, json_schema::Type)> { + ) -> Result<&( + SmolStr, + json_schema::EntityAttributeTypeInternal, + )> { u.choose(&self.attributes) .map_err(|e| while_doing("getting arbitrary attr from schema".into(), e)) } @@ -1223,14 +1485,14 @@ impl Schema { ( ast::EntityType::from(ast::Name::from(name.clone())) .qualify_with(self.namespace()), - attrs_from_attrs_or_context(&self.schema, &et.shape), + attrs_from_ea(&self.schema, &et.shape), ) }) .flat_map(|(tyname, attributes)| { attributes .attrs - .iter() - .filter(|(_, ty)| ty.ty == target_type) + .into_iter() + .filter(|(_, ty)| matches!(&ty.ty, json_schema::EntityAttributeTypeInternal::Type(t) if t == &target_type)) .map(move |(attr_name, _)| (tyname.clone(), attr_name.clone())) }) .collect(); @@ -1445,7 +1707,7 @@ impl Schema { let mut attributes: Vec<_> = action .applies_to .as_ref() - .map(|a| attrs_from_attrs_or_context(&self.schema, &a.context)) + .map(|a| attrs_from_rca(&self.schema, &a.context)) .iter() .flat_map(|attributes| attributes.attrs.iter()) .collect(); @@ -1596,21 +1858,21 @@ fn downgrade_schematypevariant_to_raw( }) => json_schema::TypeVariant::Record(json_schema::RecordType { attributes: attributes .into_iter() - .map(|(k, v)| (k, downgrade_toa_to_raw(v))) + .map(|(k, v)| (k, downgrade_rat_to_raw(v))) .collect(), additional_attributes, }), } } -/// Utility function to "downgrade" a [`TypeOfAttribute`] with fully-qualified +/// Utility function to "downgrade" a [`json_schema::RecordAttributeType`] with fully-qualified /// names into one with [`RawName`]s. See notes on [`downgrade_frag_to_raw()`]. -fn downgrade_toa_to_raw( - toa: json_schema::TypeOfAttribute, -) -> json_schema::TypeOfAttribute { - json_schema::TypeOfAttribute { - ty: downgrade_schematype_to_raw(toa.ty), - required: toa.required, +fn downgrade_rat_to_raw( + rat: json_schema::RecordAttributeType, +) -> json_schema::RecordAttributeType { + json_schema::RecordAttributeType { + ty: downgrade_schematype_to_raw(rat.ty), + required: rat.required, } } @@ -1626,17 +1888,79 @@ fn downgrade_entitytype_to_raw( .into_iter() .map(RawName::from_name) .collect(), - shape: downgrade_aoc_to_raw(entitytype.shape), + shape: downgrade_ea_to_raw(entitytype.shape), } } -/// Utility function to "downgrade" a [`AttributesOrContext`] with -/// fully-qualified names into one with [`RawName`]s. See notes on +/// Utility function to "downgrade" a [`json_schema::RecordOrContextAttributes`] +/// with fully-qualified names into one with [`RawName`]s. See notes on /// [`downgrade_frag_to_raw()`]. -fn downgrade_aoc_to_raw( - aoc: json_schema::AttributesOrContext, -) -> json_schema::AttributesOrContext { - json_schema::AttributesOrContext(downgrade_schematype_to_raw(aoc.0)) +fn downgrade_rca_to_raw( + rca: json_schema::RecordOrContextAttributes, +) -> json_schema::RecordOrContextAttributes { + json_schema::RecordOrContextAttributes(downgrade_schematype_to_raw(rca.0)) +} + +/// Utility function to "downgrade" a [`json_schema::EntityAttributes`] +/// with fully-qualified names into one with [`RawName`]s. See notes on +/// [`downgrade_frag_to_raw()`]. +fn downgrade_ea_to_raw( + ea: json_schema::EntityAttributes, +) -> json_schema::EntityAttributes { + match ea { + json_schema::EntityAttributes::RecordAttributes(rca) => { + json_schema::EntityAttributes::RecordAttributes(downgrade_rca_to_raw(rca)) + } + json_schema::EntityAttributes::EntityAttributes( + json_schema::EntityAttributesInternal { attrs, .. }, + ) => downgrade_rty_to_raw(attrs).into(), + } +} + +/// Utility function to "downgrade" a [`json_schema::RecordType`] +/// with fully-qualified names into one with [`RawName`]s. +/// See notes on [`downgrade_frag_to_raw()`]. +fn downgrade_rty_to_raw( + rty: json_schema::RecordType>, +) -> json_schema::RecordType> { + json_schema::RecordType { + attributes: rty + .attributes + .into_iter() + .map(|(k, v)| (k, downgrade_eatype_to_raw(v))) + .collect(), + additional_attributes: rty.additional_attributes, + } +} + +/// Utility function to "downgrade" a [`json_schema::EntityAttributeType`] +/// with fully-qualified names into one with [`RawName`]s. +/// See notes on [`downgrade_frag_to_raw()`]. +fn downgrade_eatype_to_raw( + eatype: json_schema::EntityAttributeType, +) -> json_schema::EntityAttributeType { + json_schema::EntityAttributeType { + ty: downgrade_eatypeinternal_to_raw(eatype.ty), + required: eatype.required, + } +} + +/// Utility function to "downgrade" a [`json_schema::EntityAttributeTypeInternal`] +/// with fully-qualified names into one with [`RawName`]s. +/// See notes on [`downgrade_frag_to_raw()`]. +fn downgrade_eatypeinternal_to_raw( + eatypeinternal: json_schema::EntityAttributeTypeInternal, +) -> json_schema::EntityAttributeTypeInternal { + match eatypeinternal { + json_schema::EntityAttributeTypeInternal::Type(ty) => { + json_schema::EntityAttributeTypeInternal::Type(downgrade_schematype_to_raw(ty)) + } + json_schema::EntityAttributeTypeInternal::EAMap { value_type } => { + json_schema::EntityAttributeTypeInternal::EAMap { + value_type: downgrade_schematype_to_raw(value_type), + } + } + } } /// Utility function to "downgrade" an [`ActionType`] with fully-qualified names @@ -1669,7 +1993,7 @@ fn downgrade_applyspec_to_raw( .into_iter() .map(RawName::from_name) .collect(), - context: downgrade_aoc_to_raw(applyspec.context), + context: downgrade_rca_to_raw(applyspec.context), } }