Skip to content

Commit

Permalink
[Authorization Logic :: Rust] Comparisons on Numbers (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
aferr authored Mar 10, 2022
1 parent 055572f commit b2069ff
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 63 deletions.
1 change: 1 addition & 0 deletions rust/tools/authorization-logic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ filegroup(
"src/test/test_multimic_overrides.rs",
"src/test/test_multiverse_handling.rs",
"src/test/test_negation.rs",
"src/test/test_num_compare.rs",
"src/test/test_num_string_names.rs",
"src/test/test_queries.rs",
"src/test/test_relation_declarations.rs",
Expand Down
28 changes: 27 additions & 1 deletion rust/tools/authorization-logic/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ impl PartialEq for AstPredicate {
// See https://doc.rust-lang.org/std/cmp/trait.Eq.html
impl Eq for AstPredicate {}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AstArithmeticComparison {
pub lnum: String ,
pub op: AstComparisonOperator,
pub rnum: String
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AstVerbPhrase {
AstPredPhrase { p: AstPredicate },
Expand All @@ -72,10 +79,29 @@ pub enum AstFact {
AstCanSayFact { p: AstPrincipal, f: Box<AstFact> },
}

#[derive(Copy, Debug, Clone, Serialize, Deserialize)]
pub enum AstComparisonOperator {
LessThan,
GreaterThan,
Equals,
NotEquals,
LessOrEquals,
GreaterOrEquals
}

// RValues are the expressions that can appear on the right hand side of a
// conditional assertion. At the time of writing these include either
// AstFlatFacts or AstArithmeticComparisons.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AstRValue {
FlatFactRValue { flat_fact: AstFlatFact},
ArithCompareRValue { arith_comp: AstArithmeticComparison }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AstAssertion {
AstFactAssertion { f: AstFact },
AstCondAssertion { lhs: AstFact, rhs: Vec<AstFlatFact> },
AstCondAssertion { lhs: AstFact, rhs: Vec<AstRValue> },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
34 changes: 31 additions & 3 deletions rust/tools/authorization-logic/src/parsing/AuthLogic.g4
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ principal
: ID
;

pred_arg
: ID
| NUMLITERAL
;

predicate
: (NEG)? ID '(' ID (',' ID)* ')'
: (NEG)? ID '(' pred_arg (',' pred_arg )* ')'
;

verbphrase
Expand All @@ -50,14 +55,28 @@ flatFact
| predicate #predFact
;

binop
: LESSTHAN #ltbinop
| GRTHAN #grbinop
| EQUALS #eqbinop
| NEQUALS #nebinop
| LEQ #leqbinop
| GEQ #geqbinop
;

rvalue
: flatFact #flatFactRvalue
| pred_arg binop pred_arg #binopRvalue
;

fact
: flatFact #flatFactFact
| principal CANSAY fact #canSayFact
;

assertion
: fact '.' #factAssertion
| fact ':-' flatFact (',' flatFact )* '.' #hornClauseAssertion
| fact ':-' rvalue (',' rvalue )* '.' #hornClauseAssertion
;

// The IDs following "Export" are path names where JSON files containing
Expand Down Expand Up @@ -120,10 +139,19 @@ ATTRIBUTE: 'attribute';

// Identifiers wrapped in quotes are constants whereas
// identifiers without quotes are variables.
ID : ('"')? [_a-zA-Z0-9/.#:]* ('"')?;
ID : ('"')? [_a-zA-Z][_a-zA-Z0-9/.#:]* ('"')?;
NUMLITERAL : [0-9]+;

NEG: '!';

// BINOPS
LESSTHAN: '<';
GRTHAN: '>';
EQUALS: '=';
NEQUALS: '!=';
LEQ: '<=';
GEQ: '>=';

WHITESPACE_IGNORE
: [ \r\t\n]+ -> skip
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ fn construct_predicate(ctx: &PredicateContext) -> AstPredicate {
Some(_) => Sign::Negated,
None => Sign::Positive
};
// Note that ID_all() in the generated antlr-rust code is buggy
// (because all {LEX_RULE}_all() generations are buggy),
// so rather than using a more idomatic iterator, "while Some(...)" is
// used here.
let name_ = ctx.ID(0).unwrap().get_text();
let mut args_ = Vec::new();
let mut idx = 1;
while let Some(id) = ctx.ID(idx) {
args_.push(id.get_text());
idx += 1;
}
let name_ = ctx.ID().unwrap().get_text();
let args_ = (&ctx).pred_arg_all()
.iter()
.map(|arg_ctx| arg_ctx.get_text())
.collect();
AstPredicate {
sign: sign_,
name: name_,
Expand All @@ -70,9 +64,7 @@ fn construct_verbphrase(ctx: &VerbphraseContextAll) -> AstVerbPhrase {
match ctx {
VerbphraseContextAll::PredphraseContext(pctx) => construct_predphrase(pctx),
VerbphraseContextAll::ActsAsPhraseContext(actx) => construct_actsasphrase(actx),
_ => {
panic!("construct_verbphrase tried to build error");
}
_ => { panic!("construct_verbphrase tried to build error"); }
}
}

Expand All @@ -92,9 +84,7 @@ fn construct_flat_fact(ctx: &FlatFactContextAll) -> AstFlatFact {
match ctx {
FlatFactContextAll::PrinFactContext(fctx) => construct_prin_fact(fctx),
FlatFactContextAll::PredFactContext(pctx) => construct_pred_fact(pctx),
_ => {
panic!("construct_flat_fact tried to build error");
}
_ => { panic!("construct_flat_fact tried to build error"); }
}
}

Expand All @@ -116,9 +106,7 @@ fn construct_fact(ctx: &FactContextAll) -> AstFact {
match ctx {
FactContextAll::FlatFactFactContext(fctx) => construct_flat_fact_fact(fctx),
FactContextAll::CanSayFactContext(sctx) => construct_can_say_fact(sctx),
_ => {
panic!("construct_fact tried to build error");
}
_ => { panic!("construct_fact tried to build error"); }
}
}

Expand All @@ -136,13 +124,43 @@ fn construct_can_say_fact(ctx: &CanSayFactContext) -> AstFact {
}
}

fn construct_binop(ctx: &BinopContextAll) -> AstComparisonOperator {
match ctx {
BinopContextAll::LtbinopContext(_) => AstComparisonOperator::LessThan,
BinopContextAll::GrbinopContext(_) => AstComparisonOperator::GreaterThan,
BinopContextAll::EqbinopContext(_) => AstComparisonOperator::Equals,
BinopContextAll::NebinopContext(_) => AstComparisonOperator::NotEquals,
BinopContextAll::LeqbinopContext(_) => AstComparisonOperator::LessOrEquals,
BinopContextAll::GeqbinopContext(_) => AstComparisonOperator::GreaterOrEquals,
_ => { panic!("construct_binop tried to build error"); }
}
}

fn construct_rvalue(ctx: &RvalueContextAll) -> AstRValue {
match ctx {
RvalueContextAll::FlatFactRvalueContext(ffctx) => {
AstRValue::FlatFactRValue {
flat_fact: construct_flat_fact(&ffctx.flatFact().unwrap())
}
},
RvalueContextAll::BinopRvalueContext(bctx) => {
AstRValue::ArithCompareRValue {
arith_comp: AstArithmeticComparison {
lnum: bctx.pred_arg(0).unwrap().get_text(),
op: construct_binop(&bctx.binop().unwrap()),
rnum: bctx.pred_arg(1).unwrap().get_text()
}
}
}
_ => { panic!("construct_rvalue tried to build error"); }
}
}

fn construct_assertion(ctx: &AssertionContextAll) -> AstAssertion {
match ctx {
AssertionContextAll::FactAssertionContext(fctx) => construct_fact_assertion(fctx),
AssertionContextAll::HornClauseAssertionContext(hctx) => construct_hornclause(hctx),
_ => {
panic!("construct_assertion tried to build error");
}
_ => { panic!("construct_assertion tried to build error"); }
}
}

Expand All @@ -153,10 +171,10 @@ fn construct_fact_assertion(ctx: &FactAssertionContext) -> AstAssertion {

fn construct_hornclause(ctx: &HornClauseAssertionContext) -> AstAssertion {
let lhs = construct_fact(&ctx.fact().unwrap());
let mut rhs = Vec::new();
for flat_fact_ctx in ctx.flatFact_all() {
rhs.push(construct_flat_fact(&flat_fact_ctx));
}
let rhs = ctx.rvalue_all()
.iter()
.map(|rvalue_ctx| construct_rvalue(&rvalue_ctx))
.collect();
AstAssertion::AstCondAssertion { lhs, rhs }
}

Expand Down Expand Up @@ -186,9 +204,7 @@ fn construct_says_assertion(ctx: &SaysAssertionContextAll) -> AstSaysAssertion {
export_file,
}
}
_ => {
panic!("construct_says_assertion tried to build Error()");
}
_ => { panic!("construct_says_assertion tried to build Error()"); }
}
}

Expand All @@ -215,9 +231,7 @@ fn construct_keybinding(ctx: &KeyBindContextAll) -> AstKeybind {
principal: construct_principal(&ctx_prime.principal().unwrap()),
is_pub: true,
},
_ => {
panic!("construct_keybinding tried to build Error()");
}
_ => { panic!("construct_keybinding tried to build Error()"); }
}
}

Expand All @@ -237,9 +251,7 @@ fn construct_type(ctx: &AuthLogicTypeContextAll) -> AstType {
AuthLogicTypeContextAll::CustomTypeContext(ctx_prime) => {
AstType::CustomType { type_name: ctx_prime.ID().unwrap().get_text() }
}
_ => {
panic!("construct_type tried to build error");
}
_ => { panic!("construct_type tried to build error"); }
}
}

Expand Down
11 changes: 10 additions & 1 deletion rust/tools/authorization-logic/src/souffle/datalog_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,23 @@
//! this authorization logic into datalog simpler.
use crate::ast::*;

// RValues are the expressions that can appear on the right hand side of a
// conditional assertion. At the time of writing these include either
// Predicates or AstArithmeticComparisons.
#[derive(Clone)]
pub enum DLIRRValue {
PredicateRValue { predicate: AstPredicate },
ArithCompareRValue { arith_comp: AstArithmeticComparison }
}

#[derive(Clone)]
pub enum DLIRAssertion {
DLIRFactAssertion {
p: AstPredicate,
},
DLIRCondAssertion {
lhs: AstPredicate,
rhs: Vec<AstPredicate>,
rhs: Vec<DLIRRValue>,
},
}

Expand Down
51 changes: 39 additions & 12 deletions rust/tools/authorization-logic/src/souffle/lowering_ast_datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@
use crate::{ast::*, souffle::datalog_ir::*};
use std::collections::HashMap;

fn pred_to_dlir_rvalue(pred: &AstPredicate) -> DLIRRValue {
DLIRRValue::PredicateRValue { predicate: pred.clone() }
}

// Note that this puts args_ on the front of the list of arguments because
// this is the conveninet way for it to work in the contexts in which it
// is used.
Expand Down Expand Up @@ -230,7 +234,10 @@ impl LoweringToDatalogPass {

let gen = DLIRAssertion::DLIRCondAssertion {
lhs: gen_lhs,
rhs: [s_says_x_as_p, s_says_p_v].to_vec(),
rhs: [s_says_x_as_p, s_says_p_v]
.into_iter()
.map(|pred| pred_to_dlir_rvalue(pred))
.collect(),
};

(pred, [gen].to_vec())
Expand Down Expand Up @@ -291,9 +298,10 @@ impl LoweringToDatalogPass {
&push_prin(String::from("canSay_"), &x, &fact_plus_prime),
);
// This is `p says fpf :- x says fpf, p says x canSay fpf`.
let mut rhs = Vec::new();
rhs.push(x_says_term);
rhs.push(can_say_term);
let rhs = [x_says_term, can_say_term]
.into_iter()
.map(|pred| pred_to_dlir_rvalue(pred))
.collect();
let gen = DLIRAssertion::DLIRCondAssertion { lhs, rhs };

collected.push(gen);
Expand All @@ -314,6 +322,22 @@ impl LoweringToDatalogPass {
.flatten()
.collect()
}

fn rvalue_to_dlir(
&mut self,
speaker: &AstPrincipal,
rvalue: &AstRValue) -> DLIRRValue {
match rvalue {
AstRValue::FlatFactRValue { flat_fact } => {
let (flat_pred, _) = self.flat_fact_to_dlir(&flat_fact, &speaker);
pred_to_dlir_rvalue(&push_prin(String::from("says_"),
&speaker, &flat_pred))
}
AstRValue::ArithCompareRValue { arith_comp } => {
DLIRRValue::ArithCompareRValue { arith_comp: arith_comp.clone() }
}
}
}

fn says_assertion_to_dlir_inner(
&mut self,
Expand All @@ -328,16 +352,16 @@ impl LoweringToDatalogPass {
gen_assert
}
AstAssertion::AstCondAssertion { lhs, rhs } => {
let mut dlir_rhs = Vec::new();
for f in rhs {
let (flat, _) = self.flat_fact_to_dlir(&f, &speaker);
dlir_rhs.push(push_prin(String::from("says_"), &speaker, &flat));
}
let (lhs_prime, mut assertions) = self.fact_to_dlir(&lhs, &speaker);
let dlir_lhs = push_prin(String::from("says_"), &speaker, &lhs_prime);
let dlir_lhs =
push_prin(String::from("says_"), &speaker, &lhs_prime);
let this_assertion = DLIRAssertion::DLIRCondAssertion {
lhs: dlir_lhs,
rhs: dlir_rhs,
rhs: rhs.clone()
.into_iter()
.map(|ast_rvalue| self.rvalue_to_dlir(
&speaker, &ast_rvalue))
.collect()
};
assertions.push(this_assertion);
assertions
Expand Down Expand Up @@ -367,7 +391,10 @@ impl LoweringToDatalogPass {
};
DLIRAssertion::DLIRCondAssertion {
lhs: lhs,
rhs: vec![main_fact, LoweringToDatalogPass::dummy_fact()],
rhs: [main_fact, LoweringToDatalogPass::dummy_fact()]
.into_iter()
.map(|pred| pred_to_dlir_rvalue(pred))
.collect()
}
}

Expand Down
Loading

0 comments on commit b2069ff

Please sign in to comment.