Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

go/mysql: improve GTID encoding for OK packet #16361

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 10 additions & 44 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,15 +787,15 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro
// assuming CapabilityClientProtocol41
length += 4 // status_flags + warnings

hasSessionTrack := c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack
Copy link
Contributor

@mattlord mattlord Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but using != 0 is more obvious and standard — at least within Vitess — although there's no practical change.

hasGtidData := hasSessionTrack && packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged

var gtidData []byte
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {

if hasSessionTrack {
length += lenEncStringSize(packetOk.info) // info
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
gtidData = getLenEncString([]byte(packetOk.sessionStateData))
gtidData = append([]byte{0x00}, gtidData...)
gtidData = getLenEncString(gtidData)
gtidData = append([]byte{0x03}, gtidData...)
gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...)
if hasGtidData {
gtidData = encGtidData(packetOk.sessionStateData)
length += len(gtidData)
}
} else {
Expand All @@ -809,50 +809,17 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro
data.writeLenEncInt(packetOk.lastInsertID)
data.writeUint16(packetOk.statusFlags)
data.writeUint16(packetOk.warnings)
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
if hasSessionTrack {
data.writeLenEncString(packetOk.info)
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
data.writeEOFString(string(gtidData))
if hasGtidData {
data.writeEOFBytes(gtidData)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I can further specialize this explicitly into a data.writeGtidData(packetOk.sessionStateData) and directly write it into the buffer with the extra allocation, then change above to a like, length += encGtidDataSize(packetOk.sessionStateData)

Then we avoid the intermediary gtidData []byte being allocated.

}
} else {
data.writeEOFString(packetOk.info)
}
return c.writeEphemeralPacket()
}

func getLenEncString(value []byte) []byte {
data := getLenEncInt(uint64(len(value)))
return append(data, value...)
}

func getLenEncInt(i uint64) []byte {
var data []byte
switch {
case i < 251:
data = append(data, byte(i))
case i < 1<<16:
data = append(data, 0xfc)
data = append(data, byte(i))
data = append(data, byte(i>>8))
case i < 1<<24:
data = append(data, 0xfd)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
default:
data = append(data, 0xfe)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
data = append(data, byte(i>>24))
data = append(data, byte(i>>32))
data = append(data, byte(i>>40))
data = append(data, byte(i>>48))
data = append(data, byte(i>>56))
}
return data
}

func (c *Conn) WriteErrorAndLog(format string, args ...interface{}) bool {
return c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, format, args...)
}
Expand Down Expand Up @@ -1290,7 +1257,6 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
c.PrepareData[c.StatementID] = prepare

fld, err := handler.ComPrepare(c, queries[0], bindVars)

if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}
Expand Down
51 changes: 51 additions & 0 deletions go/mysql/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,53 @@ func readLenEncStringAsBytesCopy(data []byte, pos int) ([]byte, int, bool) {
return result, pos + s, true
}

// > encGtidData("xxx")
//
// [07 03 05 00 03 78 78 78]
// | | | | | |------|
// | | | | | ^-------- "xxx"
// | | | | ^------------ length of rest of bytes, 3
// | | | ^--------------- fixed 0x00
// | | ^------------------ length of rest of bytes, 5
// | ^--------------------- fixed 0x03 (SESSION_TRACK_GTIDS)
// ^------------------------ length of rest of bytes, 7
//
// This is ultimately lenencoded strings of length encoded strings, or:
// > lenenc(0x03 + lenenc(0x00 + lenenc(data)))
func encGtidData(data string) []byte {
const SessionTrackGtids = 0x03

// calculate total size up front to do 1 allocation
// encoded layout is:
// lenenc(0x03 + lenenc(0x00 + lenenc(data)))
dataSize := uint64(len(data))
dataLenEncSize := uint64(lenEncIntSize(dataSize))

wrapSize := uint64(dataSize + dataLenEncSize + 1)
wrapLenEncSize := uint64(lenEncIntSize(wrapSize))

totalSize := uint64(wrapSize + wrapLenEncSize + 1)
totalLenEncSize := uint64(lenEncIntSize(totalSize))

gtidData := make([]byte, int(totalSize+totalLenEncSize))

pos := 0
pos = writeLenEncInt(gtidData, pos, totalSize)

gtidData[pos] = SessionTrackGtids
pos++

pos = writeLenEncInt(gtidData, pos, wrapSize)

gtidData[pos] = 0x00
pos++

pos = writeLenEncInt(gtidData, pos, dataSize)
writeEOFString(gtidData, pos, data)

return gtidData
}

type coder struct {
data []byte
pos int
Expand Down Expand Up @@ -397,3 +444,7 @@ func (d *coder) writeLenEncString(value string) {
func (d *coder) writeEOFString(value string) {
d.pos += copy(d.data[d.pos:], value)
}

func (d *coder) writeEOFBytes(value []byte) {
d.pos += copy(d.data[d.pos:], value)
}
30 changes: 29 additions & 1 deletion go/mysql/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mysql

import (
"bytes"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -72,7 +73,6 @@ func TestEncLenInt(t *testing.T) {
// Check failed decoding.
_, _, ok = readLenEncInt(test.encoded[:len(test.encoded)-1], 0)
assert.False(t, ok, "readLenEncInt returned ok=true for shorter value %x", test.value)

}
}

Expand Down Expand Up @@ -355,6 +355,27 @@ func TestWriteZeroes(t *testing.T) {
})
}

func TestEncGtidData(t *testing.T) {
tests := []struct {
data string
header []byte
}{
{"", []byte{0x04, 0x03, 0x02, 0x00, 0x00}},
{"xxx", []byte{0x07, 0x03, 0x05, 0x00, 0x03}},
{strings.Repeat("x", 256), []byte{
/* 264 */ 0xfc, 0x08, 0x01,
/* constant */ 0x03,
/* 260 */ 0xfc, 0x04, 0x01,
/* constant */ 0x00,
/* 256 */ 0xfc, 0x00, 0x01,
}},
}
for _, test := range tests {
got := encGtidData(test.data)
assert.Equal(t, append(test.header, test.data...), got)
}
}

func BenchmarkEncWriteInt(b *testing.B) {
buf := make([]byte, 16)

Expand Down Expand Up @@ -451,3 +472,10 @@ func BenchmarkEncReadInt(b *testing.B) {
}
})
}

func BenchmarkEncGtidData(b *testing.B) {
b.ReportAllocs()
for range b.N {
_ = encGtidData("xxx")
}
}
Loading