Skip to content

Commit

Permalink
Merge pull request #11 from aereal/db-type-names
Browse files Browse the repository at this point in the history
feat: enrich Rows implementations
  • Loading branch information
aereal authored Nov 12, 2020
2 parents a202450 + 0f5533d commit 5938aee
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 9 deletions.
57 changes: 51 additions & 6 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"math"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -64,13 +65,21 @@ func TestConn_QueryContext_Scalar(t *testing.T) {
t.Fatal(err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
t.Fatal(err)
expectedColumns := columnTypeExpectations{
{name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))},
{name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint, scanType: reflect.TypeOf(int64(0))},
{name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble, scanType: reflect.TypeOf(float64(0))},
{name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean, scanType: reflect.TypeOf(true)},
{name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar, scanType: reflect.TypeOf("")},
{name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond, scanType: reflect.TypeOf("")},
{name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth, scanType: reflect.TypeOf("")},
{name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown, scanType: reflect.TypeOf(nil)},
{name: "time", databaseTypeName: timestreamquery.ScalarTypeTime, scanType: reflect.TypeOf(time.Time{})},
}
expectedColumns := []string{"int", "big", "percent", "bool", "str", "dur1", "dur2", "nullish", "time"}
if !reflect.DeepEqual(cols, expectedColumns) {
t.Errorf("Rows.Columns(): expected=%#v got=%#v", expectedColumns, cols)
if cts, err := rows.ColumnTypes(); err == nil {
expectedColumns.compare(t, cts)
} else {
t.Error(err)
}
rowsScanned := false
for rows.Next() {
Expand Down Expand Up @@ -160,3 +169,39 @@ func Test_interpolatesQuery(t *testing.T) {
})
}
}

type columnTypeExpectation struct {
name string
databaseTypeName string
scanType reflect.Type
}

func (e columnTypeExpectation) compare(ct *sql.ColumnType) error {
if actual := ct.Name(); e.name != actual {
return fmt.Errorf("Name: actual=%q expected=%q", actual, e.name)
}
if actual := ct.DatabaseTypeName(); e.databaseTypeName != actual {
return fmt.Errorf("DatabaseTypeName: actual=%q expected=%q", actual, e.databaseTypeName)
}
if actual := ct.ScanType(); actual != e.scanType {
return fmt.Errorf("ScanType: actual=%s expected=%s", actual, e.scanType)
}
return nil
}

type columnTypeExpectations []columnTypeExpectation

func (expectations columnTypeExpectations) compare(t *testing.T, columnTypes []*sql.ColumnType) bool {
if len(columnTypes) != len(expectations) {
t.Errorf("length mismatch: expected %d items; got %d items", len(expectations), len(columnTypes))
return false
}

for i, ce := range expectations {
actual := columnTypes[i]
if err := ce.compare(actual); err != nil {
t.Errorf("#%d: %s", i, err)
}
}
return true
}
73 changes: 70 additions & 3 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@ import (
"fmt"
"io"
"math/big"
"reflect"
"strconv"
"time"

"github.com/aws/aws-sdk-go/service/timestreamquery"
)

var tsTimeLayout = "2006-01-02 15:04:05.999999999"
var (
tsTimeLayout = "2006-01-02 15:04:05.999999999"
typeNameUnknown = timestreamquery.ScalarTypeUnknown
anyType = reflect.TypeOf(new(interface{})).Elem()
intType = reflect.TypeOf(int(0))
bigintType = reflect.TypeOf(int64(0))
doubleType = reflect.TypeOf(float64(0))
boolType = reflect.TypeOf(true)
stringType = reflect.TypeOf("")
nullType = reflect.TypeOf(nil)
timeType = reflect.TypeOf(time.Time{})
)

type resultSet struct {
columns []*timestreamquery.ColumnInfo
Expand All @@ -27,9 +39,54 @@ type rows struct {
}

var _ interface {
driver.Rows
driver.RowsColumnTypeDatabaseTypeName
driver.RowsColumnTypeScanType
} = &rows{}

func (r *rows) getColumn(index int) *timestreamquery.ColumnInfo {
if len(r.rs.columns) <= index {
return nil
}
return r.rs.columns[index]
}

func (r *rows) ColumnTypeScanType(index int) reflect.Type {
ci := r.getColumn(index)
if ci == nil {
return anyType
}
switch dt := getTSDataType(ci); dt {
case timestreamquery.ScalarTypeBigint:
return bigintType
case timestreamquery.ScalarTypeBoolean:
return boolType
case timestreamquery.ScalarTypeDate:
return timeType
case timestreamquery.ScalarTypeDouble:
return doubleType
case timestreamquery.ScalarTypeInteger:
return intType
case timestreamquery.ScalarTypeIntervalDayToSecond:
return stringType
case timestreamquery.ScalarTypeIntervalYearToMonth:
return stringType
case timestreamquery.ScalarTypeTime:
return timeType
case timestreamquery.ScalarTypeTimestamp:
return timeType
case timestreamquery.ScalarTypeVarchar:
return stringType
case timestreamquery.ScalarTypeUnknown:
return nullType
default:
return anyType
}
}

func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
return getTSDataType(r.getColumn(index))
}

func (r *rows) Columns() []string {
if r.columnNames != nil {
return r.columnNames
Expand All @@ -50,7 +107,7 @@ func (r *rows) Next(dest []driver.Value) error {
return io.EOF
}
for i, datum := range r.rows[r.pos].Data {
columnInfo := r.rs.columns[i]
columnInfo := r.getColumn(i)
var err error
dest[i], err = scanColumn(datum, columnInfo)
if err != nil {
Expand Down Expand Up @@ -146,3 +203,13 @@ func parseTime(datum *timestreamquery.Datum) (time.Time, error) {
}
return parsed, nil
}

func getTSDataType(ci *timestreamquery.ColumnInfo) string {
if ci == nil {
return typeNameUnknown
}
if ci.Type.ScalarType != nil {
return *ci.Type.ScalarType
}
return typeNameUnknown
}

0 comments on commit 5938aee

Please sign in to comment.