Skip to content

Commit

Permalink
Merge pull request #23 from aereal/impl-prepare
Browse files Browse the repository at this point in the history
feat: support Prepare()
  • Loading branch information
aereal authored Nov 17, 2020
2 parents 0af91cb + 5e0e89d commit 487028a
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 115 deletions.
130 changes: 130 additions & 0 deletions cases_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package timestreamdriver

import (
"database/sql"
"math"
"reflect"
"strconv"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/timestreamquery"
)

func testRowsQueryScalar(t *testing.T, rows *sql.Rows) {
defer rows.Close()
expectedColumns := columnTypeExpectations{
{name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))},
{name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint, scanType: reflect.TypeOf(int64(0))},
{name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble, scanType: reflect.TypeOf(float64(0))},
{name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean, scanType: reflect.TypeOf(true)},
{name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar, scanType: reflect.TypeOf("")},
{name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond, scanType: reflect.TypeOf("")},
{name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth, scanType: reflect.TypeOf("")},
{name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown, scanType: reflect.TypeOf(nil)},
{name: "time", databaseTypeName: timestreamquery.ScalarTypeTime, scanType: reflect.TypeOf(time.Time{})},
{name: "dt", databaseTypeName: timestreamquery.ScalarTypeDate, scanType: reflect.TypeOf(time.Time{})},
{name: "ts", databaseTypeName: timestreamquery.ScalarTypeTimestamp, scanType: reflect.TypeOf(time.Time{})},
{name: "nullableInt", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))},
}
if cts, err := rows.ColumnTypes(); err == nil {
expectedColumns.compare(t, cts)
} else {
t.Error(err)
}
rowsScanned := false
for rows.Next() {
rowsScanned = true
var (
c1 int
c2 uint64
c3 float64
c4 bool
c5 string
c6 string
c7 string
c8 interface{}
c9 time.Time
c10 time.Time
c11 time.Time
c12 *int
)
if err := rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8, &c9, &c10, &c11, &c12); err != nil {
t.Fatal(err)
}
if c1 != 1 {
t.Errorf("c1: expected=%v got=%v", 1, c1)
}
if c2 != math.MaxUint64 {
t.Errorf("c2: expected=%v got=%v", uint64(math.MaxUint64), c2)
}
if c3 != 0.5 {
t.Errorf("c3: expected=%v got=%v", 0.5, c3)
}
if c4 != true {
t.Errorf("c4: expected=%v got=%v", true, c4)
}
if c5 != "hi" {
t.Errorf("c5: expected=%v got=%v", "hi", c5)
}
if c6 != "0 01:00:00.000000000" {
t.Errorf("c6: expected=%v got=%v", "0 01:00:00.000000000", c6)
}
if c7 != "90 01:00:00.000000000" {
t.Errorf("c7: expected=%v got=%v", "90 01:00:00.000000000", c7)
}
expectedTime := time.Unix(1262349296, 0).UTC()
if !expectedTime.Equal(c9) {
t.Errorf("c9: expected=%s got=%s", expectedTime, c9)
}
expectedDate := time.Unix(1262304000, 0).UTC()
if !expectedDate.Equal(c10) {
t.Errorf("c10: expected=%s got=%s", expectedDate, c10)
}
if !expectedTime.Equal(c11) {
t.Errorf("c11: expected=%s got=%s", expectedDate, c11)
}
if c12 != nil {
t.Errorf("c12: expected=nil got=%#v", c12)
}
}
if !rowsScanned {
t.Error("No rows scanned")
}
}

