Skip to content

Commit

Permalink
use Protobuf for additional DRT targets (#562)
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Disselkoen <cdiss@amazon.com>
  • Loading branch information
cdisselkoen authored Mar 4, 2025
1 parent 9236b32 commit 434f09f
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 134 deletions.
14 changes: 7 additions & 7 deletions cedar-drt/fuzz/fuzz_targets/protobuf-roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use serde::Serialize;

use crate::arbitrary::Arbitrary;
use crate::arbitrary::Unstructured;
use cedar_drt::{AuthorizationRequestMsg, OwnedAuthorizationRequestMsg};
use cedar_drt::{AuthorizationRequest, OwnedAuthorizationRequest};
use cedar_drt_inner::{fuzz_target, schemas::Equiv};
use cedar_policy::proto;
use cedar_policy_core::{
Expand Down Expand Up @@ -97,27 +97,27 @@ fuzz_target!(|input: FuzzTargetInput| {
let s_policy: ast::StaticPolicy = input.policy.into();
let mut policies: ast::PolicySet = ast::PolicySet::new();
policies.add(s_policy.into()).expect("Failed to add policy");
roundtrip_authz_request_msg(AuthorizationRequestMsg {
roundtrip_authz_request_msg(AuthorizationRequest {
request: &input.request.into(),
policies: &policies,
entities: &input.entities,
});
roundtrip_schema(input.schema);
});

fn roundtrip_authz_request_msg(auth_request: AuthorizationRequestMsg) {
fn roundtrip_authz_request_msg(auth_request: AuthorizationRequest) {
// AST -> Protobuf
let auth_request_proto = cedar_drt::proto::AuthorizationRequestMsg::from(&auth_request);
let auth_request_proto = cedar_drt::proto::AuthorizationRequest::from(&auth_request);

// Protobuf -> Bytes
let buf = auth_request_proto.encode_to_vec();

// Bytes -> Protobuf
let roundtripped_proto = cedar_drt::proto::AuthorizationRequestMsg::decode(&buf[..])
.expect("Failed to deserialize AuthorizationRequestMsg from proto");
let roundtripped_proto = cedar_drt::proto::AuthorizationRequest::decode(&buf[..])
.expect("Failed to deserialize AuthorizationRequest from proto");

// Protobuf -> AST
let roundtripped = OwnedAuthorizationRequestMsg::from(roundtripped_proto);
let roundtripped = OwnedAuthorizationRequest::from(roundtripped_proto);

// Checking request equality (ignores loc field)
assert_eq!(
Expand Down
14 changes: 6 additions & 8 deletions cedar-drt/fuzz/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,9 @@ pub fn run_req_val_test(
});
info!("{}{}", RUST_REQ_VALIDATION_MSG, rust_auth_dur.as_nanos());

let definitional_res = custom_impl.validate_request(&schema, &request);
match definitional_res {
TestResult::Failure(_) => {
panic!("request validation test: failed to parse");
match custom_impl.validate_request(&schema, &request) {
TestResult::Failure(e) => {
panic!("failed to execute request validation: {e}");
}
TestResult::Success(definitional_res) => {
if rust_res.is_ok() {
Expand Down Expand Up @@ -354,10 +353,9 @@ pub fn run_ent_val_test(
)
});
info!("{}{}", RUST_ENT_VALIDATION_MSG, rust_auth_dur.as_nanos());
let definitional_res = custom_impl.validate_entities(&schema, &entities);
match definitional_res {
TestResult::Failure(_) => {
panic!("entity validation test: failed to parse");
match custom_impl.validate_entities(&schema, &entities) {
TestResult::Failure(e) => {
panic!("failed to execute entity validation: {e}");
}
TestResult::Success(definitional_res) => {
if rust_res.is_ok() {
Expand Down
21 changes: 19 additions & 2 deletions cedar-drt/protobuf_schema/Messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,31 @@ package cedar_drt;
import "core.proto";
import "validator.proto";

message AuthorizationRequestMsg {
message AuthorizationRequest {
cedar_policy_core.Request request = 1;
cedar_policy_core.PolicySet policies = 2;
cedar_policy_core.Entities entities = 3;
}

message ValidationRequestMsg {
message ValidationRequest {
cedar_policy_validator.Schema schema = 1;
cedar_policy_core.PolicySet policies = 2;
cedar_policy_validator.ValidationMode mode = 3;
}

message EvaluationRequest {
cedar_policy_core.Expr expr = 1;
cedar_policy_core.Request request = 2;
cedar_policy_core.Entities entities = 3;
cedar_policy_core.Expr expected = 4;
}

message EntityValidationRequest {
cedar_policy_validator.Schema schema = 1;
cedar_policy_core.Entities entities = 2;
}

message RequestValidationRequest {
cedar_policy_validator.Schema schema = 1;
cedar_policy_core.Request request = 2;
}
77 changes: 46 additions & 31 deletions cedar-drt/src/definitional_request_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ pub mod proto {
include!(concat!(env!("OUT_DIR"), "/cedar_drt.rs"));
}

#[derive(Clone, Debug, Serialize)]
pub struct AuthorizationRequestMsg<'a> {
#[derive(Clone, Debug)]
pub struct AuthorizationRequest<'a> {
pub request: &'a ast::Request,
pub policies: &'a ast::PolicySet,
pub entities: &'a Entities,
}

impl From<&AuthorizationRequestMsg<'_>> for proto::AuthorizationRequestMsg {
fn from(v: &AuthorizationRequestMsg<'_>) -> Self {
impl From<&AuthorizationRequest<'_>> for proto::AuthorizationRequest {
fn from(v: &AuthorizationRequest<'_>) -> Self {
Self {
request: Some(cedar_policy::proto::models::Request::from(v.request)),
policies: Some(cedar_policy::proto::models::PolicySet::from(v.policies)),
Expand All @@ -42,17 +42,17 @@ impl From<&AuthorizationRequestMsg<'_>> for proto::AuthorizationRequestMsg {
}
}

// Converting `AuthorizationRequestMsg` from proto to non-proto structures is
// Converting `AuthorizationRequest` from proto to non-proto structures is
// only required for some roundtrip tests
#[derive(Clone, Debug)]
pub struct OwnedAuthorizationRequestMsg {
pub struct OwnedAuthorizationRequest {
pub request: ast::Request,
pub policies: ast::PolicySet,
pub entities: Entities,
}

impl From<proto::AuthorizationRequestMsg> for OwnedAuthorizationRequestMsg {
fn from(v: proto::AuthorizationRequestMsg) -> Self {
impl From<proto::AuthorizationRequest> for OwnedAuthorizationRequest {
fn from(v: proto::AuthorizationRequest) -> Self {
Self {
request: ast::Request::from(&v.request.unwrap_or_default()),
policies: ast::PolicySet::try_from(&v.policies.unwrap_or_default())
Expand All @@ -62,15 +62,15 @@ impl From<proto::AuthorizationRequestMsg> for OwnedAuthorizationRequestMsg {
}
}

#[derive(Clone, Debug, Serialize)]
pub struct ValidationRequestMsg<'a> {
#[derive(Clone, Debug)]
pub struct ValidationRequest<'a> {
pub schema: &'a ValidatorSchema,
pub policies: &'a ast::PolicySet,
pub mode: ValidationMode,
}

impl From<&ValidationRequestMsg<'_>> for proto::ValidationRequestMsg {
fn from(v: &ValidationRequestMsg<'_>) -> Self {
impl From<&ValidationRequest<'_>> for proto::ValidationRequest {
fn from(v: &ValidationRequest<'_>) -> Self {
Self {
schema: Some(cedar_policy::proto::models::Schema::from(v.schema)),
policies: Some(cedar_policy::proto::models::PolicySet::from(v.policies)),
Expand All @@ -79,51 +79,66 @@ impl From<&ValidationRequestMsg<'_>> for proto::ValidationRequestMsg {
}
}

#[derive(Debug, Serialize)]
pub struct AuthorizationRequest<'a> {
pub request: &'a ast::Request,
pub policies: &'a ast::PolicySet,
pub entities: &'a Entities,
}

#[derive(Debug, Serialize)]
#[derive(Clone, Debug)]
pub struct EvaluationRequest<'a> {
pub request: &'a ast::Request,
pub entities: &'a Entities,
pub expr: &'a ast::Expr,
pub expected: Option<&'a ast::Expr>,
}

#[derive(Debug, Serialize)]
impl From<&EvaluationRequest<'_>> for proto::EvaluationRequest {
fn from(v: &EvaluationRequest<'_>) -> Self {
Self {
expr: Some(cedar_policy::proto::models::Expr::from(v.expr)),
request: Some(cedar_policy::proto::models::Request::from(v.request)),
entities: Some(cedar_policy::proto::models::Entities::from(v.entities)),
expected: v.expected.map(cedar_policy::proto::models::Expr::from),
}
}
}

#[derive(Debug, Serialize, Clone)]
pub struct PartialEvaluationRequest<'a> {
pub request: &'a ast::Request,
pub entities: &'a Entities,
pub expr: &'a ast::Expr,
pub expected: Option<ExprOrValue>,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct PartialAuthorizationRequest<'a> {
pub request: &'a ast::Request,
pub entities: &'a Entities,
pub policies: &'a ast::PolicySet,
}

#[derive(Debug, Serialize)]
pub struct ValidationRequest<'a> {
pub schema: &'a ValidatorSchema,
pub policies: &'a ast::PolicySet,
pub mode: ValidationMode,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Clone)]
pub struct RequestValidationRequest<'a> {
pub schema: &'a ValidatorSchema,
pub request: &'a ast::Request,
}

#[derive(Debug, Serialize)]
impl From<&RequestValidationRequest<'_>> for proto::RequestValidationRequest {
fn from(v: &RequestValidationRequest<'_>) -> Self {
Self {
schema: Some(cedar_policy::proto::models::Schema::from(v.schema)),
request: Some(cedar_policy::proto::models::Request::from(v.request)),
}
}
}

#[derive(Debug, Clone)]
pub struct EntityValidationRequest<'a> {
pub schema: &'a ValidatorSchema,
pub entities: &'a Entities,
}

impl From<&EntityValidationRequest<'_>> for proto::EntityValidationRequest {
fn from(v: &EntityValidationRequest<'_>) -> Self {
Self {
schema: Some(cedar_policy::proto::models::Schema::from(v.schema)),
entities: Some(cedar_policy::proto::models::Entities::from(v.entities)),
}
}
}
38 changes: 21 additions & 17 deletions cedar-drt/src/lean_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ impl LeanDefinitionalEngine {
policies: &ast::PolicySet,
entities: &Entities,
) -> TestResult<TestResponse> {
let auth_request = AuthorizationRequestMsg {
let auth_request = AuthorizationRequest {
request,
policies,
entities,
};
let auth_request_proto = proto::AuthorizationRequestMsg::from(&auth_request);
let auth_request_proto = proto::AuthorizationRequest::from(&auth_request);
let buf = auth_request_proto.encode_to_vec();
let req = buf_to_lean_obj(&buf);
// Lean will decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h
Expand Down Expand Up @@ -336,16 +336,16 @@ impl LeanDefinitionalEngine {
expected: Option<Value>,
) -> TestResult<bool> {
let expected_as_expr: Option<Expr> = expected.map(|v| v.into());
let request: String = serde_json::to_string(&EvaluationRequest {
let req = EvaluationRequest {
request,
entities,
expr,
expected: expected_as_expr.as_ref(),
})
.expect("failed to serialize request, expression, or entities");
let cstring = CString::new(request).expect("`CString::new` failed");
};
let req_proto = proto::EvaluationRequest::from(&req);
let buf = req_proto.encode_to_vec();
let req = buf_to_lean_obj(&buf);
// Lean will decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h
let req = unsafe { lean_mk_string(cstring.as_ptr() as *const u8) };
let response = unsafe { evaluateDRT(req) };
// req can no longer be assumed to exist
let response_string = lean_obj_p_to_rust_string(response);
Expand Down Expand Up @@ -380,12 +380,12 @@ impl LeanDefinitionalEngine {
schema: &ValidatorSchema,
policies: &ast::PolicySet,
) -> TestResult<TestValidationResult> {
let val_request = ValidationRequestMsg {
let val_request = ValidationRequest {
schema,
policies,
mode: cedar_policy_validator::ValidationMode::default(),
};
let val_request_proto = proto::ValidationRequestMsg::from(&val_request);
let val_request_proto = proto::ValidationRequest::from(&val_request);
let buf = val_request_proto.encode_to_vec();
let req = buf_to_lean_obj(&buf);
// Lean will decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h
Expand All @@ -400,11 +400,13 @@ impl LeanDefinitionalEngine {
schema: &ValidatorSchema,
request: &ast::Request,
) -> TestResult<TestValidationResult> {
let request: String = serde_json::to_string(&RequestValidationRequest { schema, request })
.expect("failed to serialize request");
let cstring = CString::new(request).expect("CString::new failed");
let req = unsafe { lean_mk_string(cstring.as_ptr() as *const u8) };
let req = RequestValidationRequest { schema, request };
let req_proto = proto::RequestValidationRequest::from(&req);
let buf = req_proto.encode_to_vec();
let req = buf_to_lean_obj(&buf);
// Lean will decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h
let response = unsafe { validateRequestDRT(req) };
// req can no longer be assumed to exist
let response_string = lean_obj_p_to_rust_string(response);
Self::deserialize_validation_response(response_string)
}
Expand All @@ -414,11 +416,13 @@ impl LeanDefinitionalEngine {
schema: &ValidatorSchema,
entities: &Entities,
) -> TestResult<TestValidationResult> {
let request: String = serde_json::to_string(&EntityValidationRequest { schema, entities })
.expect("failed to serialize request");
let cstring = CString::new(request).expect("CString::new failed");
let req = unsafe { lean_mk_string(cstring.as_ptr() as *const u8) };
let req = EntityValidationRequest { schema, entities };
let req_proto = proto::EntityValidationRequest::from(&req);
let buf = req_proto.encode_to_vec();
let req = buf_to_lean_obj(&buf);
// Lean will decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h
let response = unsafe { validateEntitiesDRT(req) };
// req can no longer be assumed to exist
let response_string = lean_obj_p_to_rust_string(response);
Self::deserialize_validation_response(response_string)
}
Expand Down
Loading

0 comments on commit 434f09f

Please sign in to comment.