Skip to content

Commit

Permalink
refactor out some panics in CedarProto (#557)
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Disselkoen <cdiss@amazon.com>
  • Loading branch information
cdisselkoen authored Feb 28, 2025
1 parent e60af35 commit e2dafd3
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 52 deletions.
6 changes: 6 additions & 0 deletions cedar-lean/Cedar/Validation/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def map {α β} (f : α → β) : Qualified α → Qualified β
| optional a => optional (f a)
| required a => required (f a)

def transpose {α ε} : Qualified (Except ε α) → Except ε (Qualified α)
| optional (.ok a) => .ok (optional a)
| required (.ok a) => .ok (required a)
| optional (.error e) => .error e
| required (.error e) => .error e

end Qualified

inductive CedarType where
Expand Down
40 changes: 25 additions & 15 deletions cedar-lean/CedarProto/Schema.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,42 @@ private def descendantsToAncestors [LT α] [DecidableEq α] [DecidableLT α] (de

namespace Schema

def toSchema (schema : Schema) : Validation.Schema :=
private def attrsToCedarType (attrs : Proto.Map String (Qualified ProtoType)) : Except String (Data.Map Spec.Attr (Qualified CedarType)) := do
let attrs ← attrs.toList.mapM λ (k,v) => do
let v ← v.map ProtoType.toCedarType |>.transpose
.ok (k, v)
.ok $ Data.Map.make attrs

/-- was surprised this isn't in the stdlib -/
def option_transpose : Option (Except ε α) → Except ε (Option α)
| none => .ok none
| some (.ok a) => .ok (some a)
| some (.error e) => .error e

def toSchema (schema : Schema) : Except String Validation.Schema := do
let ets := schema.ets.toList
let descendantMap := ets.map λ decl => (decl.name, Data.Set.make decl.descendants.toList)
let ancestorMap := descendantsToAncestors descendantMap
let ets := Data.Map.make $ ets.map λ decl =>
(decl.name,
if decl.enums.isEmpty then
.standard {
let ets ← ets.mapM λ decl => do
let ese : EntitySchemaEntry ←
if decl.enums.isEmpty then .ok $ .standard {
ancestors := ancestorMap.find! decl.name
attrs := Data.Map.make $ decl.attrs.toList.map λ (k,v) => (k, v.map ProtoType.toCedarType)
tags := decl.tags.map ProtoType.toCedarType
attrs := ← attrsToCedarType decl.attrs
tags := ← option_transpose $ decl.tags.map ProtoType.toCedarType
}
else
.enum $ Cedar.Data.Set.make decl.enums.toList
)
else .ok $ .enum $ Cedar.Data.Set.make decl.enums.toList
.ok (decl.name, ese)
let acts := schema.acts.toList
let descendantMap := acts.map λ decl => (decl.name, Data.Set.make decl.descendants.toList)
let ancestorMap := descendantsToAncestors descendantMap
let acts := Data.Map.make $ acts.map λ decl =>
(decl.name, {
let acts acts.mapM λ decl => do
.ok (decl.name, {
appliesToPrincipal := Data.Set.make decl.principalTypes.toList
appliesToResource := Data.Set.make decl.resourceTypes.toList
ancestors := ancestorMap.find! decl.name
context := Data.Map.make $ decl.context.toList.map λ (k,v) => (k, v.map ProtoType.toCedarType)
context := ← attrsToCedarType decl.context
})
{ ets, acts }
.ok { ets := Data.Map.make ets, acts := Data.Map.make acts }

@[inline]
def mergeEntityDecls (result : Schema) (x : Array EntityDecl) : Schema :=
Expand Down Expand Up @@ -143,6 +153,6 @@ def merge (x1 x2 : Schema) : Schema :=
}

deriving instance Inhabited for Schema
instance : Field Schema := Field.fromInterField Proto.Schema.toSchema merge
instance : Field Schema := Field.fromInterFieldFallible Proto.Schema.toSchema merge

end Cedar.Validation.Schema
22 changes: 12 additions & 10 deletions cedar-lean/CedarProto/Type.lean
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,19 @@ partial def merge (x1 x2 : ProtoType) : ProtoType :=
| _, _ => x2
end

partial def toCedarType : ProtoType → CedarType
| .prim .bool => .bool .anyBool
| .prim .long => .int
| .prim .string => .string
| .set t => .set t.toCedarType
| .entity e => .entity e
| .record r => .record $ Data.Map.make $ r.attrs.map λ (k,v) => (k, v.map toCedarType)
partial def toCedarType : ProtoType → Except String CedarType
| .prim .bool => .ok (.bool .anyBool)
| .prim .long => .ok .int
| .prim .string => .ok .string
| .set t => do .ok (.set (← t.toCedarType))
| .entity e => .ok (.entity e)
| .record r => do
let attrs ← r.attrs.mapM λ (k,v) => do .ok (k, ← v.map toCedarType |>.transpose)
.ok (.record $ Data.Map.make attrs)
| .ext n => match n.id with -- ignoring n.path because currently no extension types have nonempty namespaces
| "ipaddr" => .ext .ipAddr
| "decimal" => .ext .decimal
| _ => panic!(s!"unknown extension type name: {n}")
| "ipaddr" => .ok (.ext .ipAddr)
| "decimal" => .ok (.ext .decimal)
| _ => .error s!"unknown extension type name: {n}"

end ProtoType

Expand Down
30 changes: 17 additions & 13 deletions cedar-lean/CedarProto/Value.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,27 @@ def merge (v1 : Value) (v2 : Value) : Value :=
| .record m1, .record m2 => Cedar.Data.Map.mk (m1.kvs ++ m2.kvs)
| _, _ => v2 -- note this includes all the .ext cases

private def extExprToValue (xfn : ExtFun) (args : List Expr) : Value :=
private def extExprToValue (xfn : ExtFun) (args : List Expr) : Except String Value :=
match xfn, args with
| .decimal, [.lit (.string s)] => match Spec.Ext.Decimal.decimal s with
| .some v => .ext (.decimal v)
| .none => panic! s!"exprToValue: failed to parse decimal {s}"
| .some v => .ok $ .ext (.decimal v)
| .none => .error s!"exprToValue: failed to parse decimal {s}"
| .ip, [.lit (.string s)] => match Spec.Ext.IPAddr.ip s with
| .some v => .ext (.ipaddr v)
| .none => panic! s!"exprToValue: failed to parse ip {s}"
| _, _ => panic! ("exprToValue: unexpected extension value\n" ++ toString (repr (Expr.call xfn args)))

private partial def exprToValue : Expr → Value
| .lit p => .prim p
| .record r => .record (Cedar.Data.Map.make (r.map λ ⟨attr, e⟩ => ⟨attr, exprToValue e⟩))
| .set s => .set (Cedar.Data.Set.make (s.map exprToValue))
| .some v => .ok $ .ext (.ipaddr v)
| .none => .error s!"exprToValue: failed to parse ip {s}"
| _, _ => .error s!"exprToValue: unexpected extension value\n{repr (Expr.call xfn args)}"

private partial def exprToValue : Expr → Except String Value
| .lit p => .ok (.prim p)
| .record r => do
let attrs ← r.mapM λ ⟨attr, e⟩ => do .ok ⟨attr, ← exprToValue e⟩
.ok $ .record (Cedar.Data.Map.make attrs)
| .set s => do
let elts ← s.mapM exprToValue
.ok $ .set (Cedar.Data.Set.make elts)
| .call xfn args => extExprToValue xfn args
| _ => panic!("exprToValue: invalid input expression")
| e => .error s!"exprToValue: invalid input expression {repr e}"

instance : Field Value := Field.fromInterField exprToValue merge
instance : Field Value := Field.fromInterFieldFallible exprToValue merge

end Cedar.Spec.Value
5 changes: 5 additions & 0 deletions cedar-lean/Protobuf/BParsec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def pure (a : α) : BParsec α := λ pos => { pos, res := .ok a }
@[inline]
def fail (msg : String) : BParsec α := λ pos => { pos, res := .error msg }

@[inline]
def ofExcept : Except String α → BParsec α
| .ok a => pure a
| .error e => fail e

instance (α : Type) : Inhabited (BParsec α) := ⟨fail ""

@[inline]
Expand Down
9 changes: 9 additions & 0 deletions cedar-lean/Protobuf/Field.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def fromInterField {α β : Type} [Inhabited α] [Field α] (convert : α → β
merge := merge
}

@[inline]
def fromInterFieldFallible {α β : Type} [Inhabited α] [Field α] (convert : α → Except String β) (merge : β → β → β) : Field β := {
parse := do
let m : α ← Field.parse
ofExcept $ convert m
expectedWireType := Field.expectedWireType α
merge := merge
}

end Field

end Proto
32 changes: 18 additions & 14 deletions cedar-lean/UnitTest/CedarProto.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,29 @@ open Cedar.Data
/--
`filename` is expected to be the name of a file containing binary protobuf data.
This data is deserialized, and `f` is applied to it.
Then, this test ensures that the result is equal to the value `expected`.
Then, this test ensures that `f` succeeds and that the result is equal to the value `expected`.
-/
def testDeserializeProtodata' [Inhabited α] [DecidableEq β] [Repr β] [Proto.Message α]
(filename : String) (f : α → β) (expected : β) : TestCase IO :=
(filename : String) (f : α → Except String β) (expected : β) : TestCase IO :=
test s!"Deserialize {filename}" ⟨λ () => do
let buf ← IO.FS.readBinFile filename
let parsed : Except String α := Proto.Message.interpret? buf
match parsed with
| .ok req => checkEq (f req) expected
| .ok req => do
let actual ← IO.ofExcept (f req)
checkEq actual expected
| .error e => pure (.error e)

private def infallible (f : α → β) : α → Except ε β := λ a => .ok (f a)

/--
Convenience alias for `testDeserializeProtodata'` with `f := id`, that is, no
Convenience alias for `testDeserializeProtodata'` with `f := pure`, that is, no
transform necessary on the deserialized data before comparing to `expected`.
-/
def testDeserializeProtodata [Inhabited α] [DecidableEq α] [Repr α] [Proto.Message α]
(filename : String) (expected : α) : TestCase IO :=
testDeserializeProtodata' filename id expected
testDeserializeProtodata' filename pure expected

private def mkUid (path : List String) (ty : String) (eid : String) : Cedar.Spec.EntityUID :=
{ ty := { path, id := ty }, eid }
Expand Down Expand Up @@ -110,13 +114,13 @@ def tests := [
testDeserializeProtodata "UnitTest/CedarProto-test-data/emptyrecord.protodata"
(Cedar.Spec.Expr.record []),
testDeserializeProtodata' "UnitTest/CedarProto-test-data/record.protodata"
Cedar.Spec.Expr.mkWf
(infallible Cedar.Spec.Expr.mkWf)
(.record [
("eggs", .lit (.int (Int64.ofIntChecked 7 (by decide)))),
("ham", .lit (.int (Int64.ofIntChecked 3 (by decide)))),
]),
testDeserializeProtodata' "UnitTest/CedarProto-test-data/nested_record.protodata"
Cedar.Spec.Expr.mkWf
(infallible Cedar.Spec.Expr.mkWf)
(.record [
("eggs", .set [ .lit (.string "this is"), .lit (.string "a set") ]),
("ham", .record [
Expand Down Expand Up @@ -198,7 +202,7 @@ def tests := [
(.call .decimal [.lit (.string "3.1416")]),
]),
testDeserializeProtodata' "UnitTest/CedarProto-test-data/rbac.protodata"
Cedar.Spec.Policies.fromPolicySet
(infallible Cedar.Spec.Policies.fromPolicySet)
[{
id := "policy0"
effect := .permit
Expand All @@ -208,7 +212,7 @@ def tests := [
condition := [{ kind := .when, body := .lit (.bool true) }]
}],
testDeserializeProtodata' "UnitTest/CedarProto-test-data/abac.protodata"
Cedar.Spec.Policies.fromPolicySet
(infallible Cedar.Spec.Policies.fromPolicySet)
[{
id := "policy0"
effect := .permit
Expand All @@ -227,7 +231,7 @@ def tests := [
]
}],
testDeserializeProtodata' "UnitTest/CedarProto-test-data/policyset.protodata"
(Cedar.Spec.Policies.sortByPolicyId ∘ Cedar.Spec.Policies.fromPolicySet)
(infallible $ Cedar.Spec.Policies.sortByPolicyId ∘ Cedar.Spec.Policies.fromPolicySet)
[
{
id := "linkedpolicy"
Expand Down Expand Up @@ -288,11 +292,11 @@ def tests := [
},
],
testDeserializeProtodata' "UnitTest/CedarProto-test-data/policyset_just_templates.protodata"
(Cedar.Spec.Policies.sortByPolicyId ∘ Cedar.Spec.Policies.fromPolicySet)
(infallible $ Cedar.Spec.Policies.sortByPolicyId ∘ Cedar.Spec.Policies.fromPolicySet)
-- when it's just a template, it gets dropped in the Lean `Cedar.Spec.Policies` representation
[],
testDeserializeProtodata' "UnitTest/CedarProto-test-data/policyset_one_static_policy.protodata"
(Cedar.Spec.Policies.sortByPolicyId ∘ Cedar.Spec.Policies.fromPolicySet)
(infallible $ Cedar.Spec.Policies.sortByPolicyId ∘ Cedar.Spec.Policies.fromPolicySet)
[
{
id := ""
Expand All @@ -314,7 +318,7 @@ def tests := [
context := Map.make [ ("foo", .prim (.bool true)) ]
} : Cedar.Spec.Request),
testDeserializeProtodata' "UnitTest/CedarProto-test-data/entity.protodata"
Cedar.Spec.EntityProto.mkWf
(infallible Cedar.Spec.EntityProto.mkWf)
({
uid := mkUid ["A"] "B" "C"
data := {
Expand All @@ -336,7 +340,7 @@ def tests := [
}
}),
testDeserializeProtodata' "UnitTest/CedarProto-test-data/entities.protodata"
Cedar.Spec.EntitiesProto.toEntities
(infallible Cedar.Spec.EntitiesProto.toEntities)
((Map.make [
(mkUid [] "ABC" "123", { attrs := Map.empty, ancestors := Set.empty, tags := Map.empty }),
(mkUid [] "DEF" "234", { attrs := Map.empty, ancestors := Set.empty, tags := Map.empty }),
Expand Down

0 comments on commit e2dafd3

Please sign in to comment.