From 14d4115260d3b6f4bcfe7aea46a31a5f667be4bb Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 18 Feb 2025 11:41:06 +0100 Subject: [PATCH] refactor: clean up Signed-off-by: Andres Taylor --- go/mysql/conn.go | 44 ++++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 0ade311b835..4dc27a5902d 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -1219,6 +1219,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) { StatementID: c.StatementID, PrepareStmt: query, } + c.PrepareData[c.StatementID] = prepare statement, err := handler.Env().Parser().ParseStrictDDL(query) if err != nil { @@ -1228,16 +1229,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) { } } - paramsCount := uint16(0) - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { - switch node := node.(type) { - case *sqlparser.Argument: - if strings.HasPrefix(node.Name, "v") { - paramsCount++ - } - } - return true, nil - }, statement) + paramsCount := countArguments(statement) if paramsCount > 0 { prepare.ParamsCount = paramsCount @@ -1245,26 +1237,42 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) { prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) } - bindVars := make(map[string]*querypb.BindVariable, paramsCount) - for i := range paramsCount { - parameterID := fmt.Sprintf("v%d", i+1) - bindVars[parameterID] = &querypb.BindVariable{} - } - - c.PrepareData[c.StatementID] = prepare + bindVars := prepareBindVars(paramsCount) fld, err := handler.ComPrepare(c, query, bindVars) if err != nil { return c.writeErrorPacketFromErrorAndLog(err) } - if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { + if err := c.writePrepare(fld, prepare); err != nil { log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err) return false } return true } +func prepareBindVars(paramsCount uint16) map[string]*querypb.BindVariable { + bindVars := make(map[string]*querypb.BindVariable, paramsCount) + for i := range paramsCount { + parameterID := fmt.Sprintf("v%d", i+1) + bindVars[parameterID] = &querypb.BindVariable{} + } + return bindVars +} + +func countArguments(statement sqlparser.Statement) (paramsCount uint16) { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { + switch node := node.(type) { + case *sqlparser.Argument: + if strings.HasPrefix(node.Name, "v") { + paramsCount++ + } + } + return true, nil + }, statement) + return +} + func (c *Conn) handleComSetOption(data []byte) bool { operation, ok := c.parseComSetOption(data) c.recycleReadPacket()