Skip to content

Commit

Permalink
优化缓存, 增加 float, blob 类型支持
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Sep 21, 2022
1 parent d150e34 commit d171792
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 103 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@

# Dependency directories (remove the comment below to include it)
# vendor/

*.db
204 changes: 101 additions & 103 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (
"github.com/FloatTech/ttl"
)

var (
ErrNilDB = errors.New("sqlite: db is not initialized")
ErrNullResult = errors.New("sqlite: null result")
)

// Sqlite 数据库对象
type Sqlite struct {
DB *sql.DB
Expand Down Expand Up @@ -52,19 +57,46 @@ func (db *Sqlite) Close() (err error) {
}

func wraptable(table string) string {
if unicode.IsDigit([]rune(table)[0]) {
first := []rune(table)[0]
if first < unicode.MaxLatin1 && unicode.IsDigit(first) {
return "[" + table + "]"
} else {
return "'" + table + "'"
}
}

func (db *Sqlite) compile(q string) (*sql.Stmt, error) {
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return nil, err
}
db.stmtcache.Set(q, stmt)
}
return stmt, nil
}

func (db *Sqlite) mustcompile(q string) *sql.Stmt {
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
panic(err)
}
db.stmtcache.Set(q, stmt)
}
return stmt
}

// Create 生成数据库
// 默认结构体的第一个元素为主键
// 返回错误
func (db *Sqlite) Create(table string, objptr interface{}) (err error) {
if db.DB == nil {
err = errors.New("db is nil")
err = ErrNilDB
return
}
var (
Expand All @@ -89,7 +121,11 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) {
}
}
}
_, err = db.DB.Exec(strings.Join(cmd, " ") + ";")
stmt, err := db.compile(strings.Join(cmd, " ") + ";")
if err != nil {
return err
}
_, err = stmt.Exec()
return
}

Expand All @@ -99,10 +135,14 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) {
// 返回错误
func (db *Sqlite) Insert(table string, objptr interface{}) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
table = wraptable(table)
rows, err := db.DB.Query("SELECT * FROM " + table + " limit 1;")
stmt, err := db.compile("SELECT * FROM " + table + " limit 1;")
if err != nil {
return err
}
rows, err := stmt.Query()
if err != nil {
return err
}
Expand Down Expand Up @@ -148,15 +188,9 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error {
}
}
}
q := strings.Join(cmd, " ") + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return err
}
db.stmtcache.Set(q, stmt)
stmt, err = db.compile(strings.Join(cmd, " ") + ";")
if err != nil {
return err
}
_, err = stmt.Exec(vals...)
return err
Expand All @@ -168,10 +202,14 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error {
// 返回错误
func (db *Sqlite) InsertUnique(table string, objptr interface{}) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
table = wraptable(table)
rows, err := db.DB.Query("SELECT * FROM '" + table + "' limit 1;")
stmt, err := db.compile("SELECT * FROM " + table + " limit 1;")
if err != nil {
return err
}
rows, err := stmt.Query()
if err != nil {
return err
}
Expand Down Expand Up @@ -217,17 +255,10 @@ func (db *Sqlite) InsertUnique(table string, objptr interface{}) error {
}
}
}
q := strings.Join(cmd, " ") + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return err
}
db.stmtcache.Set(q, stmt)
stmt, err = db.compile(strings.Join(cmd, " ") + ";")
if err != nil {
return err
}

