Skip to content

Commit

Permalink
bug fix and unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Shaobo He <shaobohe@amazon.com>
  • Loading branch information
shaobo-he-aws committed Feb 27, 2025
1 parent 4cb69eb commit e553a1d
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cedar-lean/Cedar/TPE/Evaluator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ inductive Error where
| invalidPolicy (err : TypeError)
| inValidEnvironment
| invalidRequestOrEntities
deriving Repr
deriving Repr, Inhabited, DecidableEq

instance : Coe Spec.Error Error where
coe := Error.evaluation
Expand All @@ -51,7 +51,7 @@ def ite (c t e : Residual)(ty : CedarType) : Residual :=
| _ =>
.ite c t e ty

def and (l r : Residual)(ty : CedarType) : Residual :=
def and (l r : Residual) (ty : CedarType) : Residual :=
match l, r with
| .val true _, _ => r
| .val false _, _ => false
Expand Down
2 changes: 1 addition & 1 deletion cedar-lean/Cedar/TPE/Input.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def PartialEntities.ancestors (es : PartialEntities) (uid : EntityUID) : Option

def PartialEntities.tags (es : PartialEntities) (uid : EntityUID) : Option (Map Tag Value) := es.get uid PartialEntityData.tags

def PartialEntities.attrs (es : PartialEntities) (uid : EntityUID) : Option (Map Tag Value) := es.get uid PartialEntityData.tags
def PartialEntities.attrs (es : PartialEntities) (uid : EntityUID) : Option (Map Tag Value) := es.get uid PartialEntityData.attrs

def partialIsValid {α} (o : Option α) (f : α → Bool) : Bool :=
(o.map f).getD true
Expand Down
70 changes: 69 additions & 1 deletion cedar-lean/Cedar/TPE/Residual.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ inductive Residual where
| record (map : List (Attr × Residual)) (ty : CedarType)
| call (xfn : ExtFun) (args : List Residual) (ty : CedarType)
| error (ty : CedarType)
deriving Repr
deriving Repr, Inhabited

instance : Coe Bool Residual where
coe b := .val (Prim.bool b) (.bool .anyBool)
Expand Down Expand Up @@ -113,4 +113,72 @@ decreasing_by
have := List.sizeOf_lt_of_mem h
omega

mutual

def decResidual (x y : Residual) : Decidable (x = y) := by
cases x <;> cases y <;>
try { apply isFalse ; intro h ; injection h }
case val.val x₁ tx y₁ ty | var.var x₁ tx y₁ ty =>
exact match decEq x₁ y₁, decEq tx ty with
| isTrue h₁, isTrue h₂ => isTrue (by rw [h₁, h₂])
| isFalse _, _ | _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case ite.ite x₁ x₂ x₃ tx y₁ y₂ y₃ ty =>
exact match decResidual x₁ y₁, decResidual x₂ y₂, decResidual x₃ y₃, decEq tx ty with
| isTrue h₁, isTrue h₂, isTrue h₃, isTrue h₄ => isTrue (by rw [h₁, h₂, h₃, h₄])
| isFalse _, _, _, _ | _, isFalse _, _, _ | _, _, isFalse _, _ | _, _, _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case and.and x₁ x₂ tx y₁ y₂ ty | or.or x₁ x₂ tx y₁ y₂ ty =>
exact match decResidual x₁ y₁, decResidual x₂ y₂, decEq tx ty with
| isTrue h₁, isTrue h₂, isTrue h₃ => isTrue (by rw [h₁, h₂, h₃])
| isFalse _, _, _ | _, isFalse _, _ | _, _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case unaryApp.unaryApp o x₁ tx o' y₁ ty =>
exact match decEq o o', decResidual x₁ y₁, decEq tx ty with
| isTrue h₁, isTrue h₂, isTrue h₃ => isTrue (by rw [h₁, h₂, h₃])
| isFalse _, _, _ | _, isFalse _, _ | _, _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case binaryApp.binaryApp o x₁ x₂ tx o' y₁ y₂ ty =>
exact match decEq o o', decResidual x₁ y₁, decResidual x₂ y₂, decEq tx ty with
| isTrue h₁, isTrue h₂, isTrue h₃, isTrue h₄ => isTrue (by rw [h₁, h₂, h₃, h₄])
| isFalse _, _, _, _ | _, isFalse _, _, _ | _, _, isFalse _, _ | _, _, _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case getAttr.getAttr x₁ a tx y₁ a' ty | hasAttr.hasAttr x₁ a tx y₁ a' ty =>
exact match decResidual x₁ y₁, decEq a a', decEq tx ty with
| isTrue h₁, isTrue h₂, isTrue h₃ => isTrue (by rw [h₁, h₂, h₃])
| isFalse _, _, _ | _, isFalse _, _ | _, _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case set.set xs tx ys ty =>
exact match decResidualList xs ys, decEq tx ty with
| isTrue h₁, isTrue h₂ => isTrue (by rw [h₁, h₂])
| isFalse _, _ | _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case record.record axs tx ays ty =>
exact match decProdAttrResidualList axs ays, decEq tx ty with
| isTrue h₁, isTrue h₂ => isTrue (by rw [h₁, h₂])
| isFalse _, _ | _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case call.call f xs tx f' ys ty =>
exact match decEq f f', decResidualList xs ys, decEq tx ty with
| isTrue h₁, isTrue h₂, isTrue h₃ => isTrue (by rw [h₁, h₂, h₃])
| isFalse _, _, _ | _, isFalse _, _ | _, _, isFalse _ => isFalse (by intro h; injection h; contradiction)
case error.error ty₁ ty₂ =>
exact match decEq ty₁ ty₂ with
| isTrue h₁ => isTrue (by rw [h₁])
| isFalse _ => isFalse (by intro h; injection h; contradiction)

