Skip to content

Commit ca205c2

Browse files
committed
Fix bug generic adapter
1 parent b849ae8 commit ca205c2

10 files changed

+105
-50
lines changed

adapter/adapter.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func NewSqlAdapterWithVersionAndArray[T any](db *sql.DB, tableName string, versi
8080
return adapter, nil
8181
}
8282

83-
func (a *Adapter[T]) Create(ctx context.Context, model interface{}) (int64, error) {
83+
func (a *Adapter[T]) Create(ctx context.Context, model T) (int64, error) {
8484
tx := q.GetExec(ctx, a.DB, a.TxKey)
8585
query, args := q.BuildToInsertWithVersion(a.Table, model, a.versionIndex, a.BuildParam, a.BoolSupport, a.ToArray, a.Schema)
8686
res, err := tx.ExecContext(ctx, query, args...)
@@ -89,7 +89,7 @@ func (a *Adapter[T]) Create(ctx context.Context, model interface{}) (int64, erro
8989
}
9090
return res.RowsAffected()
9191
}
92-
func (a *Adapter[T]) Update(ctx context.Context, model interface{}) (int64, error) {
92+
func (a *Adapter[T]) Update(ctx context.Context, model T) (int64, error) {
9393
query, args := q.BuildToUpdateWithVersion(a.Table, model, a.versionIndex, a.BuildParam, a.BoolSupport, a.ToArray, a.Schema)
9494
tx := q.GetExec(ctx, a.DB, a.TxKey)
9595
res, err := tx.ExecContext(ctx, query, args...)
@@ -98,7 +98,7 @@ func (a *Adapter[T]) Update(ctx context.Context, model interface{}) (int64, erro
9898
}
9999
return res.RowsAffected()
100100
}
101-
func (a *Adapter[T]) Save(ctx context.Context, model interface{}) (int64, error) {
101+
func (a *Adapter[T]) Save(ctx context.Context, model T) (int64, error) {
102102
query, args, err := q.BuildToSaveWithSchema(a.Table, model, a.Driver, a.BuildParam, a.ToArray, a.Schema)
103103
if err != nil {
104104
return 0, err

adapter/generic_adapter.go

+55-7
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"encoding/json"
8+
"errors"
79
"fmt"
8-
q "github.com/core-go/sql"
910
"reflect"
11+
12+
q "github.com/core-go/sql"
1013
)
1114

1215
type GenericAdapter[T any, K any] struct {
1316
*Adapter[*T]
1417
Map map[string]int
1518
Fields string
19+
Keys []string
20+
IdMap bool
1621
}
1722

1823
func NewGenericAdapter[T any, K any](db *sql.DB, tableName string, opts ...func(int) string) (*GenericAdapter[T, K], error) {
@@ -29,17 +34,31 @@ func NewSqlGenericAdapterWithVersionAndArray[T any, K any](db *sql.DB, tableName
2934
if err != nil {
3035
return nil, err
3136
}
37+
3238
var t T
3339
modelType := reflect.TypeOf(t)
34-
if modelType.Kind() == reflect.Ptr {
35-
modelType = modelType.Elem()
40+
if modelType.Kind() != reflect.Struct {
41+
return nil, errors.New("T must be a struct")
3642
}
43+
44+
_, primaryKeys := q.FindPrimaryKeys(modelType)
45+
var k K
46+
kType := reflect.TypeOf(k)
47+
idMap := false
48+
if len(primaryKeys) > 1 {
49+
if kType.Kind() == reflect.Map {
50+
idMap = true
51+
} else if kType.Kind() != reflect.Struct {
52+
return nil, errors.New("For composite keys, K must be a struct or a map")
53+
}
54+
}
55+
3756
fieldsIndex, err := q.GetColumnIndexes(modelType)
3857
if err != nil {
3958
return nil, err
4059
}
4160
fields := q.BuildFieldsBySchema(adapter.Schema)
42-
return &GenericAdapter[T, K]{adapter, fieldsIndex, fields}, nil
61+
return &GenericAdapter[T, K]{adapter, fieldsIndex, fields, primaryKeys, idMap}, nil
4362
}
4463
func (a *GenericAdapter[T, K]) All(ctx context.Context) ([]T, error) {
4564
var objs []T
@@ -48,10 +67,31 @@ func (a *GenericAdapter[T, K]) All(ctx context.Context) ([]T, error) {
4867
err := q.Query(ctx, tx, a.Map, &objs, query)
4968
return objs, err
5069
}
70+
func toMap(obj interface{}) (map[string]interface{}, error) {
71+
b, err := json.Marshal(obj)
72+
if err != nil {
73+
return nil, err
74+
}
75+
im := make(map[string]interface{})
76+
er2 := json.Unmarshal(b, &im)
77+
return im, er2
78+
}
79+
func (a *GenericAdapter[T, K]) getId(k K) (interface{}, error) {
80+
if len(a.Keys) >= 2 && !a.IdMap {
81+
ri, err := toMap(k)
82+
return ri, err
83+
} else {
84+
return k, nil
85+
}
86+
}
5187
func (a *GenericAdapter[T, K]) Load(ctx context.Context, id K) (*T, error) {
88+
ip, er0 := a.getId(id)
89+
if er0 != nil {
90+
return nil, er0
91+
}
5292
var objs []T
5393
query := fmt.Sprintf("select %s from %s ", a.Fields, a.Table)
54-
query1, args := q.BuildFindByIdWithDB(a.DB, query, id, a.JsonColumnMap, a.Schema.SKeys, a.BuildParam)
94+
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Schema.SKeys, a.BuildParam)
5595
tx := q.GetExec(ctx, a.DB, a.TxKey)
5696
err := q.Query(ctx, tx, a.Map, &objs, query1, args...)
5797
if err != nil {
@@ -63,8 +103,12 @@ func (a *GenericAdapter[T, K]) Load(ctx context.Context, id K) (*T, error) {
63103
return nil, nil
64104
}
65105
func (a *GenericAdapter[T, K]) Exist(ctx context.Context, id K) (bool, error) {
106+
ip, er0 := a.getId(id)
107+
if er0 != nil {
108+
return false, er0
109+
}
66110
query := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
67-
query1, args := q.BuildFindByIdWithDB(a.DB, query, id, a.JsonColumnMap, a.Schema.SKeys, a.BuildParam)
111+
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Schema.SKeys, a.BuildParam)
68112
tx := q.GetExec(ctx, a.DB, a.TxKey)
69113
rows, err := tx.QueryContext(ctx, query1, args...)
70114
if err != nil {
@@ -77,8 +121,12 @@ func (a *GenericAdapter[T, K]) Exist(ctx context.Context, id K) (bool, error) {
77121
return false, nil
78122
}
79123
func (a *GenericAdapter[T, K]) Delete(ctx context.Context, id K) (int64, error) {
124+
ip, er0 := a.getId(id)
125+
if er0 != nil {
126+
return -1, er0
127+
}
80128
query := fmt.Sprintf("delete from %s ", a.Table)
81-
query1, args := q.BuildFindByIdWithDB(a.DB, query, id, a.JsonColumnMap, a.Schema.SKeys, a.BuildParam)
129+
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Schema.SKeys, a.BuildParam)
82130
tx := q.GetExec(ctx, a.DB, a.TxKey)
83131
res, err := tx.ExecContext(ctx, query1, args...)
84132
if err != nil {

adapter/search.go

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ func NewSearchAdapterWithArray[T any, K any, F any](db *sql.DB, table string, bu
3232
sql.Scanner
3333
}, versionField string, buildParam func(int) string, opts ...func(context.Context, interface{}) (interface{}, error)) (*SearchAdapter[T, K, F], error) {
3434
adapter, err := NewSqlGenericAdapterWithVersionAndArray[T, K](db, table, versionField, toArray, buildParam)
35+
if err != nil {
36+
return nil, err
37+
}
3538
var mp func(context.Context, interface{}) (interface{}, error)
3639
if len(opts) >= 1 {
3740
mp = opts[0]

batch/batch_inserter.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type BatchInserter struct {
2222
sql.Scanner
2323
}
2424
}
25+
2526
func NewBatchInserter(db *sql.DB, tableName string, modelType reflect.Type, options ...func(context.Context, interface{}) (interface{}, error)) *BatchInserter {
2627
var mp func(context.Context, interface{}) (interface{}, error)
2728
if len(options) > 0 && options[0] != nil {
@@ -65,7 +66,7 @@ func (w *BatchInserter) Write(ctx context.Context, models interface{}) ([]int, [
6566
if er0 != nil {
6667
s0 := reflect.ValueOf(models2)
6768
_, er0b := q.InterfaceSlice(models2)
68-
failIndices = q.ToArrayIndex(s0, failIndices)
69+
failIndices = ToArrayIndex(s0, failIndices)
6970
return successIndices, failIndices, er0b
7071
}
7172
} else {
@@ -76,11 +77,11 @@ func (w *BatchInserter) Write(ctx context.Context, models interface{}) ([]int, [
7677

7778
if er2 == nil {
7879
// Return full success
79-
successIndices = q.ToArrayIndex(s, successIndices)
80+
successIndices = ToArrayIndex(s, successIndices)
8081
return successIndices, failIndices, er2
8182
} else {
8283
// Return full fail
83-
failIndices = q.ToArrayIndex(s, failIndices)
84+
failIndices = ToArrayIndex(s, failIndices)
8485
}
8586
return successIndices, failIndices, er2
8687
}

batch/batch_updater.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type BatchUpdater struct {
2222
sql.Scanner
2323
}
2424
}
25+
2526
func NewBatchUpdater(db *sql.DB, tableName string, modelType reflect.Type, options ...func(context.Context, interface{}) (interface{}, error)) *BatchUpdater {
2627
var mp func(context.Context, interface{}) (interface{}, error)
2728
if len(options) > 0 && options[0] != nil {
@@ -74,7 +75,7 @@ func (w *BatchUpdater) Write(ctx context.Context, models interface{}) ([]int, []
7475
if er0 != nil {
7576
s0 := reflect.ValueOf(models2)
7677
_, er0b := q.InterfaceSlice(models2)
77-
failIndices = q.ToArrayIndex(s0, failIndices)
78+
failIndices = ToArrayIndex(s0, failIndices)
7879
return successIndices, failIndices, er0b
7980
}
8081
} else {
@@ -84,11 +85,11 @@ func (w *BatchUpdater) Write(ctx context.Context, models interface{}) ([]int, []
8485
s := reflect.ValueOf(models)
8586
if err == nil {
8687
// Return full success
87-
successIndices = q.ToArrayIndex(s, successIndices)
88+
successIndices = ToArrayIndex(s, successIndices)
8889
return successIndices, failIndices, err
8990
} else {
9091
// Return full fail
91-
failIndices = q.ToArrayIndex(s, failIndices)
92+
failIndices = ToArrayIndex(s, failIndices)
9293
}
9394
return successIndices, failIndices, err
9495
}

batch/batch_writer.go

+9-8
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ import (
1010
)
1111

1212
type BatchWriter struct {
13-
db *sql.DB
14-
tableName string
15-
Map func(ctx context.Context, model interface{}) (interface{}, error)
16-
Schema *q.Schema
17-
ToArray func(interface{}) interface {
13+
db *sql.DB
14+
tableName string
15+
Map func(ctx context.Context, model interface{}) (interface{}, error)
16+
Schema *q.Schema
17+
ToArray func(interface{}) interface {
1818
driver.Valuer
1919
sql.Scanner
2020
}
2121
}
22+
2223
func NewBatchWriter(db *sql.DB, tableName string, modelType reflect.Type, options ...func(context.Context, interface{}) (interface{}, error)) *BatchWriter {
2324
var mp func(context.Context, interface{}) (interface{}, error)
2425
if len(options) > 0 && options[0] != nil {
@@ -48,7 +49,7 @@ func (w *BatchWriter) Write(ctx context.Context, models interface{}) ([]int, []i
4849
if er0 != nil {
4950
s0 := reflect.ValueOf(m)
5051
_, er0b := q.InterfaceSlice(m)
51-
failIndices = q.ToArrayIndex(s0, failIndices)
52+
failIndices = ToArrayIndex(s0, failIndices)
5253
return successIndices, failIndices, er0b
5354
}
5455
} else {
@@ -59,11 +60,11 @@ func (w *BatchWriter) Write(ctx context.Context, models interface{}) ([]int, []i
5960

6061
if er2 == nil {
6162
// Return full success
62-
successIndices = q.ToArrayIndex(s, successIndices)
63+
successIndices = ToArrayIndex(s, successIndices)
6364
return successIndices, failIndices, er2
6465
} else {
6566
// Return full fail
66-
failIndices = q.ToArrayIndex(s, failIndices)
67+
failIndices = ToArrayIndex(s, failIndices)
6768
}
6869
return successIndices, failIndices, er2
6970
}

database.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,21 @@ func ExecStmt(ctx context.Context, stmt *sql.Stmt, values ...interface{}) (int64
5454
return result.RowsAffected()
5555
}
5656

57-
func handleDuplicate(db *sql.DB, err error) (int64, error) {
57+
func HandleDuplicate(db *sql.DB, err error) (int64, error) {
5858
x := err.Error()
5959
driver := GetDriver(db)
6060
if driver == DriverPostgres && strings.Contains(x, "pq: duplicate key value violates unique constraint") {
61-
return 0, nil
61+
return 0, err
6262
} else if driver == DriverMysql && strings.Contains(x, "Error 1062: Duplicate entry") {
63-
return 0, nil //mysql Error 1062: Duplicate entry 'a-1' for key 'PRIMARY'
63+
return 0, err //mysql Error 1062: Duplicate entry 'a-1' for key 'PRIMARY'
6464
} else if driver == DriverOracle && strings.Contains(x, "ORA-00001: unique constraint") {
65-
return 0, nil //mysql Error 1062: Duplicate entry 'a-1' for key 'PRIMARY'
65+
return 0, err //mysql Error 1062: Duplicate entry 'a-1' for key 'PRIMARY'
6666
} else if driver == DriverMssql && strings.Contains(x, "Violation of PRIMARY KEY constraint") {
67-
return 0, nil //Violation of PRIMARY KEY constraint 'PK_aa'. Cannot insert duplicate key in object 'dbo.aa'. The duplicate key value is (b, 2).
67+
return 0, err //Violation of PRIMARY KEY constraint 'PK_aa'. Cannot insert duplicate key in object 'dbo.aa'. The duplicate key value is (b, 2).
6868
} else if driver == DriverSqlite3 && strings.Contains(x, "UNIQUE constraint failed") {
69-
return 0, nil
69+
return 0, err
7070
}
71-
return 0, err
71+
return -1, err
7272
}
7373
func Insert(ctx context.Context, db *sql.DB, table string, model interface{}, options ...*Schema) (int64, error) {
7474
var schema *Schema
@@ -103,7 +103,7 @@ func InsertWithVersion(ctx context.Context, db *sql.DB, table string, model inte
103103

104104
result, err := db.ExecContext(ctx, queryInsert, values...)
105105
if err != nil {
106-
return handleDuplicate(db, err)
106+
return HandleDuplicate(db, err)
107107
}
108108
return result.RowsAffected()
109109
}

loader.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func InitFields(modelType reflect.Type, db *sql.DB) (map[string]int, string, fun
6161
buildParam := GetBuild(db)
6262
return fieldsIndex, fields, buildParam, driver, nil
6363
}
64+
6465
type Loader struct {
6566
Database *sql.DB
6667
BuildParam func(i int) string
@@ -379,7 +380,7 @@ func BuildFindById(selectAll string, buildParam func(i int) string, id interface
379380
} else {
380381
conditions := make([]string, 0)
381382
if ids, ok := id.(map[string]interface{}); ok {
382-
j := 0
383+
j := 1
383384
for _, keyJson := range keys {
384385
columnName := mapJsonColumnKeys[keyJson]
385386
if idk, ok1 := ids[keyJson]; ok1 {

writer.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ func (s *Writer) Insert(ctx context.Context, model interface{}) (int64, error) {
181181
if tx == nil {
182182
result, err := s.Database.ExecContext(ctx, queryInsert, values...)
183183
if err != nil {
184-
return handleDuplicate(s.Database, err)
184+
return HandleDuplicate(s.Database, err)
185185
}
186186
return result.RowsAffected()
187187
} else {

0 commit comments

Comments
 (0)