_, err = stmt.Exec(vals...)
return err
}
Expand All @@ -238,17 +269,12 @@ func (db *Sqlite) InsertUnique(table string, objptr interface{}) error {
// 返回错误
func (db *Sqlite) Find(table string, objptr interface{}, condition string) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
q := "SELECT * FROM " + wraptable(table) + " " + condition + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return err
}
db.stmtcache.Set(q, stmt)
stmt, err := db.compile(q)
if err != nil {
return err
}
rows, err := stmt.Query()
if err != nil {
Expand All @@ -260,7 +286,7 @@ func (db *Sqlite) Find(table string, objptr interface{}, condition string) error
defer rows.Close()

if !rows.Next() {
return errors.New("sql.Find: null result")
return ErrNullResult
}
err = rows.Scan(addrs(objptr)...)
for rows.Next() {
Expand All @@ -281,14 +307,9 @@ func (db *Sqlite) CanFind(table string, condition string) bool {
return false
}
q := "SELECT * FROM " + wraptable(table) + " " + condition + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return false
}
db.stmtcache.Set(q, stmt)
stmt, err := db.compile(q)
if err != nil {
return false
}
rows, err := stmt.Query()
if err != nil {
Expand All @@ -312,17 +333,12 @@ func (db *Sqlite) CanFind(table string, condition string) bool {
// 返回错误
func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f func() error) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
q := "SELECT * FROM " + wraptable(table) + " " + condition + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return err
}
db.stmtcache.Set(q, stmt)
stmt, err := db.compile(q)
if err != nil {
return err
}
rows, err := stmt.Query()
if err != nil {
Expand All @@ -334,7 +350,7 @@ func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f
defer rows.Close()

if !rows.Next() {
return errors.New("sql.FindFor: null result")
return ErrNullResult
}
err = rows.Scan(addrs(objptr)...)
if err == nil {
Expand All @@ -355,7 +371,7 @@ func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f
// Pick 从 table 随机一行
func (db *Sqlite) Pick(table string, objptr interface{}) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
return db.Find(table, objptr, "ORDER BY RANDOM() limit 1")
}
Expand All @@ -364,19 +380,9 @@ func (db *Sqlite) Pick(table string, objptr interface{}) error {
// 返回所有表名+错误
func (db *Sqlite) ListTables() (s []string, err error) {
if db.DB == nil {
return nil, errors.New("db is nil")
return nil, ErrNilDB
}
q := "SELECT name FROM sqlite_master where type='table' order by name;"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return nil, err
}
db.stmtcache.Set(q, stmt)
}
rows, err := stmt.Query()
rows, err := db.mustcompile("SELECT name FROM sqlite_master where type='table' order by name;").Query()
if err != nil {
return
}
Expand All @@ -403,63 +409,47 @@ func (db *Sqlite) ListTables() (s []string, err error) {
// 返回错误
func (db *Sqlite) Del(table string, condition string) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
q := "DELETE FROM " + wraptable(table) + " " + condition + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return err
}
db.stmtcache.Set(q, stmt)
stmt, err := db.compile(q)
if err != nil {
return err
}
_, err := stmt.Exec()
_, err = stmt.Exec()
return err
}

// Drop 删除数据库表
func (db *Sqlite) Drop(table string) error {
if db.DB == nil {
return errors.New("db is nil")
return ErrNilDB
}
q := "DROP TABLE " + wraptable(table) + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return err
}
db.stmtcache.Set(q, stmt)
stmt, err := db.compile(q)
if err != nil {
return err
}
_, err := stmt.Exec()
_, err = stmt.Exec()
return err
}

// Count 查询数据库行数
// 返回行数以及错误
func (db *Sqlite) Count(table string) (num int, err error) {
if db.DB == nil {
return 0, errors.New("db is nil")
return 0, ErrNilDB
}
q := "SELECT COUNT(1) FROM " + wraptable(table) + ";"
stmt := db.stmtcache.Get(q)
if stmt == nil {
var err error
stmt, err = db.DB.Prepare(q)
if err != nil {
return 0, err
}
db.stmtcache.Set(q, stmt)
stmt, err := db.compile("SELECT COUNT(1) FROM " + wraptable(table) + ";")
if err != nil {
return 0, err
}
rows, err := stmt.Query()
if err != nil {
return num, err
return 0, err
}
if rows.Err() != nil {
return num, rows.Err()
return 0, rows.Err()
}
if rows.Next() {
err = rows.Scan(&num)
Expand Down Expand Up @@ -507,28 +497,36 @@ func kinds(objptr interface{}) (kinds []string) {
kinds[i] = "SMALLINT"
case "uint16":
kinds[i] = "UNSIGNED SMALLINT"
case "int32":
case "int32", "rune":
kinds[i] = "INT"
case "uint32":
kinds[i] = "UNSIGNED INT"
case "int64":
kinds[i] = "BIGINT"
case "uint64":
kinds[i] = "UNSIGNED BIGINT"
default:
case "float32":
kinds[i] = "FLOAT"
case "float64":
kinds[i] = "DOUBLE"
case "string", "[]string":
kinds[i] = "TEXT"
default:
kinds[i] = "BLOB"
}
}
return
}

var typstrarr = reflect.SliceOf(reflect.TypeOf(""))

// values 反射 返回结构体对象的 values 数组
func values(objptr interface{}) (values []interface{}) {
elem := reflect.ValueOf(objptr).Elem()
flen := elem.Type().NumField()
values = make([]interface{}, flen)
for i := 0; i < flen; i++ {
if elem.Field(i).Type() == reflect.SliceOf(reflect.TypeOf("")) { // []string
if elem.Field(i).Type() == typstrarr { // []string
values[i] = elem.Field(i).Index(0).Interface() // string
continue
}
Expand All @@ -543,7 +541,7 @@ func addrs(objptr interface{}) (addrs []interface{}) {
flen := elem.Type().NumField()
addrs = make([]interface{}, flen)
for i := 0; i < flen; i++ {
if elem.Field(i).Type() == reflect.SliceOf(reflect.TypeOf("")) { // []string
if elem.Field(i).Type() == typstrarr { // []string
s := reflect.ValueOf(make([]string, 1))
elem.Field(i).Set(s)
addrs[i] = s.Index(0).Addr().Interface() // string
Expand Down
Loading

0 comments on commit d171792

Please sign in to comment.