From fb3ab471d9c270748b4f8a1bc8f6e0a75302e7eb Mon Sep 17 00:00:00 2001 From: Chris Wolf <89790028+Christopher-Wolf-ibm@users.noreply.github.com> Date: Mon, 26 Sep 2022 14:44:55 -0400 Subject: [PATCH] Fixed bug discussed in #52 (#53) Signed-off-by: Jeff Chauvin Signed-off-by: Jeff Chauvin --- fluent/client/ws_client.go | 20 ++++++++++++++--- fluent/client/ws_client_test.go | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/fluent/client/ws_client.go b/fluent/client/ws_client.go index e14c580..5c33891 100644 --- a/fluent/client/ws_client.go +++ b/fluent/client/ws_client.go @@ -25,6 +25,7 @@ SOFTWARE. package client import ( + "bytes" "crypto/tls" "errors" "net/http" @@ -265,10 +266,14 @@ func (c *WSClient) Reconnect() (err error) { // Send sends a single msgp.Encodable across the wire. func (c *WSClient) Send(e protocol.ChunkEncoder) error { + var ( + err error + rawMessageData bytes.Buffer + ) // Check for an async connection error and return it here. // In most cases, the client will not care about reading from // the connection, so checking for the error here is sufficient. - if err := c.getErr(); err != nil { + if err = c.getErr(); err != nil { return err // TODO: wrap this } @@ -278,8 +283,17 @@ func (c *WSClient) Send(e protocol.ChunkEncoder) error { return errors.New("no active session") } - // msgp.Encode makes use of object pool to decrease allocations - return msgp.Encode(session.Connection, e) + err = msgp.Encode(&rawMessageData, e) + if err != nil { + return err + } + + bytesData := rawMessageData.Bytes() + // Write function does not accurately return the number of bytes written + // so it would be ineffective to compare + _, err = c.session.Connection.Write(bytesData) + + return err } // SendRaw sends an array of bytes across the wire. diff --git a/fluent/client/ws_client_test.go b/fluent/client/ws_client_test.go index 98107ff..4c04430 100644 --- a/fluent/client/ws_client_test.go +++ b/fluent/client/ws_client_test.go @@ -28,6 +28,7 @@ import ( "bytes" "crypto/tls" "errors" + "math/rand" "net/http" "net/http/httptest" "strings" @@ -46,6 +47,7 @@ import ( "github.com/gorilla/websocket" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/tinylib/msgp/msgp" ) var _ = Describe("IAMAuthInfo", func() { @@ -303,6 +305,44 @@ var _ = Describe("WSClient", func() { Expect(bytes.Equal(bits, writtenbits)).To(BeTrue()) }) + It("Sends the message", func() { + msgBytes, _ := msg.MarshalMsg(nil) + Expect(client.Send(&msg)).ToNot(HaveOccurred()) + + writtenBytes := conn.WriteArgsForCall(0) + Expect(bytes.Equal(msgBytes, writtenBytes)).To(BeTrue()) + }) + + When("The message is large", func() { + const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + var ( + expectedBytes int + messageSize = 65536 + ) + + JustBeforeEach(func() { + seededRand := rand.New( + rand.NewSource(time.Now().UnixNano())) + m := make([]byte, messageSize) + for i := range m { + m[i] = charset[seededRand.Intn(len(charset))] + } + msg.Record = m + + var b bytes.Buffer + Expect(msgp.Encode(&b, &msg)).ToNot(HaveOccurred()) + expectedBytes = len(b.Bytes()) + }) + + It("Sends the correct number of bits", func() { + Expect(client.Send(&msg)).ToNot(HaveOccurred()) + Expect(conn.WriteCallCount()).To(Equal(1)) + writtenBytes := len(conn.WriteArgsForCall(0)) + Expect(writtenBytes).To(Equal(expectedBytes)) + }) + }) + When("the connection is disconnected", func() { JustBeforeEach(func() { err := client.Disconnect()