Skip to content

Commit

Permalink
refactor: clean up
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Feb 18, 2025
1 parent 550b054 commit 14d4115
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
StatementID: c.StatementID,
PrepareStmt: query,
}
c.PrepareData[c.StatementID] = prepare

statement, err := handler.Env().Parser().ParseStrictDDL(query)
if err != nil {
Expand All @@ -1228,43 +1229,50 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
}
}

paramsCount := uint16(0)
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
case *sqlparser.Argument:
if strings.HasPrefix(node.Name, "v") {
paramsCount++
}
}
return true, nil
}, statement)
paramsCount := countArguments(statement)

if paramsCount > 0 {
prepare.ParamsCount = paramsCount
prepare.ParamsType = make([]int32, paramsCount)
prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount)
}

bindVars := make(map[string]*querypb.BindVariable, paramsCount)
for i := range paramsCount {
parameterID := fmt.Sprintf("v%d", i+1)
bindVars[parameterID] = &querypb.BindVariable{}
}

c.PrepareData[c.StatementID] = prepare
bindVars := prepareBindVars(paramsCount)

fld, err := handler.ComPrepare(c, query, bindVars)
if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}

if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil {
if err := c.writePrepare(fld, prepare); err != nil {
log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err)
return false
}
return true
}

func prepareBindVars(paramsCount uint16) map[string]*querypb.BindVariable {
bindVars := make(map[string]*querypb.BindVariable, paramsCount)
for i := range paramsCount {
parameterID := fmt.Sprintf("v%d", i+1)
bindVars[parameterID] = &querypb.BindVariable{}
}
return bindVars
}

func countArguments(statement sqlparser.Statement) (paramsCount uint16) {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
case *sqlparser.Argument:
if strings.HasPrefix(node.Name, "v") {
paramsCount++
}
}
return true, nil
}, statement)
return
}

func (c *Conn) handleComSetOption(data []byte) bool {
operation, ok := c.parseComSetOption(data)
c.recycleReadPacket()
Expand Down

0 comments on commit 14d4115

Please sign in to comment.