From 16c03e114bf2d8f4e0b2f3db5645eb1f3c5a37cd Mon Sep 17 00:00:00 2001 From: Longyue Li Date: Tue, 11 Jun 2024 12:57:36 +0800 Subject: [PATCH] =?UTF-8?q?refactor(merger):=20=E9=87=8D=E6=9E=84AVG?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=AE=9E=E7=8E=B0,=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=89=80=E6=9C=89rows.Rows=E5=AE=9E=E7=8E=B0=E7=9A=84ConlumnTy?= =?UTF-8?q?pe=E6=96=B9=E6=B3=95=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=20(#223)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + internal/merger/factory/factory.go | 25 +- internal/merger/factory/factory_test.go | 439 ++++++++------ .../aggregatemerger/aggregator/avg.go | 16 +- .../aggregatemerger/aggregator/avg_test.go | 18 +- .../aggregatemerger/aggregator/count.go | 10 +- .../aggregatemerger/aggregator/count_test.go | 5 +- .../aggregatemerger/aggregator/max.go | 10 +- .../aggregatemerger/aggregator/max_test.go | 7 +- .../aggregatemerger/aggregator/min.go | 10 +- .../aggregatemerger/aggregator/min_test.go | 7 +- .../aggregatemerger/aggregator/sum.go | 10 +- .../aggregatemerger/aggregator/sum_test.go | 5 +- .../aggregatemerger/aggregator/type.go | 8 +- .../merger/internal/aggregatemerger/merger.go | 41 +- .../internal/aggregatemerger/merger_test.go | 546 +++++++++++++----- .../merger/internal/batchmerger/merger.go | 6 + .../internal/batchmerger/merger_test.go | 151 +++-- .../groupbymerger/aggregator_merger.go | 38 +- .../groupbymerger/aggregator_merger_test.go | 275 +++++---- .../merger/internal/pagedmerger/merger.go | 6 + .../internal/pagedmerger/merger_test.go | 133 +++-- internal/merger/internal/sortmerger/merger.go | 22 +- .../merger/internal/sortmerger/merger_test.go | 192 +++--- internal/merger/type.go | 18 +- 25 files changed, 1350 insertions(+), 649 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index b24c5f3..442d0eb 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -38,6 +38,7 @@ - [script: 注释掉无用命令及代码、固定ci中golangci-lint的版本使其与setup.sh中版本保持一致](https://github.com/ecodeclub/eorm/pull/220) - [doc: 修复README中不可用的贡献者指南链接](https://github.com/ecodeclub/eorm/pull/221) - [feat(merger): 定义中立的特征表达数据、定义工厂方法根据特征数据来获取具体的merger](https://github.com/ecodeclub/eorm/pull/222) +- [refactor(merger): 重构AVG函数实现,重构所有rows.Rows实现的ConlumnType方法并添加测试](https://github.com/ecodeclub/eorm/pull/223) ## v0.0.1: - [Init Project](https://github.com/ecodeclub/eorm/pull/1) - [Selector Definition](https://github.com/ecodeclub/eorm/pull/2) diff --git a/internal/merger/factory/factory.go b/internal/merger/factory/factory.go index 2d68f23..d2526ac 100644 --- a/internal/merger/factory/factory.go +++ b/internal/merger/factory/factory.go @@ -105,7 +105,7 @@ func (q QuerySpec) validateGroupBy() error { return fmt.Errorf("%w: groupby %v", ErrInvalidColumnInfo, c.Name) } // 清除ASC - c.ASC = false + c.Order = merger.DESC if !slice.Contains(q.Select, c) { return fmt.Errorf("%w: groupby %v", ErrColumnNotFoundInSelectList, c.Name) } @@ -134,7 +134,7 @@ func (q QuerySpec) validateOrderBy() error { return fmt.Errorf("%w: orderby %v", ErrInvalidColumnInfo, c.Name) } // 清除ASC - c.ASC = false + c.Order = merger.DESC if !slice.Contains(q.Select, c) { return fmt.Errorf("%w: orderby %v", ErrColumnNotFoundInSelectList, c.Name) } @@ -164,7 +164,7 @@ func newAggregateMerger(origin, target QuerySpec) (merger.Merger, error) { return aggregatemerger.NewMerger(aggregators...), nil } -func getAggregators(origin QuerySpec, target QuerySpec) []aggregator.Aggregator { +func getAggregators(_, target QuerySpec) []aggregator.Aggregator { var aggregators []aggregator.Aggregator for i := 0; i < len(target.Select); i++ { c := target.Select[i] @@ -175,12 +175,11 @@ func getAggregators(origin QuerySpec, target QuerySpec) []aggregator.Aggregator case "MAX": aggregators = append(aggregators, aggregator.NewMax(c)) log.Printf("max index = %d\n", c.Index) + case "AVG": + aggregators = append(aggregators, aggregator.NewAVG(c, target.Select[i+1], target.Select[i+2])) + i += 2 + log.Printf("avg index = %d\n", c.Index) case "SUM": - if i < len(origin.Select) && strings.ToUpper(origin.Select[i].AggregateFunc) == "AVG" { - aggregators = append(aggregators, aggregator.NewAVG(c, target.Select[i+1], origin.Select[i].SelectName())) - i += 1 - continue - } aggregators = append(aggregators, aggregator.NewSum(c)) log.Printf("sum index = %d\n", c.Index) case "COUNT": @@ -202,22 +201,22 @@ func newOrderByMerger(origin, target QuerySpec) (merger.Merger, error) { for i := 0; i < len(target.OrderBy); i++ { c := target.OrderBy[i] if i < len(origin.OrderBy) && strings.ToUpper(origin.OrderBy[i].AggregateFunc) == "AVG" { - s := sortmerger.NewSortColumn(origin.OrderBy[i].SelectName(), sortmerger.Order(origin.OrderBy[i].ASC)) + s := sortmerger.NewSortColumn(origin.OrderBy[i].SelectName(), sortmerger.Order(origin.OrderBy[i].Order)) columns = append(columns, s) i++ continue } - s := sortmerger.NewSortColumn(c.SelectName(), sortmerger.Order(c.ASC)) + s := sortmerger.NewSortColumn(c.SelectName(), sortmerger.Order(c.Order)) columns = append(columns, s) } - var isScanAll bool + var isPreScanAll bool if slice.Contains(target.Features, query.GroupBy) { - isScanAll = true + isPreScanAll = true } log.Printf("sortColumns = %#v\n", columns) - return sortmerger.NewMerger(isScanAll, columns...) + return sortmerger.NewMerger(isPreScanAll, columns...) } func New(origin, target QuerySpec) (merger.Merger, error) { diff --git a/internal/merger/factory/factory_test.go b/internal/merger/factory/factory_test.go index 6a6ff73..42acfcb 100644 --- a/internal/merger/factory/factory_test.go +++ b/internal/merger/factory/factory_test.go @@ -153,7 +153,7 @@ func TestNew(t *testing.T) { { Index: 0, // 索引排序? amount没有出现在SELECT子句,出现在orderBy子句中 Name: "amount", - ASC: true, + Order: merger.ASC, }, }, }, @@ -418,7 +418,7 @@ func (s *factoryTestSuite) TestSELECT() { requireErrFunc require.ErrorAssertionFunc after func(t *testing.T, rows rows.Rows, expectedColumnNames []string) }{ - // 非法情况 + // SELECT { sql: "应该报错_QuerySpec.Select列为空", before: func(t *testing.T, sql string) ([]rows.Rows, []string) { @@ -592,11 +592,11 @@ func (s *factoryTestSuite) TestSELECT() { sql: "SELECT MIN(`amount`),MAX(`amount`),AVG(`amount`),SUM(`amount`),COUNT(`amount`) FROM `orders` WHERE (`order_id` > 10 AND `amount` > 20) OR `order_id` > 100 OR `amount` > 30", before: func(t *testing.T, sql string) ([]rows.Rows, []string) { t.Helper() - targetSQL := "SELECT MIN(`amount`),MAX(`amount`),SUM(`amount`), COUNT(`amount`), SUM(`amount`), COUNT(`amount`) FROM `orders`" - cols := []string{"MIN(`amount`)", "MAX(`amount`)", "SUM(`amount`)", "COUNT(`amount`)", "SUM(`amount`)", "COUNT(`amount`)"} - s.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 200, 400, 2, 400, 2)) - s.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(150, 150, 450, 3, 450, 3)) - s.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(50, 50, 50, 1, 50, 1)) + targetSQL := "SELECT MIN(`amount`),MAX(`amount`),AVG(`amount`),SUM(`amount`), COUNT(`amount`), SUM(`amount`), COUNT(`amount`) FROM `orders`" + cols := []string{"MIN(`amount`)", "MAX(`amount`)", "AVG(`amount`)", "SUM(`amount`)", "COUNT(`amount`)", "SUM(`amount`)", "COUNT(`amount`)"} + s.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 200, 200, 400, 2, 400, 2)) + s.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(150, 150, 150, 450, 3, 450, 3)) + s.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(50, 50, 50, 50, 1, 50, 1)) return getResultSet(t, targetSQL, s.db01, s.db02, s.db03), cols }, originSpec: QuerySpec{ @@ -645,21 +645,26 @@ func (s *factoryTestSuite) TestSELECT() { { Index: 2, Name: "`amount`", - AggregateFunc: "SUM", + AggregateFunc: "AVG", }, { Index: 3, Name: "`amount`", - AggregateFunc: "COUNT", + AggregateFunc: "SUM", }, { Index: 4, Name: "`amount`", - AggregateFunc: "SUM", + AggregateFunc: "COUNT", }, { Index: 5, Name: "`amount`", + AggregateFunc: "SUM", + }, + { + Index: 6, + Name: "`amount`", AggregateFunc: "COUNT", }, }, @@ -685,7 +690,7 @@ func (s *factoryTestSuite) TestSELECT() { sum := 200*2 + 150*3 + 50 cnt := 6 - avg := float64(sum / cnt) + avg := float64(sum) / float64(cnt) require.Equal(t, []any{ []any{50, 200, avg, sum, cnt}, }, getRowValues(t, r, scanFunc)) @@ -743,7 +748,7 @@ func (s *factoryTestSuite) TestSELECT() { { Index: 0, Name: "`ctime`", - ASC: true, + Order: merger.ASC, }, }, }, @@ -759,7 +764,7 @@ func (s *factoryTestSuite) TestSELECT() { { Index: 0, Name: "`ctime`", - ASC: true, + Order: merger.ASC, }, }, }, @@ -798,13 +803,13 @@ func (s *factoryTestSuite) TestSELECT() { Index: 0, Name: "`user_id`", Alias: "`uid`", - ASC: true, + Order: merger.ASC, }, { Index: 1, Name: "`order_id`", Alias: "`oid`", - ASC: false, + Order: merger.DESC, }, }, }, @@ -827,13 +832,13 @@ func (s *factoryTestSuite) TestSELECT() { Index: 0, Name: "`user_id`", Alias: "`uid`", - ASC: true, + Order: merger.ASC, }, { Index: 1, Name: "`order_id`", Alias: "`oid`", - ASC: false, + Order: merger.DESC, }, }, }, @@ -866,92 +871,91 @@ func (s *factoryTestSuite) TestSELECT() { }, getRowValues(t, r, scanFunc)) }, }, - // TODO: ORDER BY 和 与聚合列组合,原始SQL中ORDER BY中用别名`avg_amt`,目标SQL的ORDER BY该如何该写? - // { - // sql: "SELECT AVG(`amount`) AS `avg_amt` FROM `orders` ORDER BY `avg_amt`", - // - // before: func(t *testing.T, sql string) ([]rows.Rows, []string) { - // t.Helper() - // targetSQL := "SELECT SUM(`amount`), COUNT(`amount`) FROM `orders` ORDER BY SUM(`amount`), COUNT(`amount`)" - // cols := []string{"SUM(`amount`)", "COUNT(`amount`)"} - // s.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 4)) - // s.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(150, 2)) - // s.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(40, 1)) - // return s.getResultSet(t, targetSQL, s.db01, s.db02, s.db03), cols - // }, - // originSpec: QuerySpec{ - // Features: []query.Feature{query.AggregateFunc, query.OrderBy}, - // Select: []merger.ColumnInfo{ - // { - // Index: 0, - // Name: "`amount`", - // AggregateFunc: "AVG", - // Alias: "`avg_amt`", - // }, - // }, - // OrderBy: []merger.ColumnInfo{ - // { - // Index: 0, - // Name: "`amount`", - // AggregateFunc: "AVG", - // Alias: "`avg_amt`", - // ASC: true, - // }, - // }, - // }, - // targetSpec: QuerySpec{ - // Features: []query.Feature{query.AggregateFunc, query.OrderBy}, - // Select: []merger.ColumnInfo{ - // { - // Index: 0, - // Name: "`amount`", - // AggregateFunc: "SUM", - // }, - // { - // Index: 1, - // Name: "`amount`", - // AggregateFunc: "COUNT", - // }, - // }, - // OrderBy: []merger.ColumnInfo{ - // // pipline中的后者,需要根据原SQL中的Orderby - // { - // Index: 0, - // Name: "`amount`", - // AggregateFunc: "SUM", - // ASC: true, - // }, - // { - // Index: 1, - // Name: "`amount`", - // AggregateFunc: "COUNT", - // ASC: true, - // }, - // }, - // }, - // requireErrFunc: require.NoError, - // after: func(t *testing.T, r rows.Rows, _ []string) { - // t.Helper() - // cols := []string{"`avg_amt`"} - // columnsNames, err := r.Columns() - // require.NoError(t, err) - // require.Equal(t, cols, columnsNames) - // - // scanFunc := func(rr rows.Rows, valSet *[]any) error { - // var avg float64 - // if err := rr.Scan(&avg); err != nil { - // return err - // } - // *valSet = append(*valSet, []any{avg}) - // return nil - // } - // - // avg := float64(200+150+40) / float64(4+2+1) - // require.Equal(t, []any{ - // []any{avg}, - // }, s.getRowValues(t, r, scanFunc)) - // }, - // }, + // 聚合函数 + ORDER BY + { + sql: "SELECT AVG(`amount`) AS `avg_amt` FROM `orders` ORDER BY `avg_amt`", + + before: func(t *testing.T, sql string) ([]rows.Rows, []string) { + t.Helper() + targetSQL := "SELECT AVG(`amount`) AS `avg_amt`, SUM(`amount`), COUNT(`amount`) FROM `orders` ORDER BY SUM(`amount`), COUNT(`amount`)" + cols := []string{"`avg_amt`", "SUM(`amount`)", "COUNT(`amount`)"} + s.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(50, 200, 4)) + s.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(75, 150, 2)) + s.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(40, 40, 1)) + return getResultSet(t, targetSQL, s.db01, s.db02, s.db03), cols + }, + originSpec: QuerySpec{ + Features: []query.Feature{query.AggregateFunc, query.OrderBy}, + Select: []merger.ColumnInfo{ + { + Index: 0, + Name: "`amount`", + AggregateFunc: "AVG", + Alias: "`avg_amt`", + }, + }, + OrderBy: []merger.ColumnInfo{ + { + Index: 0, + Name: "`amount`", + AggregateFunc: "AVG", + Alias: "`avg_amt`", + Order: true, + }, + }, + }, + targetSpec: QuerySpec{ + Features: []query.Feature{query.AggregateFunc, query.OrderBy}, + Select: []merger.ColumnInfo{ + { + Index: 0, + Name: "`amount`", + AggregateFunc: "AVG", + Alias: "`avg_amt`", + }, + { + Index: 1, + Name: "`amount`", + AggregateFunc: "SUM", + }, + { + Index: 2, + Name: "`amount`", + AggregateFunc: "COUNT", + }, + }, + OrderBy: []merger.ColumnInfo{ + { + Index: 0, + Name: "`amount`", + AggregateFunc: "AVG", + Alias: "`avg_amt`", + }, + }, + }, + requireErrFunc: require.NoError, + after: func(t *testing.T, r rows.Rows, _ []string) { + t.Helper() + cols := []string{"`avg_amt`"} + columnsNames, err := r.Columns() + require.NoError(t, err) + require.Equal(t, cols, columnsNames) + + scanFunc := func(rr rows.Rows, valSet *[]any) error { + var avg float64 + if err := rr.Scan(&avg); err != nil { + return err + } + *valSet = append(*valSet, []any{avg}) + return nil + } + + avg := float64(200+150+40) / float64(4+2+1) + require.Equal(t, []any{ + []any{avg}, + }, getRowValues(t, r, scanFunc)) + }, + }, { // TODO: 暂时用该测试用例替换上方avg案例,当avg问题修复后,该测试用例应该删除 sql: "SELECT COUNT(`amount`) AS `cnt_amt` FROM `orders` ORDER BY `cnt_amt`", @@ -981,7 +985,7 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "COUNT", Alias: "`cnt_amt`", - ASC: true, + Order: merger.ASC, }, }, }, @@ -1001,7 +1005,7 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "COUNT", Alias: "`cnt_amt`", - ASC: true, + Order: merger.ASC, }, }, }, @@ -1073,7 +1077,7 @@ func (s *factoryTestSuite) TestSELECT() { { Index: 1, Name: "`ctime`", - ASC: true, + Order: merger.ASC, }, }, }, @@ -1089,7 +1093,7 @@ func (s *factoryTestSuite) TestSELECT() { { Index: 1, Name: "`ctime`", - ASC: true, + Order: merger.ASC, }, }, }, @@ -1584,13 +1588,13 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "SUM", Alias: "`total_amt`", - ASC: true, + Order: merger.ASC, }, { Index: 0, Name: "`user_id`", Alias: "`uid`", - ASC: false, + Order: merger.DESC, }, }, }, @@ -1632,13 +1636,13 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "SUM", Alias: "`total_amt`", - ASC: true, + Order: merger.ASC, }, { Index: 0, Name: "`user_id`", Alias: "`uid`", - ASC: false, + Order: merger.DESC, }, }, }, @@ -1838,7 +1842,7 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "SUM", Alias: "`total_amt`", - ASC: false, + Order: merger.DESC, }, }, Limit: 2, @@ -1872,7 +1876,7 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "SUM", Alias: "`total_amt`", - ASC: false, + Order: merger.DESC, }, }, Limit: 2, @@ -1951,7 +1955,7 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "SUM", Alias: "`total_amt`", - ASC: true, + Order: merger.ASC, }, }, Limit: 6, @@ -1995,7 +1999,7 @@ func (s *factoryTestSuite) TestSELECT() { Name: "`amount`", AggregateFunc: "SUM", Alias: "`total_amt`", - ASC: true, + Order: merger.ASC, }, }, Limit: 6, @@ -2029,76 +2033,141 @@ func (s *factoryTestSuite) TestSELECT() { }, getRowValues(t, r, scanFunc)) }, }, - // { - // TODO: 聚合 + 非聚合 + groupby + orderby + limit - - // sql: "SELECT `user_id`, COUNT(`amount`) AS `order_count`, AVG(`amount`) FROM `orders` GROUP BY `user_id` ORDER BY `order_count` DESC, `user_id` DESC Limit 3 OFFSET 0", - // before: func(t *testing.T, sql string) []rows.Rows { - // t.Helper() - // targetSQL := sql - // cols := []string{"`user_id`", "AVG(`amount`)", "COUNT(*)"} - // s.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 100, 4).AddRow(3, 150, 2)) - // s.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(4, 200, 1)) - // s.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 450, 3)) - // return s.getResultSet(t, targetSQL, s.db01, s.db02, s.db03) - // }, - // spec: QuerySpec{ - // Features: []query.Feature{query.GroupBy, query.OrderBy, query.Limit}, - // Select: []merger.ColumnInfo{ - // { - // Index: 0, - // Name: "`user_id`", - // }, - // { - // Index: 1, - // Name: "AVG(`amount`)", - // AggregateFunc: "AVG", - // }, - // { - // Index: 2, - // Name: "COUNT(*)", - // AggregateFunc: "COUNT", - // }, - // }, - // GroupBy: []merger.ColumnInfo{ - // { - // Index: 0, - // Name: "user_id", - // }, - // }, - // OrderBy: []merger.ColumnInfo{ - // { - // Index: 1, - // Name: "COUNT(*)", - // IsASCOrder: true, - // }, - // }, - // Limit: 2, - // Offset: 0, - // }, - // requireErrFunc: require.NoError, - // after: func(t *testing.T, r rows.Rows) { - // t.Helper() - // scanFunc := func(rr rows.Rows, valSet *[]any) error { - // log.Printf("before rr = %#vscan = %#v", rr, *valSet) - // var uid, cnt int - // var avgAmt float64 - // if err := rr.Scan(&uid, &avgAmt, &cnt); err != nil { - // return err - // } - // *valSet = append(*valSet, []any{uid, avgAmt, cnt}) - // return nil - // } - // // 4, 200, 1 - // // 3, 150, 2 - // // 2, 450, 3, - // // 1, 100, 4, - // require.Equal(t, []any{ - // []any{4, float64(200), 1}, - // []any{3, float64(150), 2}, - // }, s.getRowValues(t, r, scanFunc)) - // }, - // }, + // 聚合 + 非聚合 + GROUP BY + ORDER BY + LIMIT + { + sql: "SELECT `user_id`, COUNT(`amount`) AS `order_count`, AVG(`amount`) FROM `orders` GROUP BY `user_id` ORDER BY `order_count` DESC, `user_id` DESC Limit 4 OFFSET 0", + before: func(t *testing.T, sql string) ([]rows.Rows, []string) { + t.Helper() + targetSQL := "SELECT `user_id`, COUNT(`amount`) AS `order_count`, AVG(`amount`), SUM(`amount`), COUNT(`amount`) FROM `orders` GROUP BY `user_id` ORDER BY `order_count` DESC, `user_id` DESC Limit 3 OFFSET 0" + cols := []string{"`user_id`", "`order_count`", "AVG(`amount`)", "SUM(`amount`)", "COUNT(`amount`)"} + s.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 4, 100, 400, 4).AddRow(3, 2, 150, 300, 2)) + s.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(4, 1, 200, 200, 1).AddRow(3, 1, 150, 150, 1)) + s.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 3, 450, 1350, 3).AddRow(5, 1, 50, 50, 1)) + return getResultSet(t, targetSQL, s.db01, s.db02, s.db03), cols + }, + originSpec: QuerySpec{ + Features: []query.Feature{query.GroupBy, query.OrderBy, query.Limit}, + Select: []merger.ColumnInfo{ + { + Index: 0, + Name: "`user_id`", + }, + { + Index: 1, + Name: "`amount`", + AggregateFunc: "COUNT", + Alias: "`order_count`", + }, + { + Index: 2, + Name: "`amount`", + AggregateFunc: "AVG", + }, + }, + GroupBy: []merger.ColumnInfo{ + { + Index: 0, + Name: "`user_id`", + }, + }, + OrderBy: []merger.ColumnInfo{ + { + Index: 1, + Name: "`amount`", + AggregateFunc: "COUNT", + Alias: "`order_count`", + }, + { + Index: 0, + Name: "`user_id`", + }, + }, + Limit: 4, + Offset: 0, + }, + targetSpec: QuerySpec{ + Features: []query.Feature{query.GroupBy, query.OrderBy, query.Limit}, + Select: []merger.ColumnInfo{ + { + Index: 0, + Name: "`user_id`", + }, + { + Index: 1, + Name: "`amount`", + AggregateFunc: "COUNT", + Alias: "`order_count`", + }, + { + Index: 2, + Name: "`amount`", + AggregateFunc: "AVG", + }, + { + Index: 3, + Name: "`amount`", + AggregateFunc: "SUM", + }, + { + Index: 4, + Name: "`amount`", + AggregateFunc: "COUNT", + }, + }, + GroupBy: []merger.ColumnInfo{ + { + Index: 0, + Name: "`user_id`", + }, + }, + OrderBy: []merger.ColumnInfo{ + { + Index: 1, + Name: "`amount`", + AggregateFunc: "COUNT", + Alias: "`order_count`", + }, + { + Index: 0, + Name: "`user_id`", + }, + }, + Limit: 4, + Offset: 0, + }, + requireErrFunc: require.NoError, + after: func(t *testing.T, r rows.Rows, _ []string) { + t.Helper() + expectedColumnNames := []string{"`user_id`", "`order_count`", "AVG(`amount`)"} + columnsNames, err := r.Columns() + require.NoError(t, err) + require.Equal(t, expectedColumnNames, columnsNames) + + types, err := r.ColumnTypes() + require.NoError(t, err) + typeNames := make([]string, 0, len(types)) + for _, typ := range types { + typeNames = append(typeNames, typ.Name()) + } + require.Equal(t, expectedColumnNames, typeNames) + + scanFunc := func(rr rows.Rows, valSet *[]any) error { + var uid, cnt int + var avgAmt float64 + if err := rr.Scan(&uid, &cnt, &avgAmt); err != nil { + return err + } + *valSet = append(*valSet, []any{uid, cnt, avgAmt}) + return nil + } + require.Equal(t, []any{ + []any{1, 7, float64(250)}, + []any{3, 3, float64(150)}, + []any{5, 1, float64(50)}, + []any{4, 1, float64(200)}, + }, getRowValues(t, r, scanFunc)) + }, + }, } for _, tt := range tests { t.Run(tt.sql, func(t *testing.T) { diff --git a/internal/merger/internal/aggregatemerger/aggregator/avg.go b/internal/merger/internal/aggregatemerger/aggregator/avg.go index d6aabf0..c6140a9 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/avg.go +++ b/internal/merger/internal/aggregatemerger/aggregator/avg.go @@ -25,17 +25,19 @@ import ( // AVG 用于求平均值,通过sum/count求得。 // AVG 我们并不能预期在不同的数据库上,精度会不会损失,以及损失的话会有多少的损失。这很大程度上跟数据库类型,数据库驱动实现都有关 type AVG struct { + name string + avgColumnInfo merger.ColumnInfo sumColumnInfo merger.ColumnInfo countColumnInfo merger.ColumnInfo - avgName string } // NewAVG sumInfo是sum的信息,countInfo是count的信息,avgName用于Column方法 -func NewAVG(sumInfo merger.ColumnInfo, countInfo merger.ColumnInfo, avgName string) *AVG { +func NewAVG(avgInfo, sumInfo, countInfo merger.ColumnInfo) *AVG { return &AVG{ + name: "AVG", + avgColumnInfo: avgInfo, sumColumnInfo: sumInfo, countColumnInfo: countInfo, - avgName: avgName, } } @@ -63,8 +65,12 @@ func (a *AVG) findAvgFunc(col []any) (func([][]any, int, int) (float64, error), return val, nil } -func (a *AVG) ColumnName() string { - return a.avgName +func (a *AVG) ColumnInfo() merger.ColumnInfo { + return a.avgColumnInfo +} + +func (a *AVG) Name() string { + return a.name } // avgAggregator cols就是上面Aggregate的入参cols可以参Aggregate的描述 diff --git a/internal/merger/internal/aggregatemerger/aggregator/avg_test.go b/internal/merger/internal/aggregatemerger/aggregator/avg_test.go index 9790171..abbf539 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/avg_test.go +++ b/internal/merger/internal/aggregatemerger/aggregator/avg_test.go @@ -36,34 +36,39 @@ func TestAvg_Aggregate(t *testing.T) { name: "avg正常合并", input: [][]any{ { + float64(10) / float64(2), int64(10), int64(2), }, { + float64(20) / float64(2), int64(20), int64(2), }, { + float64(30) / float64(2), int64(30), int64(2), }, }, - index: []int{0, 1}, + index: []int{0, 1, 2}, wantVal: float64(10), }, { name: "传入的参数非AggregateElement类型", input: [][]any{ { + "1.5", "1", "2", }, { + "0.75", "3", "4", }, }, - index: []int{0, 1}, + index: []int{0, 1, 2}, wantErr: errs.ErrMergerAggregateFuncNotFound, }, { @@ -78,20 +83,23 @@ func TestAvg_Aggregate(t *testing.T) { int64(2), }, }, - index: []int{0, 10}, + index: []int{0, 3, 10}, wantErr: errs.ErrMergerInvalidAggregateColumnIndex, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - avg := NewAVG(merger.NewColumnInfo(tc.index[0], "SUM(grade)"), merger.NewColumnInfo(tc.index[1], "COUNT(grade)"), "AVG(grade)") + avgColumnInfo := merger.ColumnInfo{Index: tc.index[0], Name: "`grade`", AggregateFunc: "AVG"} + avg := NewAVG(avgColumnInfo, + merger.ColumnInfo{Index: tc.index[1], Name: "`grade`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: tc.index[2], Name: "`grade`", AggregateFunc: "COUNT"}) val, err := avg.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { return } assert.Equal(t, tc.wantVal, val) - assert.Equal(t, "AVG(grade)", avg.ColumnName()) + assert.Equal(t, avgColumnInfo, avg.ColumnInfo()) }) } diff --git a/internal/merger/internal/aggregatemerger/aggregator/count.go b/internal/merger/internal/aggregatemerger/aggregator/count.go index 373a489..a0d75b0 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/count.go +++ b/internal/merger/internal/aggregatemerger/aggregator/count.go @@ -23,6 +23,7 @@ import ( ) type Count struct { + name string countInfo merger.ColumnInfo } @@ -47,12 +48,17 @@ func (s *Count) findCountFunc(col []any) (func([][]any, int) (any, error), error return countFunc, nil } -func (s *Count) ColumnName() string { - return s.countInfo.SelectName() +func (s *Count) ColumnInfo() merger.ColumnInfo { + return s.countInfo +} + +func (s *Count) Name() string { + return s.name } func NewCount(info merger.ColumnInfo) *Count { return &Count{ + name: "COUNT", countInfo: info, } } diff --git a/internal/merger/internal/aggregatemerger/aggregator/count_test.go b/internal/merger/internal/aggregatemerger/aggregator/count_test.go index 81583c3..30d77ae 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/count_test.go +++ b/internal/merger/internal/aggregatemerger/aggregator/count_test.go @@ -80,14 +80,15 @@ func TestCount_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - count := NewCount(merger.NewColumnInfo(tc.countIndex, "COUNT(id)")) + info := merger.ColumnInfo{Index: tc.countIndex, Name: "id", AggregateFunc: "COUNT"} + count := NewCount(info) val, err := count.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { return } assert.Equal(t, tc.wantVal, val) - assert.Equal(t, "COUNT(id)", count.ColumnName()) + assert.Equal(t, info, count.ColumnInfo()) }) } diff --git a/internal/merger/internal/aggregatemerger/aggregator/max.go b/internal/merger/internal/aggregatemerger/aggregator/max.go index b37fcca..15410de 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/max.go +++ b/internal/merger/internal/aggregatemerger/aggregator/max.go @@ -23,6 +23,7 @@ import ( ) type Max struct { + name string maxColumnInfo merger.ColumnInfo } @@ -47,12 +48,17 @@ func (m *Max) findMaxFunc(col []any) (func([][]any, int) (any, error), error) { return countFunc, nil } -func (m *Max) ColumnName() string { - return m.maxColumnInfo.SelectName() +func (m *Max) ColumnInfo() merger.ColumnInfo { + return m.maxColumnInfo +} + +func (m *Max) Name() string { + return m.name } func NewMax(info merger.ColumnInfo) *Max { return &Max{ + name: "MAX", maxColumnInfo: info, } } diff --git a/internal/merger/internal/aggregatemerger/aggregator/max_test.go b/internal/merger/internal/aggregatemerger/aggregator/max_test.go index 54071e4..d5e3399 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/max_test.go +++ b/internal/merger/internal/aggregatemerger/aggregator/max_test.go @@ -79,14 +79,15 @@ func TestMax_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - max := NewMax(merger.NewColumnInfo(tc.maxIndex, "MAX(id)")) - val, err := max.Aggregate(tc.input) + info := merger.ColumnInfo{Index: tc.maxIndex, Name: "id", AggregateFunc: "MAX"} + m := NewMax(info) + val, err := m.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { return } assert.Equal(t, tc.wantVal, val) - assert.Equal(t, "MAX(id)", max.ColumnName()) + assert.Equal(t, info, m.maxColumnInfo) }) } diff --git a/internal/merger/internal/aggregatemerger/aggregator/min.go b/internal/merger/internal/aggregatemerger/aggregator/min.go index 3d2f8fa..99812e7 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/min.go +++ b/internal/merger/internal/aggregatemerger/aggregator/min.go @@ -23,6 +23,7 @@ import ( ) type Min struct { + name string minColumnInfo merger.ColumnInfo } @@ -48,12 +49,17 @@ func (m *Min) findMinFunc(col []any) (func([][]any, int) (any, error), error) { return minFunc, nil } -func (m *Min) ColumnName() string { - return m.minColumnInfo.SelectName() +func (m *Min) ColumnInfo() merger.ColumnInfo { + return m.minColumnInfo +} + +func (m *Min) Name() string { + return m.name } func NewMin(info merger.ColumnInfo) *Min { return &Min{ + name: "MIN", minColumnInfo: info, } } diff --git a/internal/merger/internal/aggregatemerger/aggregator/min_test.go b/internal/merger/internal/aggregatemerger/aggregator/min_test.go index 917c584..943e957 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/min_test.go +++ b/internal/merger/internal/aggregatemerger/aggregator/min_test.go @@ -79,14 +79,15 @@ func TestMin_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - min := NewMin(merger.NewColumnInfo(tc.minIndex, "MIN(id)")) - val, err := min.Aggregate(tc.input) + info := merger.ColumnInfo{Index: tc.minIndex, Name: "id", AggregateFunc: "MIN"} + m := NewMin(info) + val, err := m.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { return } assert.Equal(t, tc.wantVal, val) - assert.Equal(t, "MIN(id)", min.ColumnName()) + assert.Equal(t, info, m.ColumnInfo()) }) } diff --git a/internal/merger/internal/aggregatemerger/aggregator/sum.go b/internal/merger/internal/aggregatemerger/aggregator/sum.go index 048ec69..04c6430 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/sum.go +++ b/internal/merger/internal/aggregatemerger/aggregator/sum.go @@ -23,6 +23,7 @@ import ( ) type Sum struct { + name string sumColumnInfo merger.ColumnInfo } @@ -48,12 +49,17 @@ func (s *Sum) findSumFunc(col []any) (func([][]any, int) (any, error), error) { return sumFunc, nil } -func (s *Sum) ColumnName() string { - return s.sumColumnInfo.SelectName() +func (s *Sum) ColumnInfo() merger.ColumnInfo { + return s.sumColumnInfo +} + +func (s *Sum) Name() string { + return s.name } func NewSum(info merger.ColumnInfo) *Sum { return &Sum{ + name: "SUM", sumColumnInfo: info, } } diff --git a/internal/merger/internal/aggregatemerger/aggregator/sum_test.go b/internal/merger/internal/aggregatemerger/aggregator/sum_test.go index c9edf81..e7e14ec 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/sum_test.go +++ b/internal/merger/internal/aggregatemerger/aggregator/sum_test.go @@ -81,14 +81,15 @@ func TestSum_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - sum := NewSum(merger.NewColumnInfo(tc.sumIndex, "SUM(id)")) + info := merger.ColumnInfo{Index: tc.sumIndex, Name: "id", AggregateFunc: "SUM"} + sum := NewSum(info) val, err := sum.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { return } assert.Equal(t, tc.wantVal, val) - assert.Equal(t, "SUM(id)", sum.ColumnName()) + assert.Equal(t, info, sum.ColumnInfo()) }) } diff --git a/internal/merger/internal/aggregatemerger/aggregator/type.go b/internal/merger/internal/aggregatemerger/aggregator/type.go index 8a29c1e..bd80623 100644 --- a/internal/merger/internal/aggregatemerger/aggregator/type.go +++ b/internal/merger/internal/aggregatemerger/aggregator/type.go @@ -14,6 +14,8 @@ package aggregator +import "github.com/ecodeclub/eorm/internal/merger" + type AggregateElement interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 } @@ -21,6 +23,8 @@ type AggregateElement interface { type Aggregator interface { // Aggregate 将多个列聚合 cols表示sqlRows列表里的数据,聚合函数通过下标拿到需要的列 Aggregate(cols [][]any) (any, error) - // ColumnName 聚合函数的别名 - ColumnName() string + // ColumnInfo 聚合列的信息 + ColumnInfo() merger.ColumnInfo + // Name 聚合函数本身的名称, MIN/MAX/SUM/COUNT/AVG + Name() string } diff --git a/internal/merger/internal/aggregatemerger/merger.go b/internal/merger/internal/aggregatemerger/merger.go index 19e32c7..b4eac9c 100644 --- a/internal/merger/internal/aggregatemerger/merger.go +++ b/internal/merger/internal/aggregatemerger/merger.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "errors" + "fmt" "sync" _ "unsafe" @@ -34,16 +35,23 @@ import ( type Merger struct { aggregators []aggregator.Aggregator colNames []string + avgIndexes []int } func NewMerger(aggregators ...aggregator.Aggregator) *Merger { cols := make([]string, 0, len(aggregators)) + idx := make([]int, 0, len(aggregators)) for _, agg := range aggregators { - cols = append(cols, agg.ColumnName()) + info := agg.ColumnInfo() + if agg.Name() == "AVG" { + idx = append(idx, info.Index) + } + cols = append(cols, info.SelectName()) } return &Merger{ aggregators: aggregators, colNames: cols, + avgIndexes: idx, } } @@ -63,9 +71,11 @@ func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, err return &Rows{ rowsList: results, aggregators: m.aggregators, + avgIndexes: m.avgIndexes, mu: &sync.RWMutex{}, - // 聚合函数AVG传递到各个sql.Rows时会被转化为SUM和COUNT,这是一个对外不可见的转化。 - // 所以merger.Rows的列名及顺序是由上方aggregator出现的顺序及ColumnName()的返回值决定的而不是sql.Rows。 + // 原SQL中的聚合函数AVG传递到各个sql.Rows时会被转化为目标SQL中的AVG,SUM和COUNT三个列,这是一个对外不可见的转化。 + // 其中AVG仅用于获取ColumnType,真正的结果值是基于SUM和COUNT计算得到的 + // 所以设置aggregators要与目标SQL对齐, 而得到的merger.Rows应该与原SQL对齐的而不是目标SQL. columns: m.colNames, }, nil @@ -74,6 +84,7 @@ func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, err type Rows struct { rowsList []rows.Rows aggregators []aggregator.Aggregator + avgIndexes []int closed bool mu *sync.RWMutex lastErr error @@ -83,9 +94,26 @@ type Rows struct { } func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { - // TOTO: 应该返回 AVG 对应的名字和类型 - // rowsList[0].ColumnTypes 返回 SUM, COUNT 是我们该写后的, 抽象有破口 - return r.rowsList[0].ColumnTypes() + r.mu.Lock() + defer r.mu.Unlock() + if r.closed { + return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed) + } + ts, err := r.rowsList[0].ColumnTypes() + if err != nil { + return nil, err + } + if len(r.avgIndexes) == 0 { + return ts, nil + } + v := make([]*sql.ColumnType, 0, len(ts)) + var prev int + for i := 0; i < len(r.avgIndexes); i++ { + idx := r.avgIndexes[i] + v = append(v, ts[prev:idx+1]...) + prev = idx + 3 + } + return v, nil } func (*Rows) NextResultSet() bool { @@ -153,6 +181,7 @@ func (r *Rows) getSqlRowsData() ([][]any, error) { } return rowsData, nil } + func (*Rows) getSqlRowData(row rows.Rows) ([]any, error) { var colsData []any var err error diff --git a/internal/merger/internal/aggregatemerger/merger_test.go b/internal/merger/internal/aggregatemerger/merger_test.go index ded3f32..7f9a490 100644 --- a/internal/merger/internal/aggregatemerger/merger_test.go +++ b/internal/merger/internal/aggregatemerger/merger_test.go @@ -63,6 +63,11 @@ func (ms *MergerSuite) SetupTest() { } func (ms *MergerSuite) TearDownTest() { + + ms.NoError(ms.mock01.ExpectationsWereMet()) + ms.NoError(ms.mock02.ExpectationsWereMet()) + ms.NoError(ms.mock03.ExpectationsWereMet()) + _ = ms.mockDB01.Close() _ = ms.mockDB02.Close() _ = ms.mockDB03.Close() @@ -77,19 +82,19 @@ func (ms *MergerSuite) initMock(t *testing.T) { " id INT PRIMARY KEY NOT NULL," + " grade INT NOT NULL" + ");" - ms.mockDB01, ms.mock01, err = sqlmock.New() + ms.mockDB01, ms.mock01, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB02, ms.mock02, err = sqlmock.New() + ms.mockDB02, ms.mock02, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB03, ms.mock03, err = sqlmock.New() + ms.mockDB03, ms.mock03, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB04, ms.mock04, err = sqlmock.New() + ms.mockDB04, ms.mock04, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } @@ -125,9 +130,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { require.NoError(ms.T(), err) cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.db05} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -145,7 +150,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}), } }, }, @@ -154,9 +159,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -174,7 +179,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}), } }, }, @@ -184,9 +189,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"MAX(id)"} query := "SELECT MAX(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -204,7 +209,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewMax(merger.NewColumnInfo(0, "MAX(id)")), + aggregator.NewMax(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "MAX"}), } }, }, @@ -213,9 +218,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"MIN(id)"} query := "SELECT MIN(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -233,7 +238,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewMin(merger.NewColumnInfo(0, "MIN(id)")), + aggregator.NewMin(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "MIN"}), } }, }, @@ -242,9 +247,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"COUNT(id)"} query := "SELECT COUNT(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -262,18 +267,18 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewCount(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "COUNT"}), } }, }, { name: "AVG(grade)", sqlRows: func() []rows.Rows { - cols := []string{"SUM(grade)", "COUNT(grade)"} - query := "SELECT SUM(`grade`),COUNT(`grade`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10)) + cols := []string{"AVG(`grade`)", "SUM(`grade`)", "COUNT(`grade`)"} + query := "SELECT AVG(`grade`) AS `avg_grade`, SUM(`grade`),COUNT(`grade`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 2000, 10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(100, 2000, 20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 2000, 10)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -293,7 +298,11 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: `grade`, AggregateFunc: "AVG", Alias: "`avg_grade`"}, + merger.ColumnInfo{Index: 1, Name: `grade`, AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: `grade`, AggregateFunc: "COUNT"}, + ), } }, }, @@ -303,11 +312,11 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { name: "COUNT(id),MAX(id),MIN(id),SUM(id),AVG(grade)", sqlRows: func() []rows.Rows { - cols := []string{"COUNT(id)", "MAX(id)", "MIN(id)", "SUM(id)", "SUM(grade)", "COUNT(grade)"} - query := "SELECT COUNT(`id`),MAX(`id`),MIN(`id`),SUM(`id`),SUM(`grade`),COUNT(`student`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 20, 1, 100, 2000, 20)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20, 30, 0, 200, 800, 10)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 40, 2, 300, 1800, 20)) + cols := []string{"COUNT(id)", "MAX(id)", "MIN(id)", "SUM(id)", "AVG(`grade`)", "SUM(grade)", "COUNT(grade)"} + query := "SELECT COUNT(`id`),MAX(`id`),MIN(`id`),SUM(`id`),AVG(`grade`),SUM(`grade`),COUNT(`grade`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 20, 1, 100, 100, 2000, 20)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20, 30, 0, 200, 80, 800, 10)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 40, 2, 300, 90, 1800, 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -327,11 +336,15 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), - aggregator.NewMax(merger.NewColumnInfo(1, "MAX(id)")), - aggregator.NewMin(merger.NewColumnInfo(2, "MIN(id)")), - aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")), - aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewCount(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "COUNT"}), + aggregator.NewMax(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "MAX"}), + aggregator.NewMin(merger.ColumnInfo{Index: 2, Name: "id", AggregateFunc: "MIN"}), + aggregator.NewSum(merger.ColumnInfo{Index: 3, Name: "id", AggregateFunc: "SUM"}), + aggregator.NewAVG( + merger.ColumnInfo{Index: 4, Name: `grade`, AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 5, Name: `grade`, AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 6, Name: `grade`, AggregateFunc: "COUNT"}, + ), } }, }, @@ -340,11 +353,11 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { name: "AVG(grade),SUM(grade),AVG(grade),MIN(id),MIN(userid),MAX(id),COUNT(id)", sqlRows: func() []rows.Rows { - cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(grade)", "COUNT(grade)", "SUM(grade)", "MIN(id)", "MIN(userid)", "MAX(id)", "COUNT(id)"} - query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`grade`),COUNT(`grade`),SUM(`grade`),MIN(`id`),MIN(`userid`),MAX(`id`),COUNT(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20, 2000, 20, 2000, 10, 20, 200, 200)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1000, 10, 1000, 10, 1000, 20, 30, 300, 300)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(800, 10, 800, 10, 800, 5, 6, 100, 200)) + cols := []string{"AVG(grade)", "SUM(grade)", "COUNT(grade)", "SUM(grade)", "AVG(grade)", "COUNT(grade)", "SUM(grade)", "MIN(id)", "MIN(userid)", "MAX(id)", "COUNT(id)"} + query := "SELECT AVG(`grade`), SUM(`grade`),COUNT(`grade`),SUM(`grade`),AVG(`grade`),COUNT(`grade`),SUM(`grade`),MIN(`id`),MIN(`userid`),MAX(`id`),COUNT(`id`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(100, 2000, 20, 2000, 100, 20, 2000, 10, 20, 200, 200)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(100, 1000, 10, 1000, 100, 10, 1000, 20, 30, 300, 300)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(80, 800, 10, 800, 80, 10, 800, 5, 6, 100, 200)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -364,13 +377,21 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(merger.NewColumnInfo(2, "SUM(grade)")), - aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewMin(merger.NewColumnInfo(5, "MIN(id)")), - aggregator.NewMin(merger.NewColumnInfo(6, "MIN(userid)")), - aggregator.NewMax(merger.NewColumnInfo(7, "MAX(id)")), - aggregator.NewCount(merger.NewColumnInfo(8, "COUNT(id)")), + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: `grade`, AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 1, Name: `grade`, AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: `grade`, AggregateFunc: "COUNT"}, + ), + aggregator.NewSum(merger.ColumnInfo{Index: 3, Name: "grade", AggregateFunc: "SUM"}), + aggregator.NewAVG( + merger.ColumnInfo{Index: 4, Name: `grade`, AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 6, Name: `grade`, AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 5, Name: `grade`, AggregateFunc: "COUNT"}, + ), + aggregator.NewMin(merger.ColumnInfo{Index: 7, Name: "id", AggregateFunc: "MIN"}), + aggregator.NewMin(merger.ColumnInfo{Index: 8, Name: "userid", AggregateFunc: "MIN"}), + aggregator.NewMax(merger.ColumnInfo{Index: 9, Name: "id", AggregateFunc: "MAX"}), + aggregator.NewCount(merger.ColumnInfo{Index: 10, Name: "id", AggregateFunc: "COUNT"}), } }, }, @@ -383,10 +404,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -405,7 +426,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}), } }, }, @@ -415,10 +436,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB04, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -437,7 +458,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}), } }, }, @@ -447,10 +468,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -469,7 +490,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}), } }, }, @@ -479,10 +500,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"SUM(id)"} query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -495,7 +516,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}), } }, }, @@ -503,18 +524,18 @@ func (ms *MergerSuite) TestRows_NextAndScan() { for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { m := NewMerger(tc.aggregators()...) - rows, err := m.Merge(context.Background(), tc.sqlRows()) + r, err := m.Merge(context.Background(), tc.sqlRows()) require.NoError(t, err) - for rows.Next() { + for r.Next() { kk := make([]any, 0, len(tc.gotVal)) for i := 0; i < len(tc.gotVal); i++ { kk = append(kk, &tc.gotVal[i]) } - err = rows.Scan(kk...) + err = r.Scan(kk...) require.NoError(t, err) } - assert.Equal(t, tc.wantErr, rows.Err()) - if rows.Err() != nil { + assert.Equal(t, tc.wantErr, r.Err()) + if r.Err() != nil { return } assert.Equal(t, tc.wantVal, tc.gotVal) @@ -534,10 +555,10 @@ func (ms *MergerSuite) TestRows_NextAndErr() { rowsList: func() []rows.Rows { cols := []string{"COUNT(id)"} query := "SELECT COUNT(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4).RowError(0, nextMockErr)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(4).RowError(0, nextMockErr)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -549,7 +570,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { }, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewCount(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "COUNT"}), } }(), wantErr: nextMockErr, @@ -559,10 +580,10 @@ func (ms *MergerSuite) TestRows_NextAndErr() { rowsList: func() []rows.Rows { cols := []string{"COUNT(id)"} query := "SELECT COUNT(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(4)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -582,81 +603,85 @@ func (ms *MergerSuite) TestRows_NextAndErr() { } for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { - merger := NewMerger(tc.aggregators...) - rows, err := merger.Merge(context.Background(), tc.rowsList()) + m := NewMerger(tc.aggregators...) + r, err := m.Merge(context.Background(), tc.rowsList()) require.NoError(t, err) - for rows.Next() { + for r.Next() { } count := int64(0) - err = rows.Scan(&count) + err = r.Scan(&count) assert.Equal(t, tc.wantErr, err) - assert.Equal(t, tc.wantErr, rows.Err()) + assert.Equal(t, tc.wantErr, r.Err()) }) } } func (ms *MergerSuite) TestRows_Close() { cols := []string{"SUM(id)"} - query := "SELECT SUM(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02"))) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03"))) - merger := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) + targetSQL := "SELECT SUM(`id`) FROM `t1`" + ms.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02"))) + ms.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03"))) + m := NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"})) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { - row, err := db.QueryContext(context.Background(), query) + row, err := db.QueryContext(context.Background(), targetSQL) require.NoError(ms.T(), err) rowsList = append(rowsList, row) } - rows, err := merger.Merge(context.Background(), rowsList) + r, err := m.Merge(context.Background(), rowsList) require.NoError(ms.T(), err) // 判断当前是可以正常读取的 - require.True(ms.T(), rows.Next()) + require.True(ms.T(), r.Next()) var id int - err = rows.Scan(&id) + err = r.Scan(&id) require.NoError(ms.T(), err) - err = rows.Close() - ms.T().Run("close返回multierror", func(t *testing.T) { + err = r.Close() + ms.T().Run("close返回multiError", func(t *testing.T) { assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err) }) ms.T().Run("close之后Next返回false", func(t *testing.T) { for i := 0; i < len(rowsList); i++ { require.False(ms.T(), rowsList[i].Next()) } - require.False(ms.T(), rows.Next()) + require.False(ms.T(), r.Next()) }) ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) { var id int - err := rows.Scan(&id) + err := r.Scan(&id) assert.Equal(t, errs.ErrMergerRowsClosed, err) }) ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) { - _, err := rows.Columns() + _, err := r.Columns() require.Error(t, err) }) ms.T().Run("close多次是等效的", func(t *testing.T) { for i := 0; i < 4; i++ { - err = rows.Close() + err = r.Close() require.NoError(t, err) } }) } func (ms *MergerSuite) TestRows_Columns() { - cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} - query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12)) + cols := []string{"AVG(grade)", "SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} + query := "SELECT AVG(`grade`), SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 1, 2, 1, 3, 10)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 2, 1, 3, 2, 4, 11)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 3, 1, 4, 3, 5, 12)) aggregators := []aggregator.Aggregator{ - aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), - aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), - aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), - aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: `grade`, AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 1, Name: `grade`, AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: `grade`, AggregateFunc: "COUNT"}, + ), + aggregator.NewSum(merger.ColumnInfo{Index: 3, Name: "id", AggregateFunc: "SUM"}), + aggregator.NewMin(merger.ColumnInfo{Index: 4, Name: "id", AggregateFunc: "MIN"}), + aggregator.NewMax(merger.ColumnInfo{Index: 5, Name: "id", AggregateFunc: "MAX"}), + aggregator.NewCount(merger.ColumnInfo{Index: 6, Name: "id", AggregateFunc: "COUNT"}), } - merger := NewMerger(aggregators...) + m := NewMerger(aggregators...) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -665,21 +690,21 @@ func (ms *MergerSuite) TestRows_Columns() { rowsList = append(rowsList, row) } - rows, err := merger.Merge(context.Background(), rowsList) + r, err := m.Merge(context.Background(), rowsList) require.NoError(ms.T(), err) wantCols := []string{"AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} ms.T().Run("Next没有迭代完", func(t *testing.T) { - for rows.Next() { - columns, err := rows.Columns() + for r.Next() { + columns, err := r.Columns() require.NoError(t, err) assert.Equal(t, wantCols, columns) } - require.NoError(t, rows.Err()) + require.NoError(t, r.Err()) }) ms.T().Run("Next迭代完", func(t *testing.T) { - require.False(t, rows.Next()) - require.NoError(t, rows.Err()) - _, err := rows.Columns() + require.False(t, r.Next()) + require.NoError(t, r.Err()) + _, err := r.Columns() assert.Equal(t, errs.ErrMergerRowsClosed, err) }) } @@ -695,7 +720,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "超时", merger: func() *Merger { - return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"})) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), 0) @@ -706,16 +731,16 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT SUM(`id`) FROM `t1`;" cols := []string{"SUM(id)"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) - rows, _ := ms.mockDB01.QueryContext(context.Background(), query) - res = append(res, rows) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + r, _ := ms.mockDB01.QueryContext(context.Background(), query) + res = append(res, r) return res }, }, { name: "sqlRows列表元素个数为0", merger: func() *Merger { - return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"})) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) @@ -729,7 +754,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "sqlRows列表有nil", merger: func() *Merger { - return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"})) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) @@ -756,6 +781,243 @@ func (ms *MergerSuite) TestMerger_Merge() { } } +func (ms *MergerSuite) TestMerger_ColumnTypes() { + t := ms.T() + + tests := []struct { + sql string + before func(t *testing.T, sql string) ([]rows.Rows, []string) + columns []aggregator.Aggregator + requireErrFunc require.ErrorAssertionFunc + after func(t *testing.T, r rows.Rows, expectedColumnNames []string) + }{ + { + sql: "SELECT SUM(`grade`) FROM `t1`", + before: func(t *testing.T, sql string) ([]rows.Rows, []string) { + t.Helper() + targetSQL := sql + cols := []string{"SUM(`grade`)"} + ms.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(400)) + ms.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(120)) + ms.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(80)) + return getResultSet(t, targetSQL, ms.mockDB01, ms.mockDB02, ms.mockDB03), cols + + }, + columns: []aggregator.Aggregator{ + aggregator.NewSum( + merger.ColumnInfo{Index: 0, Name: "`grade`", AggregateFunc: "SUM"}, + ), + }, + requireErrFunc: require.NoError, + after: func(t *testing.T, r rows.Rows, _ []string) { + t.Helper() + + expectedColumnNames := []string{"SUM(`grade`)"} + columns, err := r.Columns() + require.NoError(t, err) + require.Equal(t, expectedColumnNames, columns) + + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, expectedColumnNames, names) + + scanFunc := func(rr rows.Rows, valSet *[]any) error { + var sumGrade int + if err := rr.Scan(&sumGrade); err != nil { + return err + } + *valSet = append(*valSet, []any{sumGrade}) + return nil + } + + require.Equal(t, []any{ + []any{600}, + }, getRowValues(t, r, scanFunc)) + }, + }, + { + sql: "SELECT AVG(`grade`) AS `avg_grade` FROM `t1`", + before: func(t *testing.T, sql string) ([]rows.Rows, []string) { + t.Helper() + targetSQL := "SELECT AVG(`grade`) AS `avg_grade`, SUM(`grade`), COUNT(`grade`) FROM `t1`" + cols := []string{"`avg_grade`", "SUM(`grade`)", "COUNT(`grade`)"} + ms.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 400, 2)) + ms.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(40, 120, 3)) + ms.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(80, 80, 1)) + return getResultSet(t, targetSQL, ms.mockDB01, ms.mockDB02, ms.mockDB03), cols + + }, + columns: []aggregator.Aggregator{ + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: "`grade`", AggregateFunc: "AVG", Alias: "`avg_grade`"}, + merger.ColumnInfo{Index: 1, Name: "`grade`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: "`grade`", AggregateFunc: "COUNT"}, + ), + }, + requireErrFunc: require.NoError, + after: func(t *testing.T, r rows.Rows, _ []string) { + t.Helper() + + expectedColumnNames := []string{"`avg_grade`"} + columns, err := r.Columns() + require.NoError(t, err) + require.Equal(t, expectedColumnNames, columns) + + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, expectedColumnNames, names) + + scanFunc := func(rr rows.Rows, valSet *[]any) error { + var avgGrade float64 + if err := rr.Scan(&avgGrade); err != nil { + return err + } + *valSet = append(*valSet, []any{avgGrade}) + return nil + } + + cnt := 6 + sum := 600 + require.Equal(t, []any{ + []any{float64(sum) / float64(cnt)}, + }, getRowValues(t, r, scanFunc)) + }, + }, + { + sql: "SELECT AVG(`grade`) AS `avg_grade`, AVG(`age`), AVG(`height`) FROM `t1`", + before: func(t *testing.T, sql string) ([]rows.Rows, []string) { + t.Helper() + targetSQL := "SELECT AVG(`grade`) AS `avg_grade`, SUM(`grade`), COUNT(`grade`), AVG(`age`), SUM(`age`), COUNT(`age`) , AVG(`height`), SUM(`height`), COUNT(`height`) FROM `t1`" + cols := []string{"`avg_grade`", "SUM(`grade`)", "COUNT(`grade`)", "AVG(`age`)", "SUM(`age`)", "COUNT(`age`)", "AVG(`height`)", "SUM(`height`)", "COUNT(`height`)"} + ms.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 400, 2, 18, 36, 2, 160, 320, 2)) + ms.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(40, 120, 3, 18, 54, 3, 170, 510, 3)) + ms.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(80, 80, 1, 18, 18, 1, 180, 180, 1)) + return getResultSet(t, targetSQL, ms.mockDB01, ms.mockDB02, ms.mockDB03), cols + }, + columns: []aggregator.Aggregator{ + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: "`grade`", AggregateFunc: "AVG", Alias: "`avg_grade`"}, + merger.ColumnInfo{Index: 1, Name: "`grade`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: "`grade`", AggregateFunc: "COUNT"}, + ), + aggregator.NewAVG( + merger.ColumnInfo{Index: 3, Name: "`age`", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 4, Name: "`age`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 5, Name: "`age`", AggregateFunc: "COUNT"}, + ), + aggregator.NewAVG( + merger.ColumnInfo{Index: 6, Name: "`height`", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 7, Name: "`height`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 8, Name: "`height`", AggregateFunc: "COUNT"}, + ), + }, + requireErrFunc: require.NoError, + after: func(t *testing.T, r rows.Rows, _ []string) { + t.Helper() + + expectedColumnNames := []string{"`avg_grade`", "AVG(`age`)", "AVG(`height`)"} + columns, err := r.Columns() + require.NoError(t, err) + require.Equal(t, expectedColumnNames, columns) + + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, expectedColumnNames, names) + + scanFunc := func(rr rows.Rows, valSet *[]any) error { + var avgGrade, avgAge, avgHeight float64 + if err := rr.Scan(&avgGrade, &avgAge, &avgHeight); err != nil { + return err + } + *valSet = append(*valSet, []any{avgGrade, avgAge, avgHeight}) + return nil + } + + require.Equal(t, []any{ + []any{float64(100), float64(18), float64(1010) / float64(6)}, + }, getRowValues(t, r, scanFunc)) + }, + }, + { + sql: "SELECT AVG(`grade`),AVG(`age`), AVG(`height`) FROM `t1`", + before: func(t *testing.T, sql string) ([]rows.Rows, []string) { + t.Helper() + targetSQL := "SELECT AVG(`grade`), SUM(`grade`), COUNT(`grade`), AVG(`age`), SUM(`age`), COUNT(`age`) , AVG(`height`), SUM(`height`), COUNT(`height`) FROM `t1`" + cols := []string{"AVG(`grade`)", "SUM(`grade`)", "COUNT(`grade`)", "AVG(`age`)", "SUM(`age`)", "COUNT(`age`)", "AVG(`height`)", "SUM(`height`)", "COUNT(`height`)"} + ms.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(200, 400, 2, 18, 36, 2, 160, 320, 2)) + ms.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(40, 120, 3, 18, 54, 3, 170, 510, 3)) + ms.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(80, 80, 1, 18, 18, 1, 180, 180, 1)) + return getResultSet(t, targetSQL, ms.mockDB01, ms.mockDB02, ms.mockDB03), cols + }, + columns: []aggregator.Aggregator{ + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: "`grade`", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 1, Name: "`grade`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: "`grade`", AggregateFunc: "COUNT"}, + ), + aggregator.NewAVG( + merger.ColumnInfo{Index: 3, Name: "`age`", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 4, Name: "`age`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 5, Name: "`age`", AggregateFunc: "COUNT"}, + ), + aggregator.NewAVG( + merger.ColumnInfo{Index: 6, Name: "`height`", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 7, Name: "`height`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 8, Name: "`height`", AggregateFunc: "COUNT"}, + ), + }, + requireErrFunc: require.NoError, + after: func(t *testing.T, r rows.Rows, _ []string) { + t.Helper() + + expectedColumnNames := []string{"AVG(`grade`)", "AVG(`age`)", "AVG(`height`)"} + columns, err := r.Columns() + require.NoError(t, err) + require.Equal(t, expectedColumnNames, columns) + + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, expectedColumnNames, names) + + require.NoError(t, r.Close()) + + _, err = r.ColumnTypes() + require.ErrorIs(t, err, errs.ErrMergerRowsClosed) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.sql, func(t *testing.T) { + res, cols := tt.before(t, tt.sql) + m := NewMerger(tt.columns...) + r, err := m.Merge(context.Background(), res) + require.NoError(t, err) + tt.after(t, r, cols) + }) + } +} + type mockAggregate struct { cols [][]any } @@ -765,10 +1027,32 @@ func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { return nil, aggregatorErr } -func (*mockAggregate) ColumnName() string { +func (*mockAggregate) ColumnInfo() merger.ColumnInfo { + return merger.ColumnInfo{Name: "mockAggregateColumn"} +} + +func (*mockAggregate) Name() string { return "mockAggregate" } func TestRows_NextResultSet(t *testing.T) { assert.False(t, (&Rows{}).NextResultSet()) } + +func getRowValues(t *testing.T, r rows.Rows, scanFunc func(r rows.Rows, valSet *[]any) error) []any { + var res []any + for r.Next() { + require.NoError(t, scanFunc(r, &res)) + } + return res +} + +func getResultSet(t *testing.T, sql string, dbs ...*sql.DB) []rows.Rows { + resultSet := make([]rows.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.Query(sql) + require.NoError(t, err) + resultSet = append(resultSet, row) + } + return resultSet +} diff --git a/internal/merger/internal/batchmerger/merger.go b/internal/merger/internal/batchmerger/merger.go index c017cb3..9d46e41 100644 --- a/internal/merger/internal/batchmerger/merger.go +++ b/internal/merger/internal/batchmerger/merger.go @@ -17,6 +17,7 @@ package batchmerger import ( "context" "database/sql" + "fmt" "sync" "github.com/ecodeclub/eorm/internal/rows" @@ -88,6 +89,11 @@ type Rows struct { } func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.closed { + return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed) + } return r.rowsList[0].ColumnTypes() } diff --git a/internal/merger/internal/batchmerger/merger_test.go b/internal/merger/internal/batchmerger/merger_test.go index 74fb375..0a93ee0 100644 --- a/internal/merger/internal/batchmerger/merger_test.go +++ b/internal/merger/internal/batchmerger/merger_test.go @@ -66,19 +66,19 @@ func (ms *MergerSuite) TearDownTest() { func (ms *MergerSuite) initMock(t *testing.T) { var err error - ms.mockDB01, ms.mock01, err = sqlmock.New() + ms.mockDB01, ms.mock01, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB02, ms.mock02, err = sqlmock.New() + ms.mockDB02, ms.mock02, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB03, ms.mock03, err = sqlmock.New() + ms.mockDB03, ms.mock03, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB04, ms.mock04, err = sqlmock.New() + ms.mockDB04, ms.mock04, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } @@ -118,8 +118,8 @@ func (ms *MergerSuite) TestMerger_Merge() { }, rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -138,8 +138,8 @@ func (ms *MergerSuite) TestMerger_Merge() { }, rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -155,8 +155,8 @@ func (ms *MergerSuite) TestMerger_Merge() { name: "正常的案例", rowsList: func() []rows.Rows { query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -181,7 +181,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -217,9 +217,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1").AddRow("2")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1").AddRow("2")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -235,10 +235,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -254,10 +254,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -273,9 +273,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) dbs := []*sql.DB{ms.mockDB02, ms.mockDB01, ms.mockDB03} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -291,10 +291,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB02, ms.mockDB01, ms.mockDB04, ms.mockDB03} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -310,9 +310,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) dbs := []*sql.DB{ms.mockDB02, ms.mockDB03, ms.mockDB01} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -328,10 +328,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB02, ms.mockDB03, ms.mockDB01, ms.mockDB04} res := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -345,15 +345,16 @@ func (ms *MergerSuite) TestRows_NextAndScan() { { name: "sqlRows列表中的元素均返回空行", sqlRows: func() []rows.Rows { - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id"})) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id"})) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id"})) + query := "SELECT * FROM `t1`;" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id"})) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id"})) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id"})) res := make([]rows.Rows, 0, 3) - row01, _ := ms.mockDB01.QueryContext(context.Background(), "SELECT * FROM `t1`;") + row01, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, row01) - row02, _ := ms.mockDB02.QueryContext(context.Background(), "SELECT * FROM `t1`;") + row02, _ := ms.mockDB02.QueryContext(context.Background(), query) res = append(res, row02) - row03, _ := ms.mockDB03.QueryContext(context.Background(), "SELECT * FROM `t1`;") + row03, _ := ms.mockDB03.QueryContext(context.Background(), query) res = append(res, row03) return res }, @@ -396,10 +397,10 @@ func (ms *MergerSuite) TestRows_NextAndErr() { rowsList: func() []rows.Rows { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB04, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -428,7 +429,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []rows.Rows{r} @@ -442,7 +443,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).RowError(0, nextMockErr)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1).RowError(0, nextMockErr)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []rows.Rows{r} @@ -461,9 +462,9 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { func (ms *MergerSuite) TestRows_Close() { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").CloseError(newCloseMockErr("db02"))) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2").CloseError(newCloseMockErr("db02"))) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) merger := NewMerger() dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) @@ -509,9 +510,9 @@ func (ms *MergerSuite) TestRows_Close() { func (ms *MergerSuite) TestRows_Columns() { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) merger := NewMerger() dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) @@ -540,6 +541,48 @@ func (ms *MergerSuite) TestRows_Columns() { }) } +func (ms *MergerSuite) TestRows_ColumnTypes() { + t := ms.T() + + cols := []string{"id"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + r, err := NewMerger().Merge(context.Background(), getResultSet(t, query, ms.mockDB01, ms.mockDB02, ms.mockDB03)) + require.NoError(t, err) + + t.Run("rows未关闭", func(t *testing.T) { + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, cols, names) + }) + + t.Run("rows已关闭", func(t *testing.T) { + + require.NoError(t, r.Close()) + + _, err = r.ColumnTypes() + require.ErrorIs(t, err, errs.ErrMergerRowsClosed) + }) + +} + +func getResultSet(t *testing.T, sql string, dbs ...*sql.DB) []rows.Rows { + resultSet := make([]rows.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.Query(sql) + require.NoError(t, err) + resultSet = append(resultSet, row) + } + return resultSet +} + func TestMerger(t *testing.T) { suite.Run(t, &MergerSuite{}) } diff --git a/internal/merger/internal/groupbymerger/aggregator_merger.go b/internal/merger/internal/groupbymerger/aggregator_merger.go index 7f20e08..b469ba3 100644 --- a/internal/merger/internal/groupbymerger/aggregator_merger.go +++ b/internal/merger/internal/groupbymerger/aggregator_merger.go @@ -17,6 +17,7 @@ package groupbymerger import ( "context" "database/sql" + "fmt" "reflect" "sync" _ "unsafe" @@ -37,21 +38,27 @@ import ( type AggregatorMerger struct { aggregators []aggregator.Aggregator + avgIndexes []int groupColumns []merger.ColumnInfo columnsName []string } func NewAggregatorMerger(aggregators []aggregator.Aggregator, groupColumns []merger.ColumnInfo) *AggregatorMerger { cols := make([]string, 0, len(aggregators)+len(groupColumns)) - for _, groubyCol := range groupColumns { - cols = append(cols, groubyCol.SelectName()) + idx := make([]int, 0, len(aggregators)) + for _, c := range groupColumns { + cols = append(cols, c.SelectName()) } for _, agg := range aggregators { - cols = append(cols, agg.ColumnName()) + if agg.Name() == "AVG" { + idx = append(idx, agg.ColumnInfo().Index) + } + cols = append(cols, agg.ColumnInfo().SelectName()) } return &AggregatorMerger{ aggregators: aggregators, + avgIndexes: idx, groupColumns: groupColumns, columnsName: cols, } @@ -69,7 +76,7 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []rows.Rows) (rows if slice.Contains[rows.Rows](results, nil) { return nil, errs.ErrMergerRowsIsNull } - // TODO: 无奈之举, 下方getCols会ScanAll然后出问题, 需要写测试覆盖 + // 下方getCols会ScanAll然后将results中的sql.Rows全部关闭,所以需要在关闭前保留列类型信息 columnTypes, err := results[0].ColumnTypes() if err != nil { return nil, err @@ -83,6 +90,7 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []rows.Rows) (rows rowsList: results, columnTypes: columnTypes, aggregators: a.aggregators, + avgIndexes: a.avgIndexes, groupColumns: a.groupColumns, mu: &sync.RWMutex{}, dataMap: dataMap, @@ -135,6 +143,7 @@ type AggregatorRows struct { rowsList []rows.Rows columnTypes []*sql.ColumnType aggregators []aggregator.Aggregator + avgIndexes []int groupColumns []merger.ColumnInfo dataMap *mapx.TreeMap[Key, [][]any] cur int @@ -147,9 +156,24 @@ type AggregatorRows struct { } func (a *AggregatorRows) ColumnTypes() ([]*sql.ColumnType, error) { - // TODO: 这里是为了让测试通过的临时处理方法,貌似merger会先将 - // 正常应该先判断closed是否为true, 然后再a.rowsList[0].ColumnTypes() - return a.columnTypes, nil + a.mu.Lock() + defer a.mu.Unlock() + if a.closed { + return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed) + } + + if len(a.avgIndexes) == 0 { + return a.columnTypes, nil + } + + v := make([]*sql.ColumnType, 0, len(a.columnTypes)) + var prev int + for i := 0; i < len(a.avgIndexes); i++ { + idx := a.avgIndexes[i] + v = append(v, a.columnTypes[prev:idx+1]...) + prev = idx + 3 + } + return v, nil } func (*AggregatorRows) NextResultSet() bool { diff --git a/internal/merger/internal/groupbymerger/aggregator_merger_test.go b/internal/merger/internal/groupbymerger/aggregator_merger_test.go index 5d68bef..28163e1 100644 --- a/internal/merger/internal/groupbymerger/aggregator_merger_test.go +++ b/internal/merger/internal/groupbymerger/aggregator_merger_test.go @@ -64,19 +64,19 @@ func (ms *MergerSuite) TearDownTest() { func (ms *MergerSuite) initMock(t *testing.T) { var err error - ms.mockDB01, ms.mock01, err = sqlmock.New() + ms.mockDB01, ms.mock01, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB02, ms.mock02, err = sqlmock.New() + ms.mockDB02, ms.mock02, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB03, ms.mock03, err = sqlmock.New() + ms.mockDB03, ms.mock03, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB04, ms.mock04, err = sqlmock.New() + ms.mockDB04, ms.mock04, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } @@ -98,18 +98,18 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "正常案例", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(2, "id")), + aggregator.NewCount(merger.ColumnInfo{Index: 2, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "county"), - merger.NewColumnInfo(1, "gender"), + {Index: 0, Name: "county"}, + {Index: 1, Name: "gender"}, }, rowsList: func() []rows.Rows { query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" cols := []string{"county", "gender", "SUM(id)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -129,17 +129,17 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "超时", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "user_name"), + {Index: 0, Name: "user_name"}, }, rowsList: func() []rows.Rows { query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -158,10 +158,13 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "rowsList为空", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "user_name"), + { + Index: 0, + Name: "user_name", + }, }, rowsList: func() []rows.Rows { return []rows.Rows{} @@ -175,10 +178,10 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "rowsList中有nil", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "user_name"), + {Index: 0, Name: "user_name"}, }, rowsList: func() []rows.Rows { return []rows.Rows{nil} @@ -192,17 +195,17 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "rowsList中有sql.Rows返回错误", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "user_name"), + {Index: 0, Name: "user_name"}, }, rowsList: func() []rows.Rows { query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20).RowError(1, nextMockErr)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20).RowError(1, nextMockErr)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -221,9 +224,9 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { } for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { - merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) + m := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) ctx, cancel := tc.ctx() - groupByRows, err := merger.Merge(ctx, tc.rowsList) + groupByRows, err := m.Merge(ctx, tc.rowsList) cancel() assert.Equal(t, tc.wantErr, err) if err != nil { @@ -247,17 +250,17 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "同一组数据在不同的sql.Rows中", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), + aggregator.NewCount(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "user_name"), + {Index: 0, Name: "user_name"}, }, rowsList: func() []rows.Rows { query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -281,17 +284,17 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "同一组数据在同一个sql.Rows中", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), + aggregator.NewCount(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "COUNT"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "user_name"), + {Index: 0, Name: "user_name"}, }, rowsList: func() []rows.Rows { query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" cols := []string{"user_name", "SUM(id)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("xm", 20)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("xx", 20)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("xm", 20)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("xx", 20)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -319,18 +322,18 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "多个分组列", aggregators: []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewSum(merger.ColumnInfo{Index: 2, Name: "id", AggregateFunc: "SUM"}), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "county"), - merger.NewColumnInfo(1, "gender"), + {Index: 0, Name: "county"}, + {Index: 1, Name: "gender"}, }, rowsList: func() []rows.Rows { query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" cols := []string{"county", "gender", "SUM(id)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -372,20 +375,24 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "多个聚合函数", aggregators: []aggregator.Aggregator{ - aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), - aggregator.NewAVG(merger.NewColumnInfo(3, "SUM(age)"), merger.NewColumnInfo(4, "COUNT(age)"), "AVG(age)"), + aggregator.NewSum(merger.ColumnInfo{Index: 2, Name: "id", AggregateFunc: "SUM"}), + aggregator.NewAVG( + merger.ColumnInfo{Index: 3, Name: "age", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 4, Name: "age", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 5, Name: "age", AggregateFunc: "COUNT"}, + ), }, GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "county"), - merger.NewColumnInfo(1, "gender"), + {Index: 0, Name: "county"}, + {Index: 1, Name: "gender"}, }, rowsList: func() []rows.Rows { - query := "SELECT `county`,`gender`,SUM(`id`),SUM(`age`),COUNT(`age`) FROM `t1` GROUP BY `country`,`gender`" - cols := []string{"county", "gender", "SUM(id)", "SUM(age)", "COUNT(age)"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10, 100, 2).AddRow("hangzhou", "female", 20, 120, 3).AddRow("shanghai", "female", 30, 90, 3)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40, 120, 5).AddRow("shanghai", "female", 50, 120, 4).AddRow("hangzhou", "female", 60, 150, 3)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70, 100, 5).AddRow("shanghai", "female", 80, 150, 5)) + query := "SELECT `county`,`gender`,SUM(`id`), AVG(`age`), SUM(`age`),COUNT(`age`) FROM `t1` GROUP BY `country`,`gender`" + cols := []string{"county", "gender", "SUM(id)", "AVG(`age`)", "SUM(age)", "COUNT(age)"} + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10, 50, 100, 2).AddRow("hangzhou", "female", 20, 40, 120, 3).AddRow("shanghai", "female", 30, 30, 90, 3)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40, 24, 120, 5).AddRow("shanghai", "female", 50, 30, 120, 4).AddRow("hangzhou", "female", 60, 50, 150, 3)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70, 20, 100, 5).AddRow("shanghai", "female", 80, 30, 150, 5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -431,8 +438,8 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { } for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { - merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) - groupByRows, err := merger.Merge(context.Background(), tc.rowsList) + m := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) + groupByRows, err := m.Merge(context.Background(), tc.rowsList) require.NoError(t, err) idx := 0 @@ -457,33 +464,35 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { cols := []string{"userid", "SUM(id)"} - query := "SELECT userid,SUM(id) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) - r, err := ms.mockDB01.QueryContext(context.Background(), query) + query := "SELECT `userid`,SUM(`id`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) + sqlRows, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []rows.Rows{r} - merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(merger.NewColumnInfo(1, "SUM(id)"))}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) - rows, err := merger.Merge(context.Background(), rowsList) + rowsList := []rows.Rows{sqlRows} + m := NewAggregatorMerger([]aggregator.Aggregator{ + aggregator.NewSum(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "SUM"})}, + []merger.ColumnInfo{{Index: 0, Name: "userid"}}) + r, err := m.Merge(context.Background(), rowsList) require.NoError(t, err) userid := 0 sumId := 0 - err = rows.Scan(&userid, &sumId) + err = r.Scan(&userid, &sumId) assert.Equal(t, errs.ErrMergerScanNotNext, err) }) ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { cols := []string{"userid", "SUM(id)"} - query := "SELECT userid,SUM(id) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) - r, err := ms.mockDB01.QueryContext(context.Background(), query) + query := "SELECT `userid`,SUM(`id`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) + sqlRows, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) - rowsList := []rows.Rows{r} - merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) - rows, err := merger.Merge(context.Background(), rowsList) + rowsList := []rows.Rows{sqlRows} + m := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{{Index: 0, Name: "userid"}}) + r, err := m.Merge(context.Background(), rowsList) require.NoError(t, err) userid := 0 sumId := 0 - rows.Next() - err = rows.Scan(&userid, &sumId) + r.Next() + err = r.Scan(&userid, &sumId) assert.Equal(t, aggregatorErr, err) }) @@ -501,11 +510,11 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { name: "有一个aggregator返回error", rowsList: func() []rows.Rows { cols := []string{"username", "COUNT(id)"} - query := "SELECT username,COUNT(`id`) FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 1)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("daming", 2)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("wu", 4)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("ming", 5)) + query := "SELECT `username`,COUNT(`id`) FROM `t1`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 1)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("david", 2)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("wu", 4)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("ming", 5)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -521,44 +530,48 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { } }(), GroupByColumns: []merger.ColumnInfo{ - merger.NewColumnInfo(0, "username"), + {Index: 0, Name: "username"}, }, wantErr: aggregatorErr, }, } for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { - merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) - rows, err := merger.Merge(context.Background(), tc.rowsList()) + m := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) + r, err := m.Merge(context.Background(), tc.rowsList()) require.NoError(t, err) - for rows.Next() { + for r.Next() { } count := int64(0) name := "" - err = rows.Scan(&name, &count) + err = r.Scan(&name, &count) assert.Equal(t, tc.wantErr, err) - assert.Equal(t, tc.wantErr, rows.Err()) + assert.Equal(t, tc.wantErr, r.Err()) }) } } func (ms *MergerSuite) TestAggregatorRows_Columns() { - cols := []string{"userid", "SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} - query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`),`userid` FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10, "zwl")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11, "dm")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12, "xm")) + cols := []string{"user_id", "AVG(`grade`)", "SUM(grade)", "COUNT(grade)", "SUM(grade)", "MIN(grade)", "MAX(grade)", "COUNT(grade)"} + query := "SELECT `user_id`, AVG(`grade`), SUM(`grade`),COUNT(`grade`),SUM(`grade`),MIN(`grade`),MAX(`grade`),COUNT(`grade`) FROM `t1` GROUP BY`user_id`" + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 50, 150, 3, 150, 30, 100, 3)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 100, 200, 2, 200, 50, 150, 2)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 150, 1, 1, 150, 150, 150, 1)) aggregators := []aggregator.Aggregator{ - aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), - aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), - aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), - aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), + aggregator.NewAVG( + merger.ColumnInfo{Index: 1, Name: "grade", AggregateFunc: "AVG"}, + merger.ColumnInfo{Index: 2, Name: "grade", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 3, Name: "grade", AggregateFunc: "COUNT"}, + ), + aggregator.NewSum(merger.ColumnInfo{Index: 4, Name: "grade", AggregateFunc: "SUM"}), + aggregator.NewMin(merger.ColumnInfo{Index: 5, Name: "grade", AggregateFunc: "MIN"}), + aggregator.NewMax(merger.ColumnInfo{Index: 6, Name: "grade", AggregateFunc: "MAX"}), + aggregator.NewCount(merger.ColumnInfo{Index: 7, Name: "grade", AggregateFunc: "COUNT"}), } - groupbyColumns := []merger.ColumnInfo{ - merger.NewColumnInfo(6, "userid"), + groupByColumns := []merger.ColumnInfo{ + {Index: 0, Name: "user_id"}, } - merger := NewAggregatorMerger(aggregators, groupbyColumns) + m := NewAggregatorMerger(aggregators, groupByColumns) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -567,25 +580,81 @@ func (ms *MergerSuite) TestAggregatorRows_Columns() { rowsList = append(rowsList, row) } - rows, err := merger.Merge(context.Background(), rowsList) + r, err := m.Merge(context.Background(), rowsList) require.NoError(ms.T(), err) - wantCols := []string{"userid", "AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} + wantCols := []string{"user_id", "AVG(grade)", "SUM(grade)", "MIN(grade)", "MAX(grade)", "COUNT(grade)"} ms.T().Run("Next没有迭代完", func(t *testing.T) { - for rows.Next() { - columns, err := rows.Columns() + for r.Next() { + columns, err := r.Columns() require.NoError(t, err) assert.Equal(t, wantCols, columns) } - require.NoError(t, rows.Err()) + require.NoError(t, r.Err()) }) ms.T().Run("Next迭代完", func(t *testing.T) { - require.False(t, rows.Next()) - require.NoError(t, rows.Err()) - _, err := rows.Columns() + require.False(t, r.Next()) + require.NoError(t, r.Err()) + _, err := r.Columns() assert.Equal(t, errs.ErrMergerRowsClosed, err) }) } +func (ms *MergerSuite) TestRows_ColumnTypes() { + t := ms.T() + + query := "SELECT AVG(`grade`) AS `avg_grade` FROM `t1`" + cols := []string{"`avg_grade`", "SUM(`grade`)", "COUNT(`grade`)"} + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(100, 200, 2)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(90, 270, 3)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(110, 440, 4)) + aggregators := []aggregator.Aggregator{ + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: "`grade`", AggregateFunc: "AVG", Alias: "`avg_grade`"}, + merger.ColumnInfo{Index: 1, Name: "`grade`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: "`grade`", AggregateFunc: "COUNT"}, + ), + } + + groupByColumns := []merger.ColumnInfo{ + { + Index: 0, + Name: "`grade`", + AggregateFunc: "AVG", + Alias: "`avg_grade`", + }, + } + r, err := NewAggregatorMerger(aggregators, groupByColumns).Merge(context.Background(), getResultSet(t, query, ms.mockDB01, ms.mockDB02, ms.mockDB03)) + require.NoError(t, err) + + t.Run("rows未关闭", func(t *testing.T) { + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, []string{"`avg_grade`"}, names) + }) + + t.Run("rows已关闭", func(t *testing.T) { + require.NoError(t, r.Close()) + + _, err = r.ColumnTypes() + require.ErrorIs(t, err, errs.ErrMergerRowsClosed) + }) +} + +func getResultSet(t *testing.T, sql string, dbs ...*sql.DB) []rows.Rows { + resultSet := make([]rows.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.Query(sql) + require.NoError(t, err) + resultSet = append(resultSet, row) + } + return resultSet +} + type mockAggregate struct { cols [][]any } @@ -595,7 +664,11 @@ func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { return nil, aggregatorErr } -func (*mockAggregate) ColumnName() string { +func (*mockAggregate) ColumnInfo() merger.ColumnInfo { + return merger.ColumnInfo{Name: "mockAggregate"} +} + +func (*mockAggregate) Name() string { return "mockAggregate" } diff --git a/internal/merger/internal/pagedmerger/merger.go b/internal/merger/internal/pagedmerger/merger.go index 4f6c338..2004df4 100644 --- a/internal/merger/internal/pagedmerger/merger.go +++ b/internal/merger/internal/pagedmerger/merger.go @@ -17,6 +17,7 @@ package pagedmerger import ( "context" "database/sql" + "fmt" "sync" "github.com/ecodeclub/eorm/internal/rows" @@ -144,6 +145,11 @@ func (r *Rows) Close() error { } func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.closed { + return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed) + } return r.rows.ColumnTypes() } func (r *Rows) Columns() ([]string, error) { diff --git a/internal/merger/internal/pagedmerger/merger_test.go b/internal/merger/internal/pagedmerger/merger_test.go index 27600c0..e5c48c9 100644 --- a/internal/merger/internal/pagedmerger/merger_test.go +++ b/internal/merger/internal/pagedmerger/merger_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/ecodeclub/eorm/internal/merger/internal/aggregatemerger/aggregator" + "github.com/ecodeclub/eorm/internal/merger/internal/groupbymerger" "github.com/ecodeclub/eorm/internal/rows" "github.com/DATA-DOG/go-sqlmock" @@ -68,19 +70,19 @@ func (ms *MergerSuite) TearDownTest() { func (ms *MergerSuite) initMock(t *testing.T) { var err error - ms.mockDB01, ms.mock01, err = sqlmock.New() + ms.mockDB01, ms.mock01, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB02, ms.mock02, err = sqlmock.New() + ms.mockDB02, ms.mock02, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB03, ms.mock03, err = sqlmock.New() + ms.mockDB03, ms.mock03, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB04, ms.mock04, err = sqlmock.New() + ms.mockDB04, ms.mock04, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } @@ -163,9 +165,9 @@ func (ms *MergerSuite) TestMerger_Merge() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn").RowError(1, offsetMockErr)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn").RowError(1, offsetMockErr)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -190,9 +192,9 @@ func (ms *MergerSuite) TestMerger_Merge() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -216,9 +218,9 @@ func (ms *MergerSuite) TestMerger_Merge() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -273,9 +275,9 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -323,9 +325,9 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -358,9 +360,9 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -382,9 +384,9 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -469,9 +471,9 @@ func (ms *MergerSuite) TestRows_NextAndErr() { GetRowsList: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn").RowError(1, limitMockErr)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn").RowError(1, limitMockErr)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -506,7 +508,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []rows.Rows{r} @@ -523,7 +525,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(2).RowError(1, limitMockErr)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(2).RowError(1, limitMockErr)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []rows.Rows{r} @@ -544,9 +546,9 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { func (ms *MergerSuite) TestRows_Close() { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("5").CloseError(newCloseMockErr("db02"))) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("5").CloseError(newCloseMockErr("db02"))) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) merger, err := sortmerger.NewMerger(false, sortmerger.NewSortColumn("id", sortmerger.ASC)) require.NoError(ms.T(), err) limitMerger, err := NewMerger(merger, 1, 6) @@ -595,9 +597,9 @@ func (ms *MergerSuite) TestRows_Close() { func (ms *MergerSuite) TestRows_Columns() { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) merger, err := sortmerger.NewMerger(false, sortmerger.NewSortColumn("id", sortmerger.ASC)) require.NoError(ms.T(), err) limitMerger, err := NewMerger(merger, 0, 10) @@ -616,6 +618,65 @@ func (ms *MergerSuite) TestRows_Columns() { assert.Equal(ms.T(), cols, columns) } +func (ms *MergerSuite) TestRows_ColumnTypes() { + t := ms.T() + + query := "SELECT AVG(`grade`) AS `avg_grade` FROM `t1`" + cols := []string{"`avg_grade`", "SUM(`grade`)", "COUNT(`grade`)"} + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(100, 200, 2)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(90, 270, 3)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(110, 440, 4)) + aggregators := []aggregator.Aggregator{ + aggregator.NewAVG( + merger.ColumnInfo{Index: 0, Name: "`grade`", AggregateFunc: "AVG", Alias: "`avg_grade`"}, + merger.ColumnInfo{Index: 1, Name: "`grade`", AggregateFunc: "SUM"}, + merger.ColumnInfo{Index: 2, Name: "`grade`", AggregateFunc: "COUNT"}, + ), + } + + groupByColumns := []merger.ColumnInfo{ + { + Index: 0, + Name: "`grade`", + AggregateFunc: "AVG", + Alias: "`avg_grade`", + }, + } + m, err := NewMerger(groupbymerger.NewAggregatorMerger(aggregators, groupByColumns), 0, 3) + require.NoError(t, err) + + r, err := m.Merge(context.Background(), getResultSet(t, query, ms.mockDB01, ms.mockDB02, ms.mockDB03)) + require.NoError(t, err) + + t.Run("rows未关闭", func(t *testing.T) { + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, []string{"`avg_grade`"}, names) + }) + + t.Run("rows已关闭", func(t *testing.T) { + require.NoError(t, r.Close()) + + _, err = r.ColumnTypes() + require.ErrorIs(t, err, errs.ErrMergerRowsClosed) + }) +} + +func getResultSet(t *testing.T, sql string, dbs ...*sql.DB) []rows.Rows { + resultSet := make([]rows.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.Query(sql) + require.NoError(t, err) + resultSet = append(resultSet, row) + } + return resultSet +} + func TestMerger(t *testing.T) { suite.Run(t, &MergerSuite{}) } diff --git a/internal/merger/internal/sortmerger/merger.go b/internal/merger/internal/sortmerger/merger.go index e48f80d..0d07e6a 100644 --- a/internal/merger/internal/sortmerger/merger.go +++ b/internal/merger/internal/sortmerger/merger.go @@ -84,7 +84,8 @@ type Merger struct { preScanAll bool } -// NewMerger preScanAll 表示是否预先扫描出结果集中的所有到内存 +// NewMerger 根据preScanAll及排序列的列信息来创建一个排序Merger +// 其中preScanAll为true 表示需要预先扫描出结果集中的所有数据到内存才能得到正确结果,为false每次只需要扫描一行即可得到正确结果 func NewMerger(preScanAll bool, sortCols ...SortColumn) (*Merger, error) { scs, err := newSortColumns(sortCols...) if err != nil { @@ -147,6 +148,12 @@ func (m *Merger) initRows(results []rows.Rows) (*Rows, error) { } rs.hp = h var err error + // 下方preScanAll会把rowsList中所有数据扫描到内存然后关闭其中所有rows.Rows,所以要提前缓存住列类型信息 + columnTypes, err := rs.rowsList[0].ColumnTypes() + if err != nil { + return nil, err + } + rs.columnTypes = columnTypes for i := 0; i < len(rs.rowsList); i++ { if m.preScanAll { err = rs.preScanAll(rs.rowsList[i], i) @@ -235,6 +242,7 @@ func newNode(row rows.Rows, sortCols sortColumns, index int) (*node, error) { type Rows struct { rowsList []rows.Rows + columnTypes []*sql.ColumnType sortColumns sortColumns hp *Heap cur *node @@ -246,7 +254,12 @@ type Rows struct { } func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { - return r.rowsList[0].ColumnTypes() + r.mu.Lock() + defer r.mu.Unlock() + if r.closed { + return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed) + } + return r.columnTypes, nil } func (*Rows) NextResultSet() bool { @@ -265,7 +278,6 @@ func (r *Rows) Next() bool { return false } r.cur = heap.Pop(r.hp).(*node) - log.Printf("heap node = %#v\n", r.cur) if !r.isPreScanAll { row := r.rowsList[r.cur.index] err := r.preScanOne(row, r.cur.index) @@ -282,10 +294,6 @@ func (r *Rows) Next() bool { } func (r *Rows) preScanAll(row rows.Rows, index int) error { - // TODO Rows抽象之前的假设 rowList中每个sql.Rows中的数据都是已经排序过的 - // 所以只需要读取每个sql.Rows的第一行数据,进行比较就可以得到正确答案 - // 但当使用在pipline中时,就可能需要读取全部sql.Rows中的数据进行排序才能得到正确答案 - // 当然可以进行针对性的优化——两种读模式,一次读一行,一次读全部 for row.Next() { n, err := newNode(row, r.sortColumns, index) if err != nil { diff --git a/internal/merger/internal/sortmerger/merger_test.go b/internal/merger/internal/sortmerger/merger_test.go index 3606de6..ee3dc03 100644 --- a/internal/merger/internal/sortmerger/merger_test.go +++ b/internal/merger/internal/sortmerger/merger_test.go @@ -67,19 +67,19 @@ func (ms *MergerSuite) TearDownTest() { func (ms *MergerSuite) initMock(t *testing.T) { var err error - ms.mockDB01, ms.mock01, err = sqlmock.New() + ms.mockDB01, ms.mock01, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB02, ms.mock02, err = sqlmock.New() + ms.mockDB02, ms.mock02, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB03, ms.mock03, err = sqlmock.New() + ms.mockDB03, ms.mock03, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } - ms.mockDB04, ms.mock04, err = sqlmock.New() + ms.mockDB04, ms.mock04, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatal(err) } @@ -147,8 +147,8 @@ func (ms *MergerSuite) TestMerger_Merge() { }, sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -170,8 +170,8 @@ func (ms *MergerSuite) TestMerger_Merge() { }, sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -197,7 +197,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -238,7 +238,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -257,7 +257,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -276,7 +276,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -295,7 +295,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -314,7 +314,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -332,9 +332,9 @@ func (ms *MergerSuite) TestMerger_Merge() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2).AddRow(3)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4).AddRow(6)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2).AddRow(3)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(4).AddRow(6)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -383,7 +383,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id", "age"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 18)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 18)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -403,9 +403,9 @@ func (ms *MergerSuite) TestMerger_Merge() { sqlRows: func() []rows.Rows { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "curry", "cn").AddRow(1, "zwl", "sh")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "alex", "cn").AddRow(3, "curry", "jp")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "bob", "tw").AddRow(6, "david", "hk")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "curry", "cn").AddRow(1, "zwl", "sh")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "alex", "cn").AddRow(3, "curry", "jp")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "bob", "tw").AddRow(6, "david", "hk")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -455,7 +455,7 @@ func (ms *MergerSuite) TestMerger_Merge() { query := "SELECT * FROM `t1`;" cols := []string{"id", "name", "address"} res := make([]rows.Rows, 0, 1) - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh").RowError(0, nextMockErr)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh").RowError(0, nextMockErr)) rows, _ := ms.mockDB01.QueryContext(context.Background(), query) res = append(res, rows) return res @@ -505,9 +505,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -558,9 +558,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(6, "x", "cn").AddRow(1, "x", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(8, "alex", "cn").AddRow(4, "bruce", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(9, "a", "cn").AddRow(5, "abex", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(6, "x", "cn").AddRow(1, "x", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(8, "alex", "cn").AddRow(4, "bruce", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(9, "a", "cn").AddRow(5, "abex", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -617,10 +617,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "c", "cn").AddRow(2, "bruce", "cn").AddRow(2, "zwl", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "alex", "cn").AddRow(3, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "c", "cn").AddRow(3, "b", "cn").AddRow(5, "c", "cn").AddRow(7, "c", "cn")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "c", "cn").AddRow(2, "bruce", "cn").AddRow(2, "zwl", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "alex", "cn").AddRow(3, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "c", "cn").AddRow(3, "b", "cn").AddRow(5, "c", "cn").AddRow(7, "c", "cn")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -687,9 +687,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(5, "bruce", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(5, "bruce", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -740,9 +740,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(8, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(8, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -799,10 +799,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB04, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -853,9 +853,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -906,9 +906,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -956,10 +956,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(3, "alex", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(3, "alex", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -1011,10 +1011,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols)) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -1035,9 +1035,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { sqlRows: func() []rows.Rows { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "hz").AddRow(3, "b", "hz").AddRow(2, "b", "cs")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "a", "cs").AddRow(1, "a", "cs").AddRow(3, "e", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "d", "hm").AddRow(5, "k", "xx").AddRow(4, "k", "xz")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "hz").AddRow(3, "b", "hz").AddRow(2, "b", "cs")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "a", "cs").AddRow(1, "a", "cs").AddRow(3, "e", "cn")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "d", "hm").AddRow(5, "k", "xx").AddRow(4, "k", "xz")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -1124,9 +1124,9 @@ func (ms *MergerSuite) TestRows_NextAndScan() { func (ms *MergerSuite) TestRows_Columns() { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) merger, err := NewMerger(false, NewSortColumn("id", DESC)) require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} @@ -1160,9 +1160,9 @@ func (ms *MergerSuite) TestRows_Columns() { func (ms *MergerSuite) TestRows_Close() { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").CloseError(newCloseMockErr("db02"))) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2").CloseError(newCloseMockErr("db02"))) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) merger, err := NewMerger(false, NewSortColumn("id", DESC)) require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} @@ -1219,10 +1219,10 @@ func (ms *MergerSuite) TestRows_NextAndErr() { rowsList: func() []rows.Rows { cols := []string{"id"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) - ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) + ms.mock04.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} rowsList := make([]rows.Rows, 0, len(dbs)) for _, db := range dbs { @@ -1256,7 +1256,7 @@ func (ms *MergerSuite) TestRows_ScanErr() { ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []rows.Rows{r} @@ -1271,7 +1271,7 @@ func (ms *MergerSuite) TestRows_ScanErr() { ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { cols := []string{"id", "name", "address"} query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn").RowError(1, nextMockErr)) + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn").RowError(1, nextMockErr)) r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []rows.Rows{r} @@ -1288,6 +1288,56 @@ func (ms *MergerSuite) TestRows_ScanErr() { } +func (ms *MergerSuite) TestRows_ColumnTypes() { + t := ms.T() + + query := "SELECT `grade` FROM `t1` ORDER BY `grade` DESC" + cols := []string{"`grade`"} + ms.mock01.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(100)) + ms.mock02.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(90)) + ms.mock03.ExpectQuery(query).WillReturnRows(sqlmock.NewRows(cols).AddRow(110)) + + sortColumns := []SortColumn{ + { + name: "`grade`", + order: DESC, + }, + } + m, err := NewMerger(true, sortColumns...) + require.NoError(t, err) + + r, err := m.Merge(context.Background(), getResultSet(t, query, ms.mockDB01, ms.mockDB02, ms.mockDB03)) + require.NoError(t, err) + + t.Run("rows未关闭", func(t *testing.T) { + types, err := r.ColumnTypes() + require.NoError(t, err) + + names := make([]string, 0, len(types)) + for _, typ := range types { + names = append(names, typ.Name()) + } + require.Equal(t, []string{"`grade`"}, names) + }) + + t.Run("rows已关闭", func(t *testing.T) { + require.NoError(t, r.Close()) + + _, err = r.ColumnTypes() + require.ErrorIs(t, err, errs.ErrMergerRowsClosed) + }) +} + +func getResultSet(t *testing.T, sql string, dbs ...*sql.DB) []rows.Rows { + resultSet := make([]rows.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.Query(sql) + require.NoError(t, err) + resultSet = append(resultSet, row) + } + return resultSet +} + type TestModel struct { Id int Name string diff --git a/internal/merger/type.go b/internal/merger/type.go index 4b04914..377cb85 100644 --- a/internal/merger/type.go +++ b/internal/merger/type.go @@ -28,12 +28,21 @@ type Merger interface { Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) } +type Order bool + +const ( + // ASC 升序排序 + ASC Order = true + // DESC 降序排序 + DESC Order = false +) + type ColumnInfo struct { Index int Name string AggregateFunc string Alias string - ASC bool + Order Order } func (c ColumnInfo) SelectName() string { @@ -46,13 +55,6 @@ func (c ColumnInfo) SelectName() string { return c.Name } -func NewColumnInfo(index int, name string) ColumnInfo { - return ColumnInfo{ - Index: index, - Name: name, - } -} - func (c ColumnInfo) Validate() bool { // ColumnInfo.Name中不能包含括号,也就是聚合函数, name = `id`, 而不是name = count(`id`) // 聚合函数需要写在aggregateFunc字段中