forked from apache/cassandra-gocql-driver
-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathintegration_serialization_scylla_test.go
218 lines (191 loc) · 7.04 KB
/
integration_serialization_scylla_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
//go:build integration && scylla
// +build integration,scylla
package gocql
import (
"bytes"
"fmt"
"gopkg.in/inf.v0"
"math/big"
"reflect"
"testing"
"unsafe"
"github.com/gocql/gocql/internal/tests/serialization/valcases"
)
func TestSerializationSimpleTypesCassandra(t *testing.T) {
const (
pkColumn = "test_id"
testColumn = "test_col"
)
typeCases := valcases.GetSimple()
session := createSession(t)
defer session.Close()
//Checks data and values conversion
t.Run("Marshal", func(t *testing.T) {
for _, tc := range typeCases {
checkTypeMarshal(t, tc)
}
})
t.Run("Unmarshal", func(t *testing.T) {
for _, tc := range typeCases {
checkTypeUnmarshal(t, tc)
}
})
//Create are tables
tables := make([]string, len(typeCases))
for i, tc := range typeCases {
table := "test_" + tc.CQLName
stmt := fmt.Sprintf(`CREATE TABLE %s (%s text, %s %s, PRIMARY KEY (test_id))`, table, pkColumn, testColumn, tc.CQLName)
if err := createTable(session, stmt); err != nil {
t.Fatalf("failed to create table for cqltype (%s) with error '%v'", tc.CQLName, err)
}
tables[i] = table
}
//Check Insert and Select are values
t.Run("InsertSelect", func(t *testing.T) {
for i, tc := range typeCases {
insertStmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES(?, ?)", tables[i], pkColumn, testColumn)
selectStmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ?", testColumn, tables[i], pkColumn)
checkTypeInsertSelect(t, session, insertStmt, selectStmt, tc)
}
})
}
func checkTypeMarshal(t *testing.T, tc valcases.SimpleTypeCases) {
cqlName := tc.CQLName
t.Run(cqlName, func(t *testing.T) {
tp := Type(tc.CQLType)
cqlType := NewNativeType(4, tp, "")
for _, valCase := range tc.Cases {
for _, langCase := range valCase.LangCases {
receivedData, err := Marshal(cqlType, langCase.Value)
if !langCase.ErrInsert && err != nil {
t.Errorf("failed to marshal case (%s)(%s) value (%T) with error '%v'", valCase.Name, langCase.LangType, langCase.Value, err)
} else if langCase.ErrInsert && err == nil {
t.Errorf("expected an error on marshal case (%s)(%s) value (%T)(%[2]v), but have no error", valCase.Name, langCase.LangType, langCase.Value)
} else if !bytes.Equal(valCase.Data, receivedData) {
t.Errorf("failed to equal case (%s)(%s) data: expected %d, got %d", valCase.Name, langCase.LangType, valCase.Data, receivedData)
}
}
}
})
}
func checkTypeUnmarshal(t *testing.T, tc valcases.SimpleTypeCases) {
cqlName := tc.CQLName
t.Run(cqlName, func(t *testing.T) {
tp := Type(tc.CQLType)
cqlType := NewNativeType(4, tp, "")
for _, valCase := range tc.Cases {
for _, langCase := range valCase.LangCases {
received := newRef(langCase.Value)
err := Unmarshal(cqlType, valCase.Data, received)
if !langCase.ErrSelect && err != nil {
t.Errorf("failed to unmarshal case (%s)(%s) value (%T) with error '%v'", valCase.Name, langCase.LangType, langCase.Value, err)
}
if langCase.ErrSelect && err == nil {
t.Errorf("expected an error on unmarshal case (%s)(%s) value (%T)(%[2]v), but have no error", valCase.Name, langCase.LangType, langCase.Value)
}
received = deReference(received)
if !equalVals(langCase.Value, received) {
t.Errorf("failed to equal case (%s)(%s) value: expected %d, got %d", valCase.Name, langCase.LangType, langCase.Value, received)
}
}
}
})
}
func checkTypeInsertSelect(t *testing.T, session *Session, insertStmt, selectStmt string, tc valcases.SimpleTypeCases) {
cqlName := tc.CQLName
t.Run(cqlName, func(t *testing.T) {
tp := Type(tc.CQLType)
cqlType := NewNativeType(4, tp, "")
for _, valCase := range tc.Cases {
valCaseName := valCase.Name
for _, langCase := range valCase.LangCases {
var insertedValue interface{}
//Check Insert value as values
insertedValue = langCase.Value
err := session.Query(insertStmt, valCaseName, insertedValue).Exec()
if !langCase.ErrInsert && err != nil {
t.Errorf("failed to insert case (%s) value (%T)(%[2]v) with error '%v'", valCaseName, insertedValue, err)
} else if langCase.ErrInsert && err == nil {
t.Errorf("expected an error on insert case (%s) value (%T)(%[2]v), but have no error", valCaseName, insertedValue, err)
}
//Check Select value as value
selectedValue := newRef(langCase.Value)
err = session.Query(selectStmt, valCase.Name).Scan(selectedValue)
if !langCase.ErrSelect && err != nil {
t.Errorf("failed to select case (%s) value (%T) with error '%v'", valCaseName, selectedValue, err)
} else if langCase.ErrSelect && err == nil {
t.Errorf("expected an error on select case (%s) value (%T)(%[2]v), but have no error", valCaseName, selectedValue)
}
selectedValue = deReference(selectedValue)
if !equalVals(langCase.Value, selectedValue) {
t.Errorf("failed to equal case (%s) value: expected: %d, got: %d", valCaseName, langCase.Value, selectedValue)
}
//Check Select value as bytes
selectedValue = &DirectUnmarshal{}
err = session.Query(selectStmt, valCase.Name).Scan(selectedValue)
if err != nil {
t.Errorf("failed to select case (%s) value (%T) for cqltype (%s) with error '%v'", valCaseName, selectedValue, cqlType, err)
}
selectedValue = *(*[]byte)(selectedValue.(*DirectUnmarshal))
if !equalVals(valCase.Data, selectedValue) {
t.Errorf("failed to equal case (%s) value for cqltype (%s): expected: %d, got: %d", valCaseName, cqlType, valCase.Data, selectedValue)
}
}
}
})
}
// newRef returns the nil reference to the input type value (*type)(nil)
func newRef(in interface{}) interface{} {
out := reflect.New(reflect.TypeOf(in)).Interface()
return out
}
func deReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}
func equalVals(in1, in2 interface{}) bool {
rin1 := reflect.ValueOf(in1)
rin2 := reflect.ValueOf(in2)
if rin1.Kind() != rin2.Kind() {
return false
}
if rin1.Kind() == reflect.Ptr && (rin1.IsNil() || rin2.IsNil()) {
return rin1.IsNil() && rin2.IsNil()
}
switch vin1 := in1.(type) {
case float32:
vin2 := in2.(float32)
return *(*[4]byte)(unsafe.Pointer(&vin1)) == *(*[4]byte)(unsafe.Pointer(&vin2))
case *float32:
vin2 := in2.(*float32)
return *(*[4]byte)(unsafe.Pointer(vin1)) == *(*[4]byte)(unsafe.Pointer(vin2))
case float64:
vin2 := in2.(float64)
return *(*[8]byte)(unsafe.Pointer(&vin1)) == *(*[8]byte)(unsafe.Pointer(&vin2))
case *float64:
vin2 := in2.(*float64)
return *(*[8]byte)(unsafe.Pointer(vin1)) == *(*[8]byte)(unsafe.Pointer(vin2))
case big.Int:
vin2 := in2.(big.Int)
return vin1.Cmp(&vin2) == 0
case *big.Int:
vin2 := in2.(*big.Int)
return vin1.Cmp(vin2) == 0
case inf.Dec:
vin2 := in2.(inf.Dec)
if vin1.Scale() != vin2.Scale() {
return false
}
return vin1.UnscaledBig().Cmp(vin2.UnscaledBig()) == 0
case *inf.Dec:
vin2 := in2.(*inf.Dec)
if vin1.Scale() != vin2.Scale() {
return false
}
return vin1.UnscaledBig().Cmp(vin2.UnscaledBig()) == 0
case fmt.Stringer:
vin2 := in2.(fmt.Stringer)
return vin1.String() == vin2.String()
default:
return reflect.DeepEqual(in1, in2)
}
}