From f6855984c4ea46f5cbea66ef35d094dafffc3bc8 Mon Sep 17 00:00:00 2001 From: aereal Date: Tue, 17 Nov 2020 19:46:47 +0900 Subject: [PATCH 1/2] test: extract scalarOutput --- cases_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++++ connection_test.go | 115 +-------------------------------------- 2 files changed, 132 insertions(+), 113 deletions(-) create mode 100644 cases_test.go diff --git a/cases_test.go b/cases_test.go new file mode 100644 index 0000000..6ac77d9 --- /dev/null +++ b/cases_test.go @@ -0,0 +1,130 @@ +package timestreamdriver + +import ( + "database/sql" + "math" + "reflect" + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/timestreamquery" +) + +func testRowsQueryScalar(t *testing.T, rows *sql.Rows) { + defer rows.Close() + expectedColumns := columnTypeExpectations{ + {name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))}, + {name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint, scanType: reflect.TypeOf(int64(0))}, + {name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble, scanType: reflect.TypeOf(float64(0))}, + {name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean, scanType: reflect.TypeOf(true)}, + {name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar, scanType: reflect.TypeOf("")}, + {name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond, scanType: reflect.TypeOf("")}, + {name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth, scanType: reflect.TypeOf("")}, + {name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown, scanType: reflect.TypeOf(nil)}, + {name: "time", databaseTypeName: timestreamquery.ScalarTypeTime, scanType: reflect.TypeOf(time.Time{})}, + {name: "dt", databaseTypeName: timestreamquery.ScalarTypeDate, scanType: reflect.TypeOf(time.Time{})}, + {name: "ts", databaseTypeName: timestreamquery.ScalarTypeTimestamp, scanType: reflect.TypeOf(time.Time{})}, + {name: "nullableInt", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))}, + } + if cts, err := rows.ColumnTypes(); err == nil { + expectedColumns.compare(t, cts) + } else { + t.Error(err) + } + rowsScanned := false + for rows.Next() { + rowsScanned = true + var ( + c1 int + c2 uint64 + c3 float64 + c4 bool + c5 string + c6 string + c7 string + c8 interface{} + c9 time.Time + c10 time.Time + c11 time.Time + c12 *int + ) + if err := rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8, &c9, &c10, &c11, &c12); err != nil { + t.Fatal(err) + } + if c1 != 1 { + t.Errorf("c1: expected=%v got=%v", 1, c1) + } + if c2 != math.MaxUint64 { + t.Errorf("c2: expected=%v got=%v", uint64(math.MaxUint64), c2) + } + if c3 != 0.5 { + t.Errorf("c3: expected=%v got=%v", 0.5, c3) + } + if c4 != true { + t.Errorf("c4: expected=%v got=%v", true, c4) + } + if c5 != "hi" { + t.Errorf("c5: expected=%v got=%v", "hi", c5) + } + if c6 != "0 01:00:00.000000000" { + t.Errorf("c6: expected=%v got=%v", "0 01:00:00.000000000", c6) + } + if c7 != "90 01:00:00.000000000" { + t.Errorf("c7: expected=%v got=%v", "90 01:00:00.000000000", c7) + } + expectedTime := time.Unix(1262349296, 0).UTC() + if !expectedTime.Equal(c9) { + t.Errorf("c9: expected=%s got=%s", expectedTime, c9) + } + expectedDate := time.Unix(1262304000, 0).UTC() + if !expectedDate.Equal(c10) { + t.Errorf("c10: expected=%s got=%s", expectedDate, c10) + } + if !expectedTime.Equal(c11) { + t.Errorf("c11: expected=%s got=%s", expectedDate, c11) + } + if c12 != nil { + t.Errorf("c12: expected=nil got=%#v", c12) + } + } + if !rowsScanned { + t.Error("No rows scanned") + } +} + +func scalarOutput() *timestreamquery.QueryOutput { + return ×treamquery.QueryOutput{ + ColumnInfo: []*timestreamquery.ColumnInfo{ + scalarColumn("int", timestreamquery.ScalarTypeInteger), + scalarColumn("big", timestreamquery.ScalarTypeBigint), + scalarColumn("percent", timestreamquery.ScalarTypeDouble), + scalarColumn("bool", timestreamquery.ScalarTypeBoolean), + scalarColumn("str", timestreamquery.ScalarTypeVarchar), + scalarColumn("dur1", timestreamquery.ScalarTypeIntervalDayToSecond), + scalarColumn("dur2", timestreamquery.ScalarTypeIntervalYearToMonth), + scalarColumn("nullish", timestreamquery.ScalarTypeUnknown), + scalarColumn("time", timestreamquery.ScalarTypeTime), + scalarColumn("dt", timestreamquery.ScalarTypeDate), + scalarColumn("ts", timestreamquery.ScalarTypeTimestamp), + scalarColumn("nullableInt", timestreamquery.ScalarTypeInteger), + }, + Rows: []*timestreamquery.Row{{ + Data: []*timestreamquery.Datum{ + {ScalarValue: aws.String("1")}, + {ScalarValue: aws.String(strconv.FormatUint(math.MaxUint64, 10))}, + {ScalarValue: aws.String("0.5")}, + {ScalarValue: aws.String("true")}, + {ScalarValue: aws.String("hi")}, + {ScalarValue: aws.String("0 01:00:00.000000000")}, + {ScalarValue: aws.String("90 01:00:00.000000000")}, + {}, + {ScalarValue: aws.String("2010-01-01 12:34:56.000000000")}, + {ScalarValue: aws.String("2010-01-01")}, + {ScalarValue: aws.String("2010-01-01 12:34:56.000000000")}, + {NullValue: aws.Bool(true)}, + }, + }}, + } +} diff --git a/connection_test.go b/connection_test.go index 2976286..135bf47 100644 --- a/connection_test.go +++ b/connection_test.go @@ -6,13 +6,11 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "math" "net/http" "net/http/httptest" "net/url" "os" "reflect" - "strconv" "sync" "testing" "time" @@ -28,38 +26,7 @@ import ( func TestConn_QueryContext_Scalar(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(×treamquery.QueryOutput{ - ColumnInfo: []*timestreamquery.ColumnInfo{ - scalarColumn("int", timestreamquery.ScalarTypeInteger), - scalarColumn("big", timestreamquery.ScalarTypeBigint), - scalarColumn("percent", timestreamquery.ScalarTypeDouble), - scalarColumn("bool", timestreamquery.ScalarTypeBoolean), - scalarColumn("str", timestreamquery.ScalarTypeVarchar), - scalarColumn("dur1", timestreamquery.ScalarTypeIntervalDayToSecond), - scalarColumn("dur2", timestreamquery.ScalarTypeIntervalYearToMonth), - scalarColumn("nullish", timestreamquery.ScalarTypeUnknown), - scalarColumn("time", timestreamquery.ScalarTypeTime), - scalarColumn("dt", timestreamquery.ScalarTypeDate), - scalarColumn("ts", timestreamquery.ScalarTypeTimestamp), - scalarColumn("nullableInt", timestreamquery.ScalarTypeInteger), - }, - Rows: []*timestreamquery.Row{{ - Data: []*timestreamquery.Datum{ - {ScalarValue: aws.String("1")}, - {ScalarValue: aws.String(strconv.FormatUint(math.MaxUint64, 10))}, - {ScalarValue: aws.String("0.5")}, - {ScalarValue: aws.String("true")}, - {ScalarValue: aws.String("hi")}, - {ScalarValue: aws.String("0 01:00:00.000000000")}, - {ScalarValue: aws.String("90 01:00:00.000000000")}, - {}, - {ScalarValue: aws.String("2010-01-01 12:34:56.000000000")}, - {ScalarValue: aws.String("2010-01-01")}, - {ScalarValue: aws.String("2010-01-01 12:34:56.000000000")}, - {NullValue: aws.Bool(true)}, - }, - }}, - }) + _ = json.NewEncoder(w).Encode(scalarOutput()) })) defer srv.Close() tsq := timestreamquery.New(session.Must(session.NewSessionWithOptions(session.Options{ @@ -76,85 +43,7 @@ func TestConn_QueryContext_Scalar(t *testing.T) { if err != nil { t.Fatal(err) } - defer rows.Close() - expectedColumns := columnTypeExpectations{ - {name: "int", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))}, - {name: "big", databaseTypeName: timestreamquery.ScalarTypeBigint, scanType: reflect.TypeOf(int64(0))}, - {name: "percent", databaseTypeName: timestreamquery.ScalarTypeDouble, scanType: reflect.TypeOf(float64(0))}, - {name: "bool", databaseTypeName: timestreamquery.ScalarTypeBoolean, scanType: reflect.TypeOf(true)}, - {name: "str", databaseTypeName: timestreamquery.ScalarTypeVarchar, scanType: reflect.TypeOf("")}, - {name: "dur1", databaseTypeName: timestreamquery.ScalarTypeIntervalDayToSecond, scanType: reflect.TypeOf("")}, - {name: "dur2", databaseTypeName: timestreamquery.ScalarTypeIntervalYearToMonth, scanType: reflect.TypeOf("")}, - {name: "nullish", databaseTypeName: timestreamquery.ScalarTypeUnknown, scanType: reflect.TypeOf(nil)}, - {name: "time", databaseTypeName: timestreamquery.ScalarTypeTime, scanType: reflect.TypeOf(time.Time{})}, - {name: "dt", databaseTypeName: timestreamquery.ScalarTypeDate, scanType: reflect.TypeOf(time.Time{})}, - {name: "ts", databaseTypeName: timestreamquery.ScalarTypeTimestamp, scanType: reflect.TypeOf(time.Time{})}, - {name: "nullableInt", databaseTypeName: timestreamquery.ScalarTypeInteger, scanType: reflect.TypeOf(int(0))}, - } - if cts, err := rows.ColumnTypes(); err == nil { - expectedColumns.compare(t, cts) - } else { - t.Error(err) - } - rowsScanned := false - for rows.Next() { - rowsScanned = true - var ( - c1 int - c2 uint64 - c3 float64 - c4 bool - c5 string - c6 string - c7 string - c8 interface{} - c9 time.Time - c10 time.Time - c11 time.Time - c12 *int - ) - if err := rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8, &c9, &c10, &c11, &c12); err != nil { - t.Fatal(err) - } - if c1 != 1 { - t.Errorf("c1: expected=%v got=%v", 1, c1) - } - if c2 != math.MaxUint64 { - t.Errorf("c2: expected=%v got=%v", uint64(math.MaxUint64), c2) - } - if c3 != 0.5 { - t.Errorf("c3: expected=%v got=%v", 0.5, c3) - } - if c4 != true { - t.Errorf("c4: expected=%v got=%v", true, c4) - } - if c5 != "hi" { - t.Errorf("c5: expected=%v got=%v", "hi", c5) - } - if c6 != "0 01:00:00.000000000" { - t.Errorf("c6: expected=%v got=%v", "0 01:00:00.000000000", c6) - } - if c7 != "90 01:00:00.000000000" { - t.Errorf("c7: expected=%v got=%v", "90 01:00:00.000000000", c7) - } - expectedTime := time.Unix(1262349296, 0).UTC() - if !expectedTime.Equal(c9) { - t.Errorf("c9: expected=%s got=%s", expectedTime, c9) - } - expectedDate := time.Unix(1262304000, 0).UTC() - if !expectedDate.Equal(c10) { - t.Errorf("c10: expected=%s got=%s", expectedDate, c10) - } - if !expectedTime.Equal(c11) { - t.Errorf("c11: expected=%s got=%s", expectedDate, c11) - } - if c12 != nil { - t.Errorf("c12: expected=nil got=%#v", c12) - } - } - if !rowsScanned { - t.Error("No rows scanned") - } + testRowsQueryScalar(t, rows) } type testLogger struct { From 5e0e89d39a4628fc7d9a44f8dfe60c2529d72e1e Mon Sep 17 00:00:00 2001 From: aereal Date: Tue, 17 Nov 2020 19:48:55 +0900 Subject: [PATCH 2/2] feat: support Prepare() --- connection.go | 4 ++-- statement.go | 40 ++++++++++++++++++++++++++++++++++++++++ statement_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 statement.go create mode 100644 statement_test.go diff --git a/connection.go b/connection.go index 799719c..06b3d37 100644 --- a/connection.go +++ b/connection.go @@ -34,8 +34,8 @@ func (conn) Begin() (driver.Tx, error) { return nil, ErrBeginNotSupported } -func (conn) Prepare(query string) (driver.Stmt, error) { - return nil, ErrPrepareNotSupported +func (c *conn) Prepare(query string) (driver.Stmt, error) { + return &stmt{query: query, cn: c}, nil } func (conn) Close() error { diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..297bd49 --- /dev/null +++ b/statement.go @@ -0,0 +1,40 @@ +package timestreamdriver + +import ( + "context" + "database/sql/driver" +) + +type stmt struct { + query string + cn *conn +} + +var _ interface { + driver.Stmt + driver.StmtQueryContext +} = &stmt{} + +func (s *stmt) Close() error { + return nil +} + +func (s *stmt) NumInput() int { + return -1 +} + +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { + return nil, driver.ErrSkip +} + +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { + vs := make([]driver.NamedValue, len(args)) + for i, a := range args { + vs[i] = driver.NamedValue{Ordinal: i + 1, Value: a} + } + return s.QueryContext(context.Background(), vs) +} + +func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return s.cn.QueryContext(ctx, s.query, args) +} diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 0000000..4ac79a3 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,46 @@ +package timestreamdriver + +import ( + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/timestreamquery" +) + +func TestStatement_Prepare_QueryContext_Scalar(t *testing.T) { + db, close := prepareTestDB() + defer close() + ctx := context.Background() + st, err := db.PrepareContext(ctx, `SELECT 1 FROM table1 WHERE name = ?`) + if err != nil { + t.Fatal(err) + } + defer st.Close() + rows, err := st.QueryContext(ctx, "me") + if err != nil { + t.Fatal(err) + } + testRowsQueryScalar(t, rows) +} + +func prepareTestDB() (*sql.DB, func()) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(scalarOutput()) + })) + tsq := timestreamquery.New(session.Must(session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(srv.URL), + Credentials: credentials.NewStaticCredentials("id", "secret", "token"), + }, + }))) + + return sql.OpenDB(&connector{tsq}), func() { srv.Close() } +}