diff --git a/core.go b/core.go new file mode 100644 index 0000000..a1ccf7d --- /dev/null +++ b/core.go @@ -0,0 +1,43 @@ +package sqlx + +import ( + "context" + "database/sql" +) + +var ( + _ Queryable = (*DB)(nil) + _ Queryable = (*Tx)(nil) +) + +// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing +// either type to be used interchangeably. +type Queryable interface { + Ext + ExecIn + QueryIn + ExecerContext + PreparerContext + QueryerContext + Preparer + + GetContext(context.Context, interface{}, string, ...interface{}) error + SelectContext(context.Context, interface{}, string, ...interface{}) error + Get(interface{}, string, ...interface{}) error + MustExecContext(context.Context, string, ...interface{}) sql.Result + PreparexContext(context.Context, string) (*Stmt, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Select(interface{}, string, ...interface{}) error + QueryRow(string, ...interface{}) *sql.Row + PrepareNamedContext(context.Context, string) (*NamedStmt, error) + PrepareNamed(string) (*NamedStmt, error) + Preparex(string) (*Stmt, error) + NamedExec(string, interface{}) (sql.Result, error) + NamedExecContext(context.Context, string, interface{}) (sql.Result, error) + MustExec(string, ...interface{}) sql.Result + NamedQuery(string, interface{}) (*Rows, error) + InGet(any, string, ...any) error + InSelect(any, string, ...any) error + InExec(query string, args ...any) (sql.Result, error) + MustInExec(string, ...any) sql.Result +} diff --git a/core_test.go b/core_test.go new file mode 100644 index 0000000..1fbaef5 --- /dev/null +++ b/core_test.go @@ -0,0 +1,78 @@ +package sqlx + +import ( + "database/sql" + "reflect" + "testing" +) + +func TestQueryable(t *testing.T) { + sqlDBType := reflect.TypeOf(&sql.DB{}) + dbType := reflect.TypeOf(&DB{}) + sqlTxType := reflect.TypeOf(&sql.Tx{}) + txType := reflect.TypeOf(&Tx{}) + + dbMethods := exportableMethods(sqlDBType) + for k, v := range exportableMethods(dbType) { + dbMethods[k] = v + } + + txMethods := exportableMethods(sqlTxType) + for k, v := range exportableMethods(txType) { + txMethods[k] = v + } + + sharedMethods := make([]string, 0) + + for name, dbMethod := range dbMethods { + if txMethod, ok := txMethods[name]; ok { + if methodsEqual(dbMethod.Type, txMethod.Type) { + sharedMethods = append(sharedMethods, name) + } + } + } + + queryableType := reflect.TypeOf((*Queryable)(nil)).Elem() + queryableMethods := exportableMethods(queryableType) + + for _, sharedMethodName := range sharedMethods { + if _, ok := queryableMethods[sharedMethodName]; !ok { + t.Errorf("Queryable does not include shared DB/Tx method: %s", sharedMethodName) + } + } +} + +func exportableMethods(t reflect.Type) map[string]reflect.Method { + methods := make(map[string]reflect.Method) + + for i := 0; i < t.NumMethod(); i++ { + method := t.Method(i) + + if method.IsExported() { + methods[method.Name] = method + } + } + + return methods +} + +func methodsEqual(t reflect.Type, ot reflect.Type) bool { + if t.NumIn() != ot.NumIn() || t.NumOut() != ot.NumOut() || t.IsVariadic() != ot.IsVariadic() { + return false + } + + // Start at 1 to avoid comparing receiver argument + for i := 1; i < t.NumIn(); i++ { + if t.In(i) != ot.In(i) { + return false + } + } + + for i := 0; i < t.NumOut(); i++ { + if t.Out(i) != ot.Out(i) { + return false + } + } + + return true +}