def decProdAttrResidualList (axs ays : List (Prod Attr Residual)) : Decidable (axs = ays) :=
match axs, ays with
| [], [] => isTrue rfl
| _::_, [] | [], _::_ => isFalse (by intro; contradiction)
| (a, x)::axs, (a', y)::ays =>
match decEq a a', decResidual x y, decProdAttrResidualList axs ays with
| isTrue h₁, isTrue h₂, isTrue h₃ => isTrue (by rw [h₁, h₂, h₃])
| isFalse _, _, _ | _, isFalse _, _ | _, _, isFalse _ =>
isFalse (by simp; intros; first | contradiction | assumption)

def decResidualList (xs ys : List Residual) : Decidable (xs = ys) :=
match xs, ys with
| [], [] => isTrue rfl
| _::_, [] | [], _::_ => isFalse (by intro; contradiction)
| x::xs, y::ys =>
match decResidual x y, decResidualList xs ys with
| isTrue h₁, isTrue h₂ => isTrue (by rw [h₁, h₂])
| isFalse _, _ | _, isFalse _ => isFalse (by intro h; injection h; contradiction)
end

instance : DecidableEq Residual := decResidual

end Cedar.TPE
4 changes: 3 additions & 1 deletion cedar-lean/UnitTest/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import UnitTest.Decimal
import UnitTest.IPAddr
import UnitTest.Proto
import UnitTest.Wildcard
import UnitTest.TPE

open UnitTest

Expand All @@ -29,7 +30,8 @@ def tests :=
IPAddr.tests ++
Wildcard.tests ++
Proto.tests ++
CedarProto.tests
CedarProto.tests ++
TPE.tests

def main : IO UInt32 := do
TestSuite.runAll tests
58 changes: 54 additions & 4 deletions cedar-lean/UnitTest/TPE.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Cedar.TPE.Evaluator
import Cedar.Spec.Expr
import Cedar.Validation.Types
import Cedar.Data.Map
import UnitTest.Run

namespace UnitTest.TPE.Basic

Expand Down Expand Up @@ -211,9 +212,30 @@ def es : PartialEntities :=
(⟨UserType, "Alice"⟩, ⟨.some default, .some default, default⟩)
]

#eval (tpePolicy schema policy₁ req es)
#eval (tpePolicy schema policy₂ req es)
#eval (tpePolicy schema policy₃ req es)
private def testResult (p : Policy) (r : Residual) : TestCase IO :=
test s!"policy {p.id}" ⟨λ _ => checkEq (evaluatePolicy schema p req es) (.ok r)⟩

def tests :=
suite "TPE results for the RFC basic example"
[
testResult policy₁
(.getAttr (.var .resource (.entity { id := "Document", path := [] }))
"isPublic"
(.bool .anyBool)),
testResult policy₂
(.binaryApp
.eq
(.getAttr
(.var .resource (.entity { id := "Document", path := [] }))
"owner"
(.entity { id := "User", path := [] }))
(.val
(.prim (.entityUID { ty := { id := "User", path := [] }, eid := "Alice" }))
(.entity { id := "User", path := [] }))
(.bool .anyBool)),
testResult policy₃ (.val false (.bool .ff))
]
--#eval TestSuite.runAll [tests]

end UnitTest.TPE.Basic

Expand Down Expand Up @@ -335,5 +357,33 @@ def es : PartialEntities :=
(⟨UserType, "Alice"⟩, ⟨.some $ Map.make [("address", .record $ Map.make [("street", "Sesame Street")])], .some default, default⟩)
]

#eval tpePolicy schema policy req es
private def testResult (p : Policy) (r : Residual) : TestCase IO :=
test s!"policy {p.id}" ⟨λ _ => checkEq (evaluatePolicy schema p req es) (.ok r)⟩

def tests :=
suite "TPE results for the RFC basic example"
[
testResult policy
(.binaryApp
.eq
(.val
(.record
(Map.mk [("street", .prim (.string "Sesame Street"))]))
(.record AddressType))
(.getAttr
(.var
.resource
(.entity { id := "Package", path := [] }))
"address"
(.record AddressType))
(.bool .anyBool))
]
-- #eval TestSuite.runAll [tests]

end UnitTest.TPE.Motivation

namespace UnitTest.TPE

def tests := [Basic.tests, Motivation.tests]

end UnitTest.TPE

0 comments on commit e553a1d

Please sign in to comment.