Skip to content

Commit

Permalink
Merge pull request volatiletech#1083 from fdegiuli/null
Browse files Browse the repository at this point in the history
Generate IN/NIN whereHelpers for nullable types
  • Loading branch information
stephenafamo authored Aug 19, 2022
2 parents 63efd17 + ee0ed02 commit 78272e8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
8 changes: 5 additions & 3 deletions boilingcore/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,11 @@ var templateFunctions = template.FuncMap{
"whereClause": strmangle.WhereClause,

// Alias and text helping
"aliasCols": func(ta TableAlias) func(string) string { return ta.Column },
"usesPrimitives": usesPrimitives,
"isPrimitive": isPrimitive,
"aliasCols": func(ta TableAlias) func(string) string { return ta.Column },
"usesPrimitives": usesPrimitives,
"isPrimitive": isPrimitive,
"isNullPrimitive": isNullPrimitive,
"convertNullToPrimitive": convertNullToPrimitive,
"splitLines": func(a string) []string {
if a == "" {
return nil
Expand Down
25 changes: 25 additions & 0 deletions boilingcore/text_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,28 @@ func isPrimitive(typ string) bool {

return false
}

func isNullPrimitive(typ string) bool {
switch typ {
// Numeric
case "null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64":
return true
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64":
return true
case "null.Float32", "null.Float64":
return true
case "null.Byte", "null.String":
return true
}

return false
}

// convertNullToPrimitive takes a type name and returns the underlying primitive type name X if it is a `null.X`,
// otherwise it returns the input value unchanged
func convertNullToPrimitive(typ string) string {
if isNullPrimitive(typ) {
return strings.ToLower(strings.Split(typ, ".")[1])
}
return typ
}
6 changes: 3 additions & 3 deletions templates/main/00_struct.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ func (w {{$name}}) LT(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field,
func (w {{$name}}) LTE(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.LTE, x) }
func (w {{$name}}) GT(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.GT, x) }
func (w {{$name}}) GTE(x {{.Type}}) qm.QueryMod { return qmhelper.Where(w.field, qmhelper.GTE, x) }
{{if or (isPrimitive .Type) (isEnumDBType .DBType) -}}
func (w {{$name}}) IN(slice []{{.Type}}) qm.QueryMod {
{{if or (isPrimitive .Type) (isNullPrimitive .Type) (isEnumDBType .DBType) -}}
func (w {{$name}}) IN(slice []{{convertNullToPrimitive .Type}}) qm.QueryMod {
values := make([]interface{}, 0, len(slice))
for _, value := range slice {
values = append(values, value)
}
return qm.WhereIn(fmt.Sprintf("%s IN ?", w.field), values...)
}
func (w {{$name}}) NIN(slice []{{.Type}}) qm.QueryMod {
func (w {{$name}}) NIN(slice []{{convertNullToPrimitive .Type}}) qm.QueryMod {
values := make([]interface{}, 0, len(slice))
for _, value := range slice {
values = append(values, value)
Expand Down

0 comments on commit 78272e8

Please sign in to comment.