func scalarOutput() *timestreamquery.QueryOutput {
return &timestreamquery.QueryOutput{
ColumnInfo: []*timestreamquery.ColumnInfo{
scalarColumn("int", timestreamquery.ScalarTypeInteger),
scalarColumn("big", timestreamquery.ScalarTypeBigint),
scalarColumn("percent", timestreamquery.ScalarTypeDouble),
scalarColumn("bool", timestreamquery.ScalarTypeBoolean),
scalarColumn("str", timestreamquery.ScalarTypeVarchar),
scalarColumn("dur1", timestreamquery.ScalarTypeIntervalDayToSecond),
scalarColumn("dur2", timestreamquery.ScalarTypeIntervalYearToMonth),
scalarColumn("nullish", timestreamquery.ScalarTypeUnknown),
scalarColumn("time", timestreamquery.ScalarTypeTime),
scalarColumn("dt", timestreamquery.ScalarTypeDate),
scalarColumn("ts", timestreamquery.ScalarTypeTimestamp),
scalarColumn("nullableInt", timestreamquery.ScalarTypeInteger),
},
Rows: []*timestreamquery.Row{{
Data: []*timestreamquery.Datum{
{ScalarValue: aws.String("1")},
{ScalarValue: aws.String(strconv.FormatUint(math.MaxUint64, 10))},
{ScalarValue: aws.String("0.5")},
{ScalarValue: aws.String("true")},
{ScalarValue: aws.String("hi")},
{ScalarValue: aws.String("0 01:00:00.000000000")},
{ScalarValue: aws.String("90 01:00:00.000000000")},
{},
{ScalarValue: aws.String("2010-01-01 12:34:56.000000000")},
{ScalarValue: aws.String("2010-01-01")},
{ScalarValue: aws.String("2010-01-01 12:34:56.000000000")},
{NullValue: aws.Bool(true)},
},
}},
}
}
4 changes: 2 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func (conn) Begin() (driver.Tx, error) {
return nil, ErrBeginNotSupported
}

func (conn) Prepare(query string) (driver.Stmt, error) {
return nil, ErrPrepareNotSupported
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return &stmt{query: query, cn: c}, nil
}

func (conn) Close() error {
Expand Down
115 changes: 2 additions & 113 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ import (
"database/sql/driver"
"encoding/json"
"fmt"
"math"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"strconv"
"sync"
"testing"
"time"
Expand All @@ -28,38 +26,7 @@ import (

func TestConn_QueryContext_Scalar(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(&timestreamquery.QueryOutput{
ColumnInfo: []*timestreamquery.ColumnInfo{
scalarColumn("int", timestreamquery.ScalarTypeInteger),
scalarColumn("big", timestreamquery.ScalarTypeBigint),
scalarColumn("percent", timestreamquery.ScalarTypeDouble),
scalarColumn("bool", timestreamquery.ScalarTypeBoolean),
scalarColumn("str", timestreamquery.ScalarTypeVarchar),
scalarColumn("dur1", timestreamquery.ScalarTypeIntervalDayToSecond),
scalarColumn("dur2", timestreamquery.ScalarTypeIntervalYearToMonth),
scalarColumn("nullish", timestreamquery.ScalarTypeUnknown),
scalarColumn("time", timestreamquery.ScalarTypeTime),
scalarColumn("dt", timestreamquery.ScalarTypeDate),
scalarColumn("ts", timestreamquery.ScalarTypeTimestamp),
scalarColumn("nullableInt", timestreamquery.ScalarTypeInteger),
},
Rows: []*timestreamquery.Row{{
Data: []*timestreamquery.Datum{
{ScalarValue: aws.String("1")},
{ScalarValue: aws.String(strconv.FormatUint(math.MaxUint64, 10))},
{ScalarValue: aws.String("0.5")},
{ScalarValue: aws.String("true")},
{ScalarValue: aws.String("hi")},
{ScalarValue: aws.String("0 01:00:00.000000000")},
{ScalarValue: aws.String("90 01:00:00.000000000")},
{},
{ScalarValue: aws.String("2010-01-01 12:34:56.000000000")},
{ScalarValue: aws.String("2010-01-01")},
{ScalarValue: aws.String("2010-01-01 12:34:56.000000000")},
{NullValue: aws.Bool(true)},
},
}},
})
_ = json.NewEncoder(w).Encode(scalarOutput())
}))
defer srv.Close()
tsq := timestreamquery.New(session.Must(session.NewSessionWithOptions(session.Options{
Expand All @@ -76,85 +43,7 @@ func TestConn_QueryContext_Scalar(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer rows.Close()
expectedColumns := columnTypeExpectations{
{name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))},
{name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint, scanType: reflect.TypeOf(int64(0))},
{name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble, scanType: reflect.TypeOf(float64(0))},
{name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean, scanType: reflect.TypeOf(true)},
{name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar, scanType: reflect.TypeOf("")},
{name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond, scanType: reflect.TypeOf("")},
{name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth, scanType: reflect.TypeOf("")},
{name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown, scanType: reflect.TypeOf(nil)},
{name: "time", databaseTypeName: timestreamquery.ScalarTypeTime, scanType: reflect.TypeOf(time.Time{})},
{name: "dt", databaseTypeName: timestreamquery.ScalarTypeDate, scanType: reflect.TypeOf(time.Time{})},
{name: "ts", databaseTypeName: timestreamquery.ScalarTypeTimestamp, scanType: reflect.TypeOf(time.Time{})},
{name: "nullableInt", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))},
}
if cts, err := rows.ColumnTypes(); err == nil {
expectedColumns.compare(t, cts)
} else {
t.Error(err)
}
rowsScanned := false
for rows.Next() {
rowsScanned = true
var (
c1 int
c2 uint64
c3 float64
c4 bool
c5 string
c6 string
c7 string
c8 interface{}
c9 time.Time
c10 time.Time
c11 time.Time
c12 *int
)
if err := rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8, &c9, &c10, &c11, &c12); err != nil {
t.Fatal(err)
}
if c1 != 1 {
t.Errorf("c1: expected=%v got=%v", 1, c1)
}
if c2 != math.MaxUint64 {
t.Errorf("c2: expected=%v got=%v", uint64(math.MaxUint64), c2)
}
if c3 != 0.5 {
t.Errorf("c3: expected=%v got=%v", 0.5, c3)
}
if c4 != true {
t.Errorf("c4: expected=%v got=%v", true, c4)
}
if c5 != "hi" {
t.Errorf("c5: expected=%v got=%v", "hi", c5)
}
if c6 != "0 01:00:00.000000000" {
t.Errorf("c6: expected=%v got=%v", "0 01:00:00.000000000", c6)
}
if c7 != "90 01:00:00.000000000" {
t.Errorf("c7: expected=%v got=%v", "90 01:00:00.000000000", c7)
}
expectedTime := time.Unix(1262349296, 0).UTC()
if !expectedTime.Equal(c9) {
t.Errorf("c9: expected=%s got=%s", expectedTime, c9)
}
expectedDate := time.Unix(1262304000, 0).UTC()
if !expectedDate.Equal(c10) {
t.Errorf("c10: expected=%s got=%s", expectedDate, c10)
}
if !expectedTime.Equal(c11) {
t.Errorf("c11: expected=%s got=%s", expectedDate, c11)
}
if c12 != nil {
t.Errorf("c12: expected=nil got=%#v", c12)
}
}
if !rowsScanned {
t.Error("No rows scanned")
}
testRowsQueryScalar(t, rows)
}

