@@ -4,15 +4,20 @@ import (
4
4
"context"
5
5
"database/sql"
6
6
"database/sql/driver"
7
+ "encoding/json"
8
+ "errors"
7
9
"fmt"
8
- q "github.com/core-go/sql"
9
10
"reflect"
11
+
12
+ q "github.com/core-go/sql"
10
13
)
11
14
12
15
type GenericAdapter [T any , K any ] struct {
13
16
* Adapter [* T ]
14
17
Map map [string ]int
15
18
Fields string
19
+ Keys []string
20
+ IdMap bool
16
21
}
17
22
18
23
func NewGenericAdapter [T any , K any ](db * sql.DB , tableName string , opts ... func (int ) string ) (* GenericAdapter [T , K ], error ) {
@@ -29,17 +34,31 @@ func NewSqlGenericAdapterWithVersionAndArray[T any, K any](db *sql.DB, tableName
29
34
if err != nil {
30
35
return nil , err
31
36
}
37
+
32
38
var t T
33
39
modelType := reflect .TypeOf (t )
34
- if modelType .Kind () == reflect .Ptr {
35
- modelType = modelType . Elem ( )
40
+ if modelType .Kind () != reflect .Struct {
41
+ return nil , errors . New ( "T must be a struct" )
36
42
}
43
+
44
+ _ , primaryKeys := q .FindPrimaryKeys (modelType )
45
+ var k K
46
+ kType := reflect .TypeOf (k )
47
+ idMap := false
48
+ if len (primaryKeys ) > 1 {
49
+ if kType .Kind () == reflect .Map {
50
+ idMap = true
51
+ } else if kType .Kind () != reflect .Struct {
52
+ return nil , errors .New ("For composite keys, K must be a struct or a map" )
53
+ }
54
+ }
55
+
37
56
fieldsIndex , err := q .GetColumnIndexes (modelType )
38
57
if err != nil {
39
58
return nil , err
40
59
}
41
60
fields := q .BuildFieldsBySchema (adapter .Schema )
42
- return & GenericAdapter [T , K ]{adapter , fieldsIndex , fields }, nil
61
+ return & GenericAdapter [T , K ]{adapter , fieldsIndex , fields , primaryKeys , idMap }, nil
43
62
}
44
63
func (a * GenericAdapter [T , K ]) All (ctx context.Context ) ([]T , error ) {
45
64
var objs []T
@@ -48,10 +67,31 @@ func (a *GenericAdapter[T, K]) All(ctx context.Context) ([]T, error) {
48
67
err := q .Query (ctx , tx , a .Map , & objs , query )
49
68
return objs , err
50
69
}
70
+ func toMap (obj interface {}) (map [string ]interface {}, error ) {
71
+ b , err := json .Marshal (obj )
72
+ if err != nil {
73
+ return nil , err
74
+ }
75
+ im := make (map [string ]interface {})
76
+ er2 := json .Unmarshal (b , & im )
77
+ return im , er2
78
+ }
79
+ func (a * GenericAdapter [T , K ]) getId (k K ) (interface {}, error ) {
80
+ if len (a .Keys ) >= 2 && ! a .IdMap {
81
+ ri , err := toMap (k )
82
+ return ri , err
83
+ } else {
84
+ return k , nil
85
+ }
86
+ }
51
87
func (a * GenericAdapter [T , K ]) Load (ctx context.Context , id K ) (* T , error ) {
88
+ ip , er0 := a .getId (id )
89
+ if er0 != nil {
90
+ return nil , er0
91
+ }
52
92
var objs []T
53
93
query := fmt .Sprintf ("select %s from %s " , a .Fields , a .Table )
54
- query1 , args := q .BuildFindByIdWithDB (a .DB , query , id , a .JsonColumnMap , a .Schema .SKeys , a .BuildParam )
94
+ query1 , args := q .BuildFindByIdWithDB (a .DB , query , ip , a .JsonColumnMap , a .Schema .SKeys , a .BuildParam )
55
95
tx := q .GetExec (ctx , a .DB , a .TxKey )
56
96
err := q .Query (ctx , tx , a .Map , & objs , query1 , args ... )
57
97
if err != nil {
@@ -63,8 +103,12 @@ func (a *GenericAdapter[T, K]) Load(ctx context.Context, id K) (*T, error) {
63
103
return nil , nil
64
104
}
65
105
func (a * GenericAdapter [T , K ]) Exist (ctx context.Context , id K ) (bool , error ) {
106
+ ip , er0 := a .getId (id )
107
+ if er0 != nil {
108
+ return false , er0
109
+ }
66
110
query := fmt .Sprintf ("select %s from %s " , a .Schema .SColumns [0 ], a .Table )
67
- query1 , args := q .BuildFindByIdWithDB (a .DB , query , id , a .JsonColumnMap , a .Schema .SKeys , a .BuildParam )
111
+ query1 , args := q .BuildFindByIdWithDB (a .DB , query , ip , a .JsonColumnMap , a .Schema .SKeys , a .BuildParam )
68
112
tx := q .GetExec (ctx , a .DB , a .TxKey )
69
113
rows , err := tx .QueryContext (ctx , query1 , args ... )
70
114
if err != nil {
@@ -77,8 +121,12 @@ func (a *GenericAdapter[T, K]) Exist(ctx context.Context, id K) (bool, error) {
77
121
return false , nil
78
122
}
79
123
func (a * GenericAdapter [T , K ]) Delete (ctx context.Context , id K ) (int64 , error ) {
124
+ ip , er0 := a .getId (id )
125
+ if er0 != nil {
126
+ return - 1 , er0
127
+ }
80
128
query := fmt .Sprintf ("delete from %s " , a .Table )
81
- query1 , args := q .BuildFindByIdWithDB (a .DB , query , id , a .JsonColumnMap , a .Schema .SKeys , a .BuildParam )
129
+ query1 , args := q .BuildFindByIdWithDB (a .DB , query , ip , a .JsonColumnMap , a .Schema .SKeys , a .BuildParam )
82
130
tx := q .GetExec (ctx , a .DB , a .TxKey )
83
131
res , err := tx .ExecContext (ctx , query1 , args ... )
84
132
if err != nil {
0 commit comments