Skip to content

Commit

Permalink
feat: add api QuerySet
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Feb 14, 2025
1 parent b2c54e5 commit 1f2c947
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
44 changes: 22 additions & 22 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (db *Sqlite) mustcompile(q string) *sql.Stmt {
// Create 生成数据库.
// 默认结构体的第一个元素为主键.
// 返回错误.
func (db *Sqlite) Create(table string, objptr interface{}, additional ...string) (err error) {
func (db *Sqlite) Create(table string, objptr any, additional ...string) (err error) {
if db.db == nil {
err = ErrNilDB
return
Expand Down Expand Up @@ -153,7 +153,7 @@ func (db *Sqlite) Create(table string, objptr interface{}, additional ...string)
// 如果 PK 存在会覆盖.
// 默认结构体的第一个元素为主键.
// 返回错误.
func (db *Sqlite) Insert(table string, objptr interface{}) error {
func (db *Sqlite) Insert(table string, objptr any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error {
// 如果 PK 存在会报错.
// 默认结构体的第一个元素为主键.
// 返回错误.
func (db *Sqlite) InsertUnique(table string, objptr interface{}) error {
func (db *Sqlite) InsertUnique(table string, objptr any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func (db *Sqlite) InsertUnique(table string, objptr interface{}) error {
// condition 可为"WHERE id = 0".
// 默认字段与结构体元素顺序一致.
// 返回错误.
func (db *Sqlite) Find(table string, objptr interface{}, condition string, questions ...interface{}) error {
func (db *Sqlite) Find(table string, objptr any, condition string, questions ...any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -322,7 +322,7 @@ func (db *Sqlite) Find(table string, objptr interface{}, condition string, quest
// condition 可为"WHERE id = 0".
// 默认字段与结构体元素顺序一致.
// 返回错误.
func Find[T any](db *Sqlite, table string, condition string, questions ...interface{}) (obj T, err error) {
func Find[T any](db *Sqlite, table string, condition string, questions ...any) (obj T, err error) {
if db.db == nil {
err = ErrNilDB
return
Expand Down Expand Up @@ -360,7 +360,7 @@ func Find[T any](db *Sqlite, table string, condition string, questions ...interf
// q 为一整条查询语句, 慎用.
// 默认字段与结构体元素顺序一致.
// 返回错误.
func (db *Sqlite) Query(q string, objptr interface{}, args ...interface{}) error {
func (db *Sqlite) Query(q string, objptr any, args ...any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -394,7 +394,7 @@ func (db *Sqlite) Query(q string, objptr interface{}, args ...interface{}) error
// q 为一整条查询语句, 慎用.
// 默认字段与结构体元素顺序一致.
// 返回错误.
func Query[T any](db *Sqlite, q string, args ...interface{}) (obj T, err error) {
func Query[T any](db *Sqlite, q string, args ...any) (obj T, err error) {
if db.db == nil {
err = ErrNilDB
return
Expand Down Expand Up @@ -431,7 +431,7 @@ func Query[T any](db *Sqlite, q string, args ...interface{}) (obj T, err error)
// condition 可为"WHERE id = 0".
// 默认字段与结构体元素顺序一致.
// 返回错误.
func (db *Sqlite) CanFind(table string, condition string, questions ...interface{}) bool {
func (db *Sqlite) CanFind(table string, condition string, questions ...any) bool {
if db.db == nil {
return false
}
Expand Down Expand Up @@ -460,7 +460,7 @@ func (db *Sqlite) CanFind(table string, condition string, questions ...interface
// q 为一整条查询语句, 慎用.
// 默认字段与结构体元素顺序一致.
// 返回错误.
func (db *Sqlite) CanQuery(q string, questions ...interface{}) bool {
func (db *Sqlite) CanQuery(q string, questions ...any) bool {
if db.db == nil {
return false
}
Expand Down Expand Up @@ -488,7 +488,7 @@ func (db *Sqlite) CanQuery(q string, questions ...interface{}) bool {
// condition 可为"WHERE id = 0".
// 默认字段与结构体元素顺序一致.
// 返回错误.
func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f func() error, questions ...interface{}) error {
func (db *Sqlite) FindFor(table string, objptr any, condition string, f func() error, questions ...any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -529,7 +529,7 @@ func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f
// condition 可为"WHERE id = 0".
// 默认字段与结构体元素顺序一致.
// 返回错误.
func FindAll[T any](db *Sqlite, table string, condition string, questions ...interface{}) ([]*T, error) {
func FindAll[T any](db *Sqlite, table string, condition string, questions ...any) ([]*T, error) {
if db.db == nil {
return nil, ErrNilDB
}
Expand Down Expand Up @@ -574,7 +574,7 @@ func FindAll[T any](db *Sqlite, table string, condition string, questions ...int
// q 为一整条查询语句, 慎用.
// 默认字段与结构体元素顺序一致.
// 返回错误.
func (db *Sqlite) QueryFor(q string, objptr interface{}, f func() error, questions ...interface{}) error {
func (db *Sqlite) QueryFor(q string, objptr any, f func() error, questions ...any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -614,7 +614,7 @@ func (db *Sqlite) QueryFor(q string, objptr interface{}, f func() error, questio
// q 为一整条查询语句, 慎用.
// 默认字段与结构体元素顺序一致.
// 返回错误.
func QueryAll[T any](db *Sqlite, q string, questions ...interface{}) ([]*T, error) {
func QueryAll[T any](db *Sqlite, q string, questions ...any) ([]*T, error) {
if db.db == nil {
return nil, ErrNilDB
}
Expand Down Expand Up @@ -655,15 +655,15 @@ func QueryAll[T any](db *Sqlite, q string, questions ...interface{}) ([]*T, erro
}

// Pick 从 table 随机一行
func (db *Sqlite) Pick(table string, objptr interface{}, questions ...interface{}) error {
func (db *Sqlite) Pick(table string, objptr any, questions ...any) error {
if db.db == nil {
return ErrNilDB
}
return db.Find(table, objptr, "ORDER BY RANDOM() limit 1", questions...)
}

// PickFor 从 table 随机多行
func (db *Sqlite) PickFor(table string, n uint, objptr interface{}, f func() error, questions ...interface{}) error {
func (db *Sqlite) PickFor(table string, n uint, objptr any, f func() error, questions ...any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -701,7 +701,7 @@ func (db *Sqlite) ListTables() (s []string, err error) {
// Del 删除数据库表项.
// condition 可为"WHERE id = 0".
// 返回错误.
func (db *Sqlite) Del(table string, condition string, questions ...interface{}) error {
func (db *Sqlite) Del(table string, condition string, questions ...any) error {
if db.db == nil {
return ErrNilDB
}
Expand Down Expand Up @@ -753,7 +753,7 @@ func (db *Sqlite) Count(table string) (num int, err error) {
}

// tags 反射 返回结构体对象的 tag 数组
func tags(objptr interface{}) (tags []string) {
func tags(objptr any) (tags []string) {
elem := reflect.ValueOf(objptr).Elem()
flen := elem.Type().NumField()
tags = make([]string, flen)
Expand All @@ -771,7 +771,7 @@ func tags(objptr interface{}) (tags []string) {
}

// kinds 反射 返回结构体对象的 kinds 数组
func kinds(objptr interface{}) (kinds []string) {
func kinds(objptr any) (kinds []string) {
elem := reflect.ValueOf(objptr).Elem()
// 判断第一个元素是否为匿名字段
if elem.Type().Field(0).Anonymous {
Expand Down Expand Up @@ -863,10 +863,10 @@ func kinds(objptr interface{}) (kinds []string) {
var typstrarr = reflect.SliceOf(reflect.TypeOf(""))

// values 反射 返回结构体对象的 values 数组
func values(objptr interface{}) (values []interface{}) {
func values(objptr any) (values []any) {
elem := reflect.ValueOf(objptr).Elem()
flen := elem.Type().NumField()
values = make([]interface{}, flen)
values = make([]any, flen)
for i := 0; i < flen; i++ {
if elem.Field(i).Type() == typstrarr { // []string
values[i] = elem.Field(i).Index(0).Interface() // string
Expand All @@ -878,10 +878,10 @@ func values(objptr interface{}) (values []interface{}) {
}

// addrs 反射 返回结构体对象的 addrs 数组
func addrs(objptr interface{}) (addrs []interface{}) {
func addrs(objptr any) (addrs []any) {
elem := reflect.ValueOf(objptr).Elem()
flen := elem.Type().NumField()
addrs = make([]interface{}, flen)
addrs = make([]any, flen)
for i := 0; i < flen; i++ {
if elem.Field(i).Type() == typstrarr { // []string
s := reflect.ValueOf(make([]string, 1))
Expand Down
3 changes: 2 additions & 1 deletion sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ func TestPackUnpack(t *testing.T) {
t.Fatal(err)
}
tmp = teststruct{O: &o}
err = db.Find("test", &tmp, "WHERE A = ?", 3)
q, s := QuerySet("WHERE A", "IN", []int{3})
err = db.Find("test", &tmp, q, s...)
if err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 24 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package sql

import "strings"

// QuerySet returns "q op (?,?,...,?)", []T
func QuerySet[T any](q, op string, s []T) (string, []any) {
sz := len(s)
if sz == 0 {
panic("len(s) must > 0")
}
sb := strings.Builder{}
qs := make([]any, sz)
sb.WriteString(q)
sb.WriteByte(' ')
sb.WriteString(op)
sb.WriteString(" (?")
qs[0] = s[0]
for i := 1; i < sz; i++ {
sb.WriteString(",?")
qs[i] = s[i]
}
sb.WriteByte(')')
return sb.String(), qs
}

0 comments on commit 1f2c947

Please sign in to comment.