From 4b8ada47e74177fe631cc038c632faf24d14ec19 Mon Sep 17 00:00:00 2001 From: aereal Date: Thu, 12 Nov 2020 14:44:10 +0900 Subject: [PATCH 1/4] feat: implements driver.RowsColumnTypeDatabaseTypeName --- connection_test.go | 54 ++++++++++++++++++++++++++++++++++++++++------ rows.go | 17 +++++++++++++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/connection_test.go b/connection_test.go index ef8cdab..14f2f07 100644 --- a/connection_test.go +++ b/connection_test.go @@ -5,10 +5,10 @@ import ( "database/sql" "database/sql/driver" "encoding/json" + "fmt" "math" "net/http" "net/http/httptest" - "reflect" "strconv" "testing" "time" @@ -64,13 +64,21 @@ func TestConn_QueryContext_Scalar(t *testing.T) { t.Fatal(err) } defer rows.Close() - cols, err := rows.Columns() - if err != nil { - t.Fatal(err) + expectedColumns := columnTypeExpectations{ + {name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger}, + {name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint}, + {name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble}, + {name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean}, + {name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar}, + {name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond}, + {name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth}, + {name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown}, + {name: "time", databaseTypeName: timestreamquery.ScalarTypeTime}, } - expectedColumns := []string{"int", "big", "percent", "bool", "str", "dur1", "dur2", "nullish", "time"} - if !reflect.DeepEqual(cols, expectedColumns) { - t.Errorf("Rows.Columns(): expected=%#v got=%#v", expectedColumns, cols) + if cts, err := rows.ColumnTypes(); err == nil { + expectedColumns.compare(t, cts) + } else { + t.Error(err) } rowsScanned := false for rows.Next() { @@ -160,3 +168,35 @@ func Test_interpolatesQuery(t *testing.T) { }) } } + +type columnTypeExpectation struct { + name string + databaseTypeName string +} + +func (e columnTypeExpectation) compare(ct *sql.ColumnType) error { + if actual := ct.Name(); e.name != actual { + return fmt.Errorf("Name: actual=%q expected=%q", actual, e.name) + } + if actual := ct.DatabaseTypeName(); e.databaseTypeName != actual { + return fmt.Errorf("DatabaseTypeName: actual=%q expected=%q", actual, e.databaseTypeName) + } + return nil +} + +type columnTypeExpectations []columnTypeExpectation + +func (expectations columnTypeExpectations) compare(t *testing.T, columnTypes []*sql.ColumnType) bool { + if len(columnTypes) != len(expectations) { + t.Errorf("length mismatch: expected %d items; got %d items", len(expectations), len(columnTypes)) + return false + } + + for i, ce := range expectations { + actual := columnTypes[i] + if err := ce.compare(actual); err != nil { + t.Errorf("#%d: %s", i, err) + } + } + return true +} diff --git a/rows.go b/rows.go index 355914d..648810b 100644 --- a/rows.go +++ b/rows.go @@ -13,7 +13,9 @@ import ( "github.com/aws/aws-sdk-go/service/timestreamquery" ) -var tsTimeLayout = "2006-01-02 15:04:05.999999999" +var ( + tsTimeLayout = "2006-01-02 15:04:05.999999999" +) type resultSet struct { columns []*timestreamquery.ColumnInfo @@ -27,9 +29,20 @@ type rows struct { } var _ interface { - driver.Rows + driver.RowsColumnTypeDatabaseTypeName } = &rows{} +func (r *rows) ColumnTypeDatabaseTypeName(index int) string { + if len(r.rs.columns) <= index { + return "" + } + ci := r.rs.columns[index] + if ci.Type.ScalarType != nil { + return *ci.Type.ScalarType + } + return "UNKNOWN" +} + func (r *rows) Columns() []string { if r.columnNames != nil { return r.columnNames From ba1f56883a9fe75fbc651931ea0c6e982c7f835f Mon Sep 17 00:00:00 2001 From: aereal Date: Thu, 12 Nov 2020 16:48:47 +0900 Subject: [PATCH 2/4] refactor: add getColumn method for safely get column --- rows.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/rows.go b/rows.go index 648810b..9131cd8 100644 --- a/rows.go +++ b/rows.go @@ -32,11 +32,15 @@ var _ interface { driver.RowsColumnTypeDatabaseTypeName } = &rows{} -func (r *rows) ColumnTypeDatabaseTypeName(index int) string { +func (r *rows) getColumn(index int) *timestreamquery.ColumnInfo { if len(r.rs.columns) <= index { - return "" + return nil } - ci := r.rs.columns[index] + return r.rs.columns[index] +} + +func (r *rows) ColumnTypeDatabaseTypeName(index int) string { + ci := r.getColumn(index) if ci.Type.ScalarType != nil { return *ci.Type.ScalarType } @@ -63,7 +67,7 @@ func (r *rows) Next(dest []driver.Value) error { return io.EOF } for i, datum := range r.rows[r.pos].Data { - columnInfo := r.rs.columns[i] + columnInfo := r.getColumn(i) var err error dest[i], err = scanColumn(datum, columnInfo) if err != nil { From 903bfb4367f5ba432f2209837b021a83e7bb8d0e Mon Sep 17 00:00:00 2001 From: aereal Date: Thu, 12 Nov 2020 18:39:23 +0900 Subject: [PATCH 3/4] refactor: extract getTSDataType --- rows.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/rows.go b/rows.go index 9131cd8..f59236e 100644 --- a/rows.go +++ b/rows.go @@ -40,11 +40,7 @@ func (r *rows) getColumn(index int) *timestreamquery.ColumnInfo { } func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - ci := r.getColumn(index) - if ci.Type.ScalarType != nil { - return *ci.Type.ScalarType - } - return "UNKNOWN" + return getTSDataType(r.getColumn(index)) } func (r *rows) Columns() []string { @@ -163,3 +159,13 @@ func parseTime(datum *timestreamquery.Datum) (time.Time, error) { } return parsed, nil } + +func getTSDataType(ci *timestreamquery.ColumnInfo) string { + if ci == nil { + return typeNameUnknown + } + if ci.Type.ScalarType != nil { + return *ci.Type.ScalarType + } + return typeNameUnknown +} From 0f5533dc80c2a150a92cf3f8225474367d358423 Mon Sep 17 00:00:00 2001 From: aereal Date: Thu, 12 Nov 2020 18:40:30 +0900 Subject: [PATCH 4/4] feat: implements RowsColumnTypeScanType --- connection_test.go | 23 ++++++++++++++--------- rows.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/connection_test.go b/connection_test.go index 14f2f07..1fdf55f 100644 --- a/connection_test.go +++ b/connection_test.go @@ -9,6 +9,7 @@ import ( "math" "net/http" "net/http/httptest" + "reflect" "strconv" "testing" "time" @@ -65,15 +66,15 @@ func TestConn_QueryContext_Scalar(t *testing.T) { } defer rows.Close() expectedColumns := columnTypeExpectations{ - {name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger}, - {name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint}, - {name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble}, - {name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean}, - {name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar}, - {name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond}, - {name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth}, - {name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown}, - {name: "time", databaseTypeName: timestreamquery.ScalarTypeTime}, + {name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))}, + {name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint, scanType: reflect.TypeOf(int64(0))}, + {name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble, scanType: reflect.TypeOf(float64(0))}, + {name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean, scanType: reflect.TypeOf(true)}, + {name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar, scanType: reflect.TypeOf("")}, + {name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond, scanType: reflect.TypeOf("")}, + {name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth, scanType: reflect.TypeOf("")}, + {name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown, scanType: reflect.TypeOf(nil)}, + {name: "time", databaseTypeName: timestreamquery.ScalarTypeTime, scanType: reflect.TypeOf(time.Time{})}, } if cts, err := rows.ColumnTypes(); err == nil { expectedColumns.compare(t, cts) @@ -172,6 +173,7 @@ func Test_interpolatesQuery(t *testing.T) { type columnTypeExpectation struct { name string databaseTypeName string + scanType reflect.Type } func (e columnTypeExpectation) compare(ct *sql.ColumnType) error { @@ -181,6 +183,9 @@ func (e columnTypeExpectation) compare(ct *sql.ColumnType) error { if actual := ct.DatabaseTypeName(); e.databaseTypeName != actual { return fmt.Errorf("DatabaseTypeName: actual=%q expected=%q", actual, e.databaseTypeName) } + if actual := ct.ScanType(); actual != e.scanType { + return fmt.Errorf("ScanType: actual=%s expected=%s", actual, e.scanType) + } return nil } diff --git a/rows.go b/rows.go index f59236e..d62f884 100644 --- a/rows.go +++ b/rows.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "math/big" + "reflect" "strconv" "time" @@ -14,7 +15,16 @@ import ( ) var ( - tsTimeLayout = "2006-01-02 15:04:05.999999999" + tsTimeLayout = "2006-01-02 15:04:05.999999999" + typeNameUnknown = timestreamquery.ScalarTypeUnknown + anyType = reflect.TypeOf(new(interface{})).Elem() + intType = reflect.TypeOf(int(0)) + bigintType = reflect.TypeOf(int64(0)) + doubleType = reflect.TypeOf(float64(0)) + boolType = reflect.TypeOf(true) + stringType = reflect.TypeOf("") + nullType = reflect.TypeOf(nil) + timeType = reflect.TypeOf(time.Time{}) ) type resultSet struct { @@ -30,6 +40,7 @@ type rows struct { var _ interface { driver.RowsColumnTypeDatabaseTypeName + driver.RowsColumnTypeScanType } = &rows{} func (r *rows) getColumn(index int) *timestreamquery.ColumnInfo { @@ -39,6 +50,39 @@ func (r *rows) getColumn(index int) *timestreamquery.ColumnInfo { return r.rs.columns[index] } +func (r *rows) ColumnTypeScanType(index int) reflect.Type { + ci := r.getColumn(index) + if ci == nil { + return anyType + } + switch dt := getTSDataType(ci); dt { + case timestreamquery.ScalarTypeBigint: + return bigintType + case timestreamquery.ScalarTypeBoolean: + return boolType + case timestreamquery.ScalarTypeDate: + return timeType + case timestreamquery.ScalarTypeDouble: + return doubleType + case timestreamquery.ScalarTypeInteger: + return intType + case timestreamquery.ScalarTypeIntervalDayToSecond: + return stringType + case timestreamquery.ScalarTypeIntervalYearToMonth: + return stringType + case timestreamquery.ScalarTypeTime: + return timeType + case timestreamquery.ScalarTypeTimestamp: + return timeType + case timestreamquery.ScalarTypeVarchar: + return stringType + case timestreamquery.ScalarTypeUnknown: + return nullType + default: + return anyType + } +} + func (r *rows) ColumnTypeDatabaseTypeName(index int) string { return getTSDataType(r.getColumn(index)) }