Skip to content

Commit

Permalink
bugfix: fix compare subquery bug in aggregation compare (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
fucangfy authored Mar 3, 2025
1 parent 866f428 commit 5bd9a2f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
46 changes: 46 additions & 0 deletions pkg/interpreter/translator/translator_ccl_input_for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,52 @@ var translateNumericTestCases = []sPair{
}

var translateWithCCLTestCases = []sPair{
{`SELECT ta.join_int_0 FROM alice.tbl_0 AS ta JOIN bob.tbl_0 AS tb ON ta.join_int_0 = tb.join_int_0 GROUP BY ta.join_int_0 HAVING SUM(ta.plain_int_0) = ANY(SELECT ta.plain_int_0 FROM alice.tbl_0 AS ta)`, `digraph G {
0 [label="runsql:{in:[],out:[Out:{t_0,t_1,},],attr:[sql:select ta.join_int_0,ta.plain_int_0 from alice.tbl_0 as ta;,table_refs:[alice.tbl_0],],party:[alice,]}"]
1 [label="runsql:{in:[],out:[Out:{t_2,},],attr:[sql:select tb.join_int_0 from bob.tbl_0 as tb;,table_refs:[bob.tbl_0],],party:[bob,]}"]
2 [label="join:{in:[Left:{t_0,},Right:{t_2,},],out:[LeftJoinIndex:{t_3,},RightJoinIndex:{t_4,},],attr:[input_party_codes:[alice bob],join_type:0,psi_algorithm:0,],party:[alice,bob,]}"]
3 [label="filter_by_index:{in:[Data:{t_0,t_1,},RowsIndexFilter:{t_3,},],out:[Out:{t_5,t_6,},],attr:[],party:[alice,]}"]
4 [label="group:{in:[Key:{t_5,},],out:[GroupId:{t_7,},GroupNum:{t_8,},],attr:[],party:[alice,]}"]
5 [label="sum:{in:[GroupId:{t_7,},GroupNum:{t_8,},In:{t_6,},],out:[Out:{t_9,},],attr:[],party:[alice,]}"]
6 [label="firstrow:{in:[GroupId:{t_7,},GroupNum:{t_8,},In:{t_5,},],out:[Out:{t_10,},],attr:[],party:[alice,]}"]
7 [label="count:{in:[GroupId:{t_7,},GroupNum:{t_8,},In:{t_7,},],out:[Out:{t_11,},],attr:[],party:[alice,]}"]
8 [label="make_constant:{in:[],out:[Out:{t_12,},],attr:[scalar:4,],party:[alice,bob,carol,]}"]
9 [label="broadcast:{in:[In:{t_12,},ShapeRefTensor:{t_11,},],out:[Out:{t_13,},],attr:[],party:[alice,]}"]
10 [label="GreaterEqual:{in:[Left:{t_11,},Right:{t_13,},],out:[Out:{t_14,},],attr:[],party:[alice,]}"]
11 [label="apply_filter:{in:[Filter:{t_14,},In:{t_10,t_9,t_11,},],out:[Out:{t_15,t_16,t_17,},],attr:[],party:[alice,]}"]
12 [label="runsql:{in:[],out:[Out:{t_18,},],attr:[sql:select ta.plain_int_0 from alice.tbl_0 as ta;,table_refs:[alice.tbl_0],],party:[alice,]}"]
13 [label="psi_in:{in:[Left:{t_16,},Right:{t_18,},],out:[Out:{t_19,},],attr:[in_type:2,input_party_codes:[alice alice],psi_algorithm:0,reveal_to:[alice],],party:[alice,alice,]}"]
14 [label="apply_filter:{in:[Filter:{t_19,},In:{t_15,},],out:[Out:{t_20,},],attr:[],party:[alice,]}"]
15 [label="publish:{in:[In:{t_20,},],out:[Out:{t_21,},],attr:[],party:[alice,]}"]
0 -> 2 [label = "t_0:{join_int_0:PRIVATE:INT64}"]
0 -> 3 [label = "t_0:{join_int_0:PRIVATE:INT64}"]
0 -> 3 [label = "t_1:{plain_int_0:PRIVATE:INT64}"]
1 -> 2 [label = "t_2:{join_int_0:PRIVATE:INT64}"]
10 -> 11 [label = "t_14:{GreaterEqual_out:PRIVATE:BOOL}"]
11 -> 13 [label = "t_16:{plain_int_0_sum:PRIVATE:INT64}"]
11 -> 14 [label = "t_15:{join_int_0_firstrow:PRIVATE:INT64}"]
12 -> 13 [label = "t_18:{plain_int_0:PRIVATE:INT64}"]
13 -> 14 [label = "t_19:{psi_in_out:PRIVATE:BOOL}"]
14 -> 15 [label = "t_20:{join_int_0_firstrow:PRIVATE:INT64}"]
2 -> 3 [label = "t_3:{join_int_0:PRIVATE:INT64}"]
3 -> 4 [label = "t_5:{join_int_0:PRIVATE:INT64}"]
3 -> 5 [label = "t_6:{plain_int_0:PRIVATE:INT64}"]
3 -> 6 [label = "t_5:{join_int_0:PRIVATE:INT64}"]
4 -> 5 [label = "t_7:{group_id:PRIVATE:INT64}"]
4 -> 5 [label = "t_8:{group_num:PRIVATE:INT64}"]
4 -> 6 [label = "t_7:{group_id:PRIVATE:INT64}"]
4 -> 6 [label = "t_8:{group_num:PRIVATE:INT64}"]
4 -> 7 [label = "t_7:{group_id:PRIVATE:INT64}"]
4 -> 7 [label = "t_7:{group_id:PRIVATE:INT64}"]
4 -> 7 [label = "t_8:{group_num:PRIVATE:INT64}"]
5 -> 11 [label = "t_9:{plain_int_0_sum:PRIVATE:INT64}"]
6 -> 11 [label = "t_10:{join_int_0_firstrow:PRIVATE:INT64}"]
7 -> 10 [label = "t_11:{group_id_count:PRIVATE:INT64}"]
7 -> 11 [label = "t_11:{group_id_count:PRIVATE:INT64}"]
7 -> 9 [label = "t_11:{group_id_count:PRIVATE:INT64}"]
8 -> 9 [label = "t_12:{constant_data:PUBLIC:INT64}"]
9 -> 10 [label = "t_13:{constant_data:PRIVATE:INT64}"]
}`, ``, testConf{groupThreshold: 0, batched: false}},
{`SELECT ta.join_int_0 FROM alice.tbl_0 AS ta JOIN bob.tbl_0 AS tb ON ta.join_int_0 = tb.join_int_0 GROUP BY ta.join_int_0 HAVING SUM(ta.plain_int_0) > ANY(SELECT ta.plain_int_0 FROM alice.tbl_0 AS ta)`, `digraph G {
0 [label="runsql:{in:[],out:[Out:{t_0,t_1,},],attr:[sql:select ta.join_int_0,ta.plain_int_0 from alice.tbl_0 as ta;,table_refs:[alice.tbl_0],],party:[alice,]}"]
1 [label="runsql:{in:[],out:[Out:{t_2,},],attr:[sql:select tb.join_int_0 from bob.tbl_0 as tb;,table_refs:[bob.tbl_0],],party:[bob,]}"]
Expand Down
1 change: 1 addition & 0 deletions pkg/planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
if er.err != nil {
return
}
er.asScalar = true // scql change
er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, not)
}

Expand Down

0 comments on commit 5bd9a2f

Please sign in to comment.