Skip to content

Commit

Permalink
add Queryable interface
Browse files Browse the repository at this point in the history
  • Loading branch information
alimy committed Aug 23, 2023
1 parent 06d07f4 commit 609aefe
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
43 changes: 43 additions & 0 deletions core.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package sqlx

import (
"context"
"database/sql"
)

var (
_ Queryable = (*DB)(nil)
_ Queryable = (*Tx)(nil)
)

// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing
// either type to be used interchangeably.
type Queryable interface {
Ext
ExecIn
QueryIn
ExecerContext
PreparerContext
QueryerContext
Preparer

GetContext(context.Context, interface{}, string, ...interface{}) error
SelectContext(context.Context, interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
MustExecContext(context.Context, string, ...interface{}) sql.Result
PreparexContext(context.Context, string) (*Stmt, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
Select(interface{}, string, ...interface{}) error
QueryRow(string, ...interface{}) *sql.Row
PrepareNamedContext(context.Context, string) (*NamedStmt, error)
PrepareNamed(string) (*NamedStmt, error)
Preparex(string) (*Stmt, error)
NamedExec(string, interface{}) (sql.Result, error)
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
MustExec(string, ...interface{}) sql.Result
NamedQuery(string, interface{}) (*Rows, error)
InGet(any, string, ...any) error
InSelect(any, string, ...any) error
InExec(query string, args ...any) (sql.Result, error)
MustInExec(string, ...any) sql.Result
}
78 changes: 78 additions & 0 deletions core_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package sqlx

import (
"database/sql"
"reflect"
"testing"
)

func TestQueryable(t *testing.T) {
sqlDBType := reflect.TypeOf(&sql.DB{})
dbType := reflect.TypeOf(&DB{})
sqlTxType := reflect.TypeOf(&sql.Tx{})
txType := reflect.TypeOf(&Tx{})

dbMethods := exportableMethods(sqlDBType)
for k, v := range exportableMethods(dbType) {
dbMethods[k] = v
}

txMethods := exportableMethods(sqlTxType)
for k, v := range exportableMethods(txType) {
txMethods[k] = v
}

sharedMethods := make([]string, 0)

for name, dbMethod := range dbMethods {
if txMethod, ok := txMethods[name]; ok {
if methodsEqual(dbMethod.Type, txMethod.Type) {
sharedMethods = append(sharedMethods, name)
}
}
}

queryableType := reflect.TypeOf((*Queryable)(nil)).Elem()
queryableMethods := exportableMethods(queryableType)

for _, sharedMethodName := range sharedMethods {
if _, ok := queryableMethods[sharedMethodName]; !ok {
t.Errorf("Queryable does not include shared DB/Tx method: %s", sharedMethodName)
}
}
}

func exportableMethods(t reflect.Type) map[string]reflect.Method {
methods := make(map[string]reflect.Method)

for i := 0; i < t.NumMethod(); i++ {
method := t.Method(i)

if method.IsExported() {
methods[method.Name] = method
}
}

return methods
}

func methodsEqual(t reflect.Type, ot reflect.Type) bool {
if t.NumIn() != ot.NumIn() || t.NumOut() != ot.NumOut() || t.IsVariadic() != ot.IsVariadic() {
return false
}

// Start at 1 to avoid comparing receiver argument
for i := 1; i < t.NumIn(); i++ {
if t.In(i) != ot.In(i) {
return false
}
}

for i := 0; i < t.NumOut(); i++ {
if t.Out(i) != ot.Out(i) {
return false
}
}

return true
}

0 comments on commit 609aefe

Please sign in to comment.