Skip to content

Commit

Permalink
generate protodatas for unit tests using cedar-policy interface (#551)
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Disselkoen <cdiss@amazon.com>
  • Loading branch information
cdisselkoen authored Feb 26, 2025
1 parent bd2004c commit 9c12a80
Show file tree
Hide file tree
Showing 19 changed files with 330 additions and 295 deletions.
Binary file modified cedar-lean/UnitTest/CedarProto-test-data/abac.protodata
Binary file not shown.
Binary file modified cedar-lean/UnitTest/CedarProto-test-data/entity.protodata
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ edition = "2021"
repository = "https://github.com/cedar-policy/cedar-spec"

[dependencies]
cedar-policy-core = { path = "../../../../cedar/cedar-policy-core", version = "*" }
cedar-policy-validator = { path = "../../../../cedar/cedar-policy-validator", version = "*" }
cedar-policy = { path = "../../../../cedar/cedar-policy", version = "*", features = ["protobufs"] }
miette = { version = "7.1.0", features = ["fancy"] }
prost = "0.13"

[lints.rust]
unsafe_code = "forbid"
Expand Down
236 changes: 88 additions & 148 deletions cedar-lean/UnitTest/CedarProto-test-data/generate-protodata/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,84 +1,55 @@
use cedar_policy_core::{
ast, entities,
extensions::Extensions,
parser::{parse_policy, parse_policy_or_template, parse_policyset, Loc},
use cedar_policy::{
proto::traits::Protobuf, Entities, Entity, Expression, PolicyId, PolicySet, Request, Schema,
};
use cedar_policy_validator::types as validator_types;
use cedar_policy::proto;
use prost::Message;
use std::collections::{HashMap, HashSet};
use cedar_policy::{Context, EntityUid, Policy, RestrictedExpression, SlotId};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::str::FromStr;

fn output_dir() -> PathBuf {
std::env::var("OUTPUT_DIR")
.map(PathBuf::from)
.unwrap_or(PathBuf::from_str(".").unwrap())
.unwrap_or_else(|_| PathBuf::from_str(".").unwrap())
}

#[track_caller]
fn encode_expr(path: impl AsRef<Path>, e: &str) {
let expr: ast::Expr = e.parse().unwrap();
let proto: proto::models::Expr = (&expr).into();
let encoded = proto.encode_to_vec();
let expr: Expression = e.parse().unwrap();
let encoded = expr.encode();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

/// Encodes using the protobuf TemplateBody format
#[track_caller]
fn encode_policy_as_template(path: impl AsRef<Path>, p: &str) {
let policy: ast::Template = parse_policy_or_template(None, p).unwrap().into();
let proto: proto::models::TemplateBody = (&policy).into();
let encoded = proto.encode_to_vec();
fn encode_policyset(path: impl AsRef<Path>, ps: &PolicySet) {
let encoded = ps.encode();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

#[track_caller]
fn encode_policyset(path: impl AsRef<Path>, ps: &ast::PolicySet) {
let proto: proto::models::PolicySet = ps.into();
let encoded = proto.encode_to_vec();
fn encode_request(path: impl AsRef<Path>, r: &Request) {
let encoded = r.encode();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

#[track_caller]
fn encode_request(path: impl AsRef<Path>, r: &ast::Request) {
let proto: proto::models::Request = r.into();
let encoded = proto.encode_to_vec();
fn encode_entity(path: impl AsRef<Path>, e: &Entity) {
let encoded = e.encode();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

#[track_caller]
fn encode_entity(path: impl AsRef<Path>, e: &ast::Entity) {
let proto: proto::models::Entity = e.into();
let encoded = proto.encode_to_vec();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

#[track_caller]
fn encode_entities(path: impl AsRef<Path>, es: &entities::Entities) {
let proto: proto::models::Entities = es.into();
let encoded = proto.encode_to_vec();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

#[track_caller]
fn encode_val_type(path: impl AsRef<Path>, ty: &validator_types::Type) {
let proto: proto::models::Type = ty.into();
let encoded = proto.encode_to_vec();
fn encode_entities(path: impl AsRef<Path>, es: &Entities) {
let encoded = es.encode();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

#[track_caller]
fn encode_schema(path: impl AsRef<Path>, s: &str) {
let (schema, warnings) = cedar_policy_validator::ValidatorSchema::from_cedarschema_str(
s,
&Extensions::all_available(),
)
.map_err(|e| format!("{:?}", miette::Report::new(e)))
.unwrap();
let (schema, warnings) = Schema::from_cedarschema_str(s)
.map_err(|e| format!("{:?}", miette::Report::new(e)))
.unwrap();
assert_eq!(warnings.count(), 0);
let proto: proto::models::Schema = (&schema).into();
let encoded = proto.encode_to_vec();
let encoded = schema.encode();
std::fs::write(output_dir().join(path.as_ref()), encoded).unwrap();
}

Expand Down Expand Up @@ -145,16 +116,19 @@ fn main() {
r#"decimal("3.14").lessThan(decimal("3.1416"))"#,
);

encode_policy_as_template(
encode_policyset(
"rbac.protodata",
r#"permit(principal == User::"a b c", action, resource is App::Widget);"#,
&PolicySet::from_str(
r#"permit(principal == User::"a b c", action, resource is App::Widget);"#,
)
.unwrap(),
);
encode_policy_as_template(
encode_policyset(
"abac.protodata",
r#"permit(principal, action, resource) when { principal == resource.owner } unless { resource.sensitive };"#,
&PolicySet::from_str(r#"permit(principal, action, resource) when { principal == resource.owner } unless { resource.sensitive };"#).unwrap(),
);

let mut policyset: ast::PolicySet = parse_policyset(r#"
let mut policyset = PolicySet::from_str(r#"
permit(principal, action == Action::"do", resource == Blob::"thing") when { context.foo - 7 == context.bar };
forbid(principal is UnauthenticatedUser, action, resource) when { resource.requiresAuthentication };
@foo("bar")
Expand All @@ -163,142 +137,108 @@ fn main() {
"#).unwrap();
policyset
.link(
ast::PolicyID::from_string("policy3"),
ast::PolicyID::from_string("linkedpolicy"),
PolicyId::from_str("policy3").unwrap(),
PolicyId::from_str("linkedpolicy").unwrap(),
HashMap::from_iter([(
ast::SlotId::principal(),
ast::EntityUID::with_eid_and_type("User", "alice").unwrap(),
SlotId::principal(),
EntityUid::from_type_name_and_id("User".parse().unwrap(), "alice".parse().unwrap()),
)]),
)
.unwrap();
encode_policyset("policyset.protodata", &policyset);

// regression test for #500: a policyset with only templates, no static or linked policies
let policyset: ast::PolicySet = parse_policyset(r#"
encode_policyset(
"policyset_just_templates.protodata",
&PolicySet::from_str(
r#"
permit(principal == ?principal, action, resource);
"#).unwrap();
encode_policyset("policyset_just_templates.protodata", &policyset);
"#,
)
.unwrap(),
);

// regression test for #505: a policyset with exactly one static policy, with id "" (empty string)
let policy: ast::StaticPolicy = parse_policy(Some(ast::PolicyID::from_string("")), r#"
let policy = Policy::parse(
Some(PolicyId::from_str("").unwrap()),
r#"
permit(principal, action, resource) when { true };
"#).unwrap();
let mut policyset = ast::PolicySet::new();
policyset.add_static(policy).unwrap();
"#,
)
.unwrap();
let policyset = PolicySet::from_policies([policy]).unwrap();
encode_policyset("policyset_one_static_policy.protodata", &policyset);

encode_request(
"request.protodata",
&ast::Request::new(
(
ast::EntityUID::with_eid_and_type("User", "alice").unwrap(),
None,
),
(
ast::EntityUID::with_eid_and_type("Action", "access").unwrap(),
Some(Loc::new(2..5, "source code".into())),
),
(
ast::EntityUID::with_eid_and_type("Folder", "data").unwrap(),
None,
),
ast::Context::from_pairs(
[("foo".into(), ast::RestrictedExpr::val(true))],
Extensions::all_available(),
)
.unwrap(),
None::<&ast::RequestSchemaAllPass>,
Extensions::all_available(),
&Request::new(
EntityUid::from_type_name_and_id("User".parse().unwrap(), "alice".parse().unwrap()),
EntityUid::from_type_name_and_id("Action".parse().unwrap(), "access".parse().unwrap()),
EntityUid::from_type_name_and_id("Folder".parse().unwrap(), "data".parse().unwrap()),
Context::from_pairs([("foo".into(), RestrictedExpression::new_bool(true))]).unwrap(),
None,
)
.unwrap(),
);

encode_entity(
"entity.protodata",
&ast::Entity::new(
ast::EntityUID::from_components(
ast::EntityType::from_normalized_str("A::B").unwrap(),
ast::Eid::new("C"),
None,
),
&Entity::new_with_tags(
EntityUid::from_type_name_and_id("A::B".parse().unwrap(), "C".parse().unwrap()),
[
("foo".into(), "[1, -1]".parse().unwrap()),
("bar".into(), ast::RestrictedExpr::val(false)),
("bar".into(), RestrictedExpression::new_bool(false)),
],
[
EntityUid::from_type_name_and_id("Parent".parse().unwrap(), "1".parse().unwrap()),
EntityUid::from_type_name_and_id(
"Grandparent".parse().unwrap(),
"A".parse().unwrap(),
),
],
HashSet::from_iter([
ast::EntityUID::with_eid_and_type("Parent", "1").unwrap(),
ast::EntityUID::with_eid_and_type("Grandparent", "A").unwrap(),
]),
HashSet::new(),
[
("tag1".into(), ast::RestrictedExpr::val("val1")),
("tag2".into(), ast::RestrictedExpr::val("val2")),
(
"tag1".into(),
RestrictedExpression::new_string("val1".into()),
),
(
"tag2".into(),
RestrictedExpression::new_string("val2".into()),
),
],
Extensions::all_available(),
)
.unwrap(),
);

encode_entities(
"entities.protodata",
&entities::Entities::from_entities(
&Entities::from_entities(
[
ast::Entity::with_uid(ast::EntityUID::with_eid_and_type("ABC", "123").unwrap()),
ast::Entity::with_uid(ast::EntityUID::with_eid_and_type("DEF", "234").unwrap()),
Entity::with_uid(EntityUid::from_type_name_and_id(
"ABC".parse().unwrap(),
"123".parse().unwrap(),
)),
Entity::with_uid(EntityUid::from_type_name_and_id(
"DEF".parse().unwrap(),
"234".parse().unwrap(),
)),
],
None::<&entities::AllEntitiesNoAttrsSchema>,
entities::TCComputation::ComputeNow,
Extensions::all_available(),
None,
)
.unwrap(),
);

let primitive_bool = validator_types::Type::Primitive {
primitive_type: validator_types::Primitive::Bool,
};
let primitive_long = validator_types::Type::Primitive {
primitive_type: validator_types::Primitive::Long,
};
let primitive_string = validator_types::Type::Primitive {
primitive_type: validator_types::Primitive::String,
};

encode_val_type("type_true.protodata", &validator_types::Type::True);
encode_val_type("type_false.protodata", &validator_types::Type::False);
encode_val_type("type_bool.protodata", &primitive_bool);
encode_val_type("type_long.protodata", &primitive_long.clone());
encode_val_type("type_string.protodata", &primitive_string.clone());
encode_val_type(
encode_schema("type_bool.protodata", "entity E { attr: Bool };");
encode_schema("type_long.protodata", "entity E { attr: Long };");
encode_schema("type_string.protodata", "entity E { attr: String };");
encode_schema(
"type_set_of_string.protodata",
&validator_types::Type::Set {
element_type: Some(Box::new(primitive_string.clone())),
},
);
encode_val_type(
"type_ip.protodata",
&validator_types::Type::ExtensionType {
name: ast::Name::parse_unqualified_name("ipaddr").unwrap(),
},
"entity E { attr: Set<String> };",
);
encode_val_type(
encode_schema("type_ip.protodata", "entity E { attr: ipaddr };");
encode_schema(
"type_record.protodata",
&validator_types::Type::EntityOrRecord(validator_types::EntityRecordKind::Record {
attrs: validator_types::Attributes::with_attributes(
[
(
"ham".into(),
validator_types::AttributeType::required_attribute(
primitive_string.clone(),
),
),
(
"eggs".into(),
validator_types::AttributeType::optional_attribute(primitive_long.clone()),
),
]
),
open_attributes: validator_types::OpenTag::ClosedAttributes,
}),
"entity E { attr: { ham: String, eggs?: Long } };",
);

encode_schema(
Expand Down
Binary file modified cedar-lean/UnitTest/CedarProto-test-data/nested_record.protodata
Binary file not shown.
Binary file modified cedar-lean/UnitTest/CedarProto-test-data/policyset.protodata
Binary file not shown.
Binary file modified cedar-lean/UnitTest/CedarProto-test-data/rbac.protodata
Binary file not shown.
10 changes: 5 additions & 5 deletions cedar-lean/UnitTest/CedarProto-test-data/record.protodata
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

!z

eggs



ham




eggs


Loading

0 comments on commit 9c12a80

Please sign in to comment.