type testLogger struct {
Expand Down
40 changes: 40 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package timestreamdriver

import (
"context"
"database/sql/driver"
)

type stmt struct {
query string
cn *conn
}

var _ interface {
driver.Stmt
driver.StmtQueryContext
} = &stmt{}

func (s *stmt) Close() error {
return nil
}

func (s *stmt) NumInput() int {
return -1
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, driver.ErrSkip
}

func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
vs := make([]driver.NamedValue, len(args))
for i, a := range args {
vs[i] = driver.NamedValue{Ordinal: i + 1, Value: a}
}
return s.QueryContext(context.Background(), vs)
}

func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return s.cn.QueryContext(ctx, s.query, args)
}
46 changes: 46 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package timestreamdriver

import (
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/timestreamquery"
)

func TestStatement_Prepare_QueryContext_Scalar(t *testing.T) {
db, close := prepareTestDB()
defer close()
ctx := context.Background()
st, err := db.PrepareContext(ctx, `SELECT 1 FROM table1 WHERE name = ?`)
if err != nil {
t.Fatal(err)
}
defer st.Close()
rows, err := st.QueryContext(ctx, "me")
if err != nil {
t.Fatal(err)
}
testRowsQueryScalar(t, rows)
}

func prepareTestDB() (*sql.DB, func()) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(scalarOutput())
}))
tsq := timestreamquery.New(session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Region: aws.String("us-east-1"),
Endpoint: aws.String(srv.URL),
Credentials: credentials.NewStaticCredentials("id", "secret", "token"),
},
})))

return sql.OpenDB(&connector{tsq}), func() { srv.Close() }
}

0 comments on commit 487028a

Please sign in to comment.