From d1717923e08a73f163da814a95fd50be089bfe8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 21 Sep 2022 09:07:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=BC=93=E5=AD=98,=20?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=20float,=20blob=20=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + sqlite.go | 204 ++++++++++++++++++++++++------------------------- sqlite_test.go | 86 +++++++++++++++++++++ 3 files changed, 189 insertions(+), 103 deletions(-) create mode 100644 sqlite_test.go diff --git a/.gitignore b/.gitignore index 66fd13c..5039ef7 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +*.db diff --git a/sqlite.go b/sqlite.go index 2ed9d0b..95066de 100644 --- a/sqlite.go +++ b/sqlite.go @@ -14,6 +14,11 @@ import ( "github.com/FloatTech/ttl" ) +var ( + ErrNilDB = errors.New("sqlite: db is not initialized") + ErrNullResult = errors.New("sqlite: null result") +) + // Sqlite 数据库对象 type Sqlite struct { DB *sql.DB @@ -52,19 +57,46 @@ func (db *Sqlite) Close() (err error) { } func wraptable(table string) string { - if unicode.IsDigit([]rune(table)[0]) { + first := []rune(table)[0] + if first < unicode.MaxLatin1 && unicode.IsDigit(first) { return "[" + table + "]" } else { return "'" + table + "'" } } +func (db *Sqlite) compile(q string) (*sql.Stmt, error) { + stmt := db.stmtcache.Get(q) + if stmt == nil { + var err error + stmt, err = db.DB.Prepare(q) + if err != nil { + return nil, err + } + db.stmtcache.Set(q, stmt) + } + return stmt, nil +} + +func (db *Sqlite) mustcompile(q string) *sql.Stmt { + stmt := db.stmtcache.Get(q) + if stmt == nil { + var err error + stmt, err = db.DB.Prepare(q) + if err != nil { + panic(err) + } + db.stmtcache.Set(q, stmt) + } + return stmt +} + // Create 生成数据库 // 默认结构体的第一个元素为主键 // 返回错误 func (db *Sqlite) Create(table string, objptr interface{}) (err error) { if db.DB == nil { - err = errors.New("db is nil") + err = ErrNilDB return } var ( @@ -89,7 +121,11 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) { } } } - _, err = db.DB.Exec(strings.Join(cmd, " ") + ";") + stmt, err := db.compile(strings.Join(cmd, " ") + ";") + if err != nil { + return err + } + _, err = stmt.Exec() return } @@ -99,10 +135,14 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) { // 返回错误 func (db *Sqlite) Insert(table string, objptr interface{}) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } table = wraptable(table) - rows, err := db.DB.Query("SELECT * FROM " + table + " limit 1;") + stmt, err := db.compile("SELECT * FROM " + table + " limit 1;") + if err != nil { + return err + } + rows, err := stmt.Query() if err != nil { return err } @@ -148,15 +188,9 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error { } } } - q := strings.Join(cmd, " ") + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return err - } - db.stmtcache.Set(q, stmt) + stmt, err = db.compile(strings.Join(cmd, " ") + ";") + if err != nil { + return err } _, err = stmt.Exec(vals...) return err @@ -168,10 +202,14 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error { // 返回错误 func (db *Sqlite) InsertUnique(table string, objptr interface{}) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } table = wraptable(table) - rows, err := db.DB.Query("SELECT * FROM '" + table + "' limit 1;") + stmt, err := db.compile("SELECT * FROM " + table + " limit 1;") + if err != nil { + return err + } + rows, err := stmt.Query() if err != nil { return err } @@ -217,17 +255,10 @@ func (db *Sqlite) InsertUnique(table string, objptr interface{}) error { } } } - q := strings.Join(cmd, " ") + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return err - } - db.stmtcache.Set(q, stmt) + stmt, err = db.compile(strings.Join(cmd, " ") + ";") + if err != nil { + return err } - _, err = stmt.Exec(vals...) return err } @@ -238,17 +269,12 @@ func (db *Sqlite) InsertUnique(table string, objptr interface{}) error { // 返回错误 func (db *Sqlite) Find(table string, objptr interface{}, condition string) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } q := "SELECT * FROM " + wraptable(table) + " " + condition + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return err - } - db.stmtcache.Set(q, stmt) + stmt, err := db.compile(q) + if err != nil { + return err } rows, err := stmt.Query() if err != nil { @@ -260,7 +286,7 @@ func (db *Sqlite) Find(table string, objptr interface{}, condition string) error defer rows.Close() if !rows.Next() { - return errors.New("sql.Find: null result") + return ErrNullResult } err = rows.Scan(addrs(objptr)...) for rows.Next() { @@ -281,14 +307,9 @@ func (db *Sqlite) CanFind(table string, condition string) bool { return false } q := "SELECT * FROM " + wraptable(table) + " " + condition + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return false - } - db.stmtcache.Set(q, stmt) + stmt, err := db.compile(q) + if err != nil { + return false } rows, err := stmt.Query() if err != nil { @@ -312,17 +333,12 @@ func (db *Sqlite) CanFind(table string, condition string) bool { // 返回错误 func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f func() error) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } q := "SELECT * FROM " + wraptable(table) + " " + condition + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return err - } - db.stmtcache.Set(q, stmt) + stmt, err := db.compile(q) + if err != nil { + return err } rows, err := stmt.Query() if err != nil { @@ -334,7 +350,7 @@ func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f defer rows.Close() if !rows.Next() { - return errors.New("sql.FindFor: null result") + return ErrNullResult } err = rows.Scan(addrs(objptr)...) if err == nil { @@ -355,7 +371,7 @@ func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f // Pick 从 table 随机一行 func (db *Sqlite) Pick(table string, objptr interface{}) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } return db.Find(table, objptr, "ORDER BY RANDOM() limit 1") } @@ -364,19 +380,9 @@ func (db *Sqlite) Pick(table string, objptr interface{}) error { // 返回所有表名+错误 func (db *Sqlite) ListTables() (s []string, err error) { if db.DB == nil { - return nil, errors.New("db is nil") + return nil, ErrNilDB } - q := "SELECT name FROM sqlite_master where type='table' order by name;" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return nil, err - } - db.stmtcache.Set(q, stmt) - } - rows, err := stmt.Query() + rows, err := db.mustcompile("SELECT name FROM sqlite_master where type='table' order by name;").Query() if err != nil { return } @@ -403,38 +409,28 @@ func (db *Sqlite) ListTables() (s []string, err error) { // 返回错误 func (db *Sqlite) Del(table string, condition string) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } q := "DELETE FROM " + wraptable(table) + " " + condition + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return err - } - db.stmtcache.Set(q, stmt) + stmt, err := db.compile(q) + if err != nil { + return err } - _, err := stmt.Exec() + _, err = stmt.Exec() return err } // Drop 删除数据库表 func (db *Sqlite) Drop(table string) error { if db.DB == nil { - return errors.New("db is nil") + return ErrNilDB } q := "DROP TABLE " + wraptable(table) + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return err - } - db.stmtcache.Set(q, stmt) + stmt, err := db.compile(q) + if err != nil { + return err } - _, err := stmt.Exec() + _, err = stmt.Exec() return err } @@ -442,24 +438,18 @@ func (db *Sqlite) Drop(table string) error { // 返回行数以及错误 func (db *Sqlite) Count(table string) (num int, err error) { if db.DB == nil { - return 0, errors.New("db is nil") + return 0, ErrNilDB } - q := "SELECT COUNT(1) FROM " + wraptable(table) + ";" - stmt := db.stmtcache.Get(q) - if stmt == nil { - var err error - stmt, err = db.DB.Prepare(q) - if err != nil { - return 0, err - } - db.stmtcache.Set(q, stmt) + stmt, err := db.compile("SELECT COUNT(1) FROM " + wraptable(table) + ";") + if err != nil { + return 0, err } rows, err := stmt.Query() if err != nil { - return num, err + return 0, err } if rows.Err() != nil { - return num, rows.Err() + return 0, rows.Err() } if rows.Next() { err = rows.Scan(&num) @@ -507,7 +497,7 @@ func kinds(objptr interface{}) (kinds []string) { kinds[i] = "SMALLINT" case "uint16": kinds[i] = "UNSIGNED SMALLINT" - case "int32": + case "int32", "rune": kinds[i] = "INT" case "uint32": kinds[i] = "UNSIGNED INT" @@ -515,20 +505,28 @@ func kinds(objptr interface{}) (kinds []string) { kinds[i] = "BIGINT" case "uint64": kinds[i] = "UNSIGNED BIGINT" - default: + case "float32": + kinds[i] = "FLOAT" + case "float64": + kinds[i] = "DOUBLE" + case "string", "[]string": kinds[i] = "TEXT" + default: + kinds[i] = "BLOB" } } return } +var typstrarr = reflect.SliceOf(reflect.TypeOf("")) + // values 反射 返回结构体对象的 values 数组 func values(objptr interface{}) (values []interface{}) { elem := reflect.ValueOf(objptr).Elem() flen := elem.Type().NumField() values = make([]interface{}, flen) for i := 0; i < flen; i++ { - if elem.Field(i).Type() == reflect.SliceOf(reflect.TypeOf("")) { // []string + if elem.Field(i).Type() == typstrarr { // []string values[i] = elem.Field(i).Index(0).Interface() // string continue } @@ -543,7 +541,7 @@ func addrs(objptr interface{}) (addrs []interface{}) { flen := elem.Type().NumField() addrs = make([]interface{}, flen) for i := 0; i < flen; i++ { - if elem.Field(i).Type() == reflect.SliceOf(reflect.TypeOf("")) { // []string + if elem.Field(i).Type() == typstrarr { // []string s := reflect.ValueOf(make([]string, 1)) elem.Field(i).Set(s) addrs[i] = s.Index(0).Addr().Interface() // string diff --git a/sqlite_test.go b/sqlite_test.go new file mode 100644 index 0000000..150c7d9 --- /dev/null +++ b/sqlite_test.go @@ -0,0 +1,86 @@ +package sql + +import ( + "bytes" + "testing" + "time" +) + +func TestPackUnpack(t *testing.T) { + type teststruct struct { + A bool + B int8 + C uint8 + D uint16 + E int32 + F uint32 + G int64 + H uint64 + I float32 + J float64 + K []byte + L string + M []string + } + db := Sqlite{DBPath: "test.db"} + err := db.Open(time.Hour) + if err != nil { + t.Fatal(err) + } + err = db.Create("test", &teststruct{}) + if err != nil { + t.Fatal(err) + } + inst := teststruct{true, 2, 3, 4, 5, 6, 7, 8, 9.0, 10.0, []byte{1, 2, 3}, "123", []string{"123", "456"}} + err = db.Insert("test", &inst) + if err != nil { + t.Fatal(err) + } + tmp := teststruct{} + err = db.Find("test", &tmp, "") + if err != nil { + t.Fatal(err) + } + if tmp.A != inst.A { + t.Fail() + } + if tmp.B != inst.B { + t.Fail() + } + if tmp.C != inst.C { + t.Fail() + } + if tmp.D != inst.D { + t.Fail() + } + if tmp.E != inst.E { + t.Fail() + } + if tmp.F != inst.F { + t.Fail() + } + if tmp.F != inst.F { + t.Fail() + } + if tmp.G != inst.G { + t.Fail() + } + if tmp.H != inst.H { + t.Fail() + } + if tmp.I != inst.I { + t.Fail() + } + if tmp.J != inst.J { + t.Fail() + } + if !bytes.Equal(tmp.K, inst.K) { + t.Fail() + } + if tmp.L != inst.L { + t.Fail() + } + if tmp.M[0] != inst.M[0] { + t.Fail() + } +}