From 879b7d8712e497925b051ff4c166d5a37d3f155c Mon Sep 17 00:00:00 2001 From: zhouhui8915 Date: Mon, 7 Sep 2015 15:10:09 +0800 Subject: [PATCH] websocket ok, long polling TODO --- TODO | 3 + attachment.go | 170 +++++++++++++++++++++++ client.go | 182 +++++++++++++++++++++---- client_conn.go | 131 +++++++++++++++--- example/main.go | 52 ++++++- ioutil.go | 32 +++++ message_reader.go | 60 +++++++++ parser.go | 336 ++++++++++++++++++++++++++++++++++++++++++++++ trim_writer.go | 45 +++++++ 9 files changed, 969 insertions(+), 42 deletions(-) create mode 100644 TODO create mode 100644 attachment.go create mode 100644 message_reader.go create mode 100644 parser.go create mode 100644 trim_writer.go diff --git a/TODO b/TODO new file mode 100644 index 0000000..1a3830f --- /dev/null +++ b/TODO @@ -0,0 +1,3 @@ +1.long polling +2.connect timeout +3. \ No newline at end of file diff --git a/attachment.go b/attachment.go new file mode 100644 index 0000000..22ae7f4 --- /dev/null +++ b/attachment.go @@ -0,0 +1,170 @@ +package socketio_client + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "reflect" +) + +// Attachment is an attachment handler used in emit args. All attachments will send as binary in transport layer. When use attachment, make sure use as pointer. +// +// For example: +// +// type Arg struct { +// Title string `json:"title"` +// File *Attachment `json:"file"` +// } +// +// f, _ := os.Open("./some_file") +// arg := Arg{ +// Title: "some_file", +// File: &Attachment{ +// Data: f, +// } +// } +// +// socket.Emit("send file", arg) +// socket.On("get file", func(so Socket, arg Arg) { +// b, _ := ioutil.ReadAll(arg.File.Data) +// }) +type Attachment struct { + + // Data is the ReadWriter of the attachment data. + Data io.ReadWriter + num int +} + +func encodeAttachments(v interface{}) []io.Reader { + index := 0 + return encodeAttachmentValue(reflect.ValueOf(v), &index) +} + +func encodeAttachmentValue(v reflect.Value, index *int) []io.Reader { + v = reflect.Indirect(v) + ret := []io.Reader{} + if !v.IsValid() { + return ret + } + switch v.Kind() { + case reflect.Struct: + if v.Type().Name() == "Attachment" { + a, ok := v.Addr().Interface().(*Attachment) + if !ok { + panic("can't convert") + } + a.num = *index + ret = append(ret, a.Data) + (*index)++ + return ret + } + for i, n := 0, v.NumField(); i < n; i++ { + var r []io.Reader + r = encodeAttachmentValue(v.Field(i), index) + ret = append(ret, r...) + } + case reflect.Map: + if v.IsNil() { + return ret + } + for _, key := range v.MapKeys() { + var r []io.Reader + r = encodeAttachmentValue(v.MapIndex(key), index) + ret = append(ret, r...) + } + case reflect.Slice: + if v.IsNil() { + return ret + } + fallthrough + case reflect.Array: + for i, n := 0, v.Len(); i < n; i++ { + var r []io.Reader + r = encodeAttachmentValue(v.Index(i), index) + ret = append(ret, r...) + } + case reflect.Interface: + ret = encodeAttachmentValue(reflect.ValueOf(v.Interface()), index) + } + return ret +} + +func decodeAttachments(v interface{}, binary [][]byte) error { + return decodeAttachmentValue(reflect.ValueOf(v), binary) +} + +func decodeAttachmentValue(v reflect.Value, binary [][]byte) error { + v = reflect.Indirect(v) + if !v.IsValid() { + return fmt.Errorf("invalid value") + } + switch v.Kind() { + case reflect.Struct: + if v.Type().Name() == "Attachment" { + a, ok := v.Addr().Interface().(*Attachment) + if !ok { + panic("can't convert") + } + if a.num >= len(binary) || a.num < 0 { + return fmt.Errorf("out of range") + } + if a.Data == nil { + a.Data = bytes.NewBuffer(nil) + } + for b := binary[a.num]; len(b) > 0; { + n, err := a.Data.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil + } + for i, n := 0, v.NumField(); i < n; i++ { + if err := decodeAttachmentValue(v.Field(i), binary); err != nil { + return err + } + } + case reflect.Map: + if v.IsNil() { + return nil + } + for _, key := range v.MapKeys() { + if err := decodeAttachmentValue(v.MapIndex(key), binary); err != nil { + return err + } + } + case reflect.Slice: + if v.IsNil() { + return nil + } + fallthrough + case reflect.Array: + for i, n := 0, v.Len(); i < n; i++ { + if err := decodeAttachmentValue(v.Index(i), binary); err != nil { + return err + } + } + case reflect.Interface: + if err := decodeAttachmentValue(reflect.ValueOf(v.Interface()), binary); err != nil { + return err + } + } + return nil +} + +func (a Attachment) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("{\"_placeholder\":true,\"num\":%d}", a.num)), nil +} + +func (a *Attachment) UnmarshalJSON(b []byte) error { + var v struct { + Num int `json:"num"` + } + if err := json.Unmarshal(b, &v); err != nil { + return err + } + a.num = v.Num + return nil +} diff --git a/client.go b/client.go index 37dd35c..df3aa2d 100644 --- a/client.go +++ b/client.go @@ -1,27 +1,28 @@ package socketio_client import ( - "reflect" "net/http" "net/url" - "fmt" + "reflect" ) var defaultTransport = "websocket" type Options struct { - Transport string //protocol name string,websocket polling... - Query map[string]string //url的附加的参数 + Transport string //protocol name string,websocket polling... + Query map[string]string //url的附加的参数 } type Client struct { opts *Options - socket *clientConn + conn *clientConn events map[string]*caller acks map[int]*caller + id int + namespace string } func NewClient(uri string, opts *Options) (client *Client, err error) { @@ -31,35 +32,36 @@ func NewClient(uri string, opts *Options) (client *Client, err error) { exist = true } } - if !exist{ + if !exist { opts.Transport = defaultTransport } request := &http.Request{} - request.URL,err = url.Parse(uri) + request.URL, err = url.Parse(uri) if err != nil { return } - q:= request.URL.Query() - for k,v := range opts.Query{ - q.Set(k,v) + q := request.URL.Query() + for k, v := range opts.Query { + q.Set(k, v) } request.URL.RawQuery = q.Encode() - fmt.Println(request.URL.String()) - - socket,err := newClientConn(opts.Transport,request) + socket, err := newClientConn(opts.Transport, request) if err != nil { return } client = &Client{ - opts:opts, - socket:socket, + opts: opts, + conn: socket, - events:make(map[string]*caller), - acks:make(map[int]*caller), + events: make(map[string]*caller), + acks: make(map[int]*caller), } + + go client.readLoop() + return } @@ -97,20 +99,154 @@ func (client *Client) Emit(message string, args ...interface{}) (err error) { return client.send(args) } +func (client *Client) sendConnect() error { + packet := packet{ + Type: _CONNECT, + Id: -1, + NSP: client.namespace, + } + encoder := newEncoder(client.conn) + return encoder.Encode(packet) +} + func (client *Client) sendId(args []interface{}) (int, error) { - return 0,nil + packet := packet{ + Type: _EVENT, + Id: client.id, + NSP: client.namespace, + Data: args, + } + client.id++ + if client.id < 0 { + client.id = 0 + } + encoder := newEncoder(client.conn) + err := encoder.Encode(packet) + if err != nil { + return -1, nil + } + return packet.Id, nil } func (client *Client) send(args []interface{}) error { - return nil + packet := packet{ + Type: _EVENT, + Id: -1, + NSP: client.namespace, + Data: args, + } + encoder := newEncoder(client.conn) + return encoder.Encode(packet) } -func (client *Client) handshake(uri string)(err error){ +func (client *Client) onPacket(decoder *decoder, packet *packet) ([]interface{}, error) { + var message string + switch packet.Type { + case _CONNECT: + message = "connection" + case _DISCONNECT: + message = "disconnection" + case _ERROR: + message = "error" + case _ACK: + case _BINARY_ACK: + return nil, client.onAck(packet.Id, decoder, packet) + default: + message = decoder.Message() + } + c, ok := client.events[message] + if !ok { + // If the message is not recognized by the server, the decoder.currentCloser + // needs to be closed otherwise the server will be stuck until the e + decoder.Close() + return nil, nil + } + args := c.GetArgs() + olen := len(args) + if olen > 0 { + packet.Data = &args + if err := decoder.DecodeData(packet); err != nil { + return nil, err + } + } + for i := len(args); i < olen; i++ { + args = append(args, nil) + } - return + retV := c.Call(args) + if len(retV) == 0 { + return nil, nil + } + + var err error + if last, ok := retV[len(retV)-1].Interface().(error); ok { + err = last + retV = retV[0 : len(retV)-1] + } + ret := make([]interface{}, len(retV)) + for i, v := range retV { + ret[i] = v.Interface() + } + return ret, err } -func (client *Client) handleMessage(msg []byte)(err error){ +func (client *Client) onAck(id int, decoder *decoder, packet *packet) error { + c, ok := client.acks[id] + if !ok { + return nil + } + delete(client.acks, id) + + args := c.GetArgs() + packet.Data = &args + if err := decoder.DecodeData(packet); err != nil { + return err + } + c.Call(args) + return nil +} - return +func (client *Client) readLoop() error { + defer func() { + p := packet{ + Type: _DISCONNECT, + Id: -1, + } + client.onPacket(nil, &p) + }() + + for { + decoder := newDecoder(client.conn) + var p packet + if err := decoder.Decode(&p); err != nil { + return err + } + ret, err := client.onPacket(decoder, &p) + if err != nil { + return err + } + switch p.Type { + case _CONNECT: + client.namespace = p.NSP + // !!!下面这个不能有,否则会有死循环 + //client.sendConnect() + case _BINARY_EVENT: + fallthrough + case _EVENT: + if p.Id >= 0 { + p := packet{ + Type: _ACK, + Id: p.Id, + NSP: client.namespace, + Data: ret, + } + encoder := newEncoder(client.conn) + if err := encoder.Encode(p); err != nil { + return err + } + } + case _DISCONNECT: + return nil + } + } } diff --git a/client_conn.go b/client_conn.go index 758bd42..c0f156d 100644 --- a/client_conn.go +++ b/client_conn.go @@ -1,6 +1,7 @@ package socketio_client import ( + "encoding/json" "errors" "fmt" "github.com/googollee/go-engine.io/message" @@ -10,6 +11,7 @@ import ( "github.com/googollee/go-engine.io/websocket" "io" "net/http" + "strings" "sync" "time" ) @@ -89,6 +91,7 @@ func newClientConn(transportName string, r *http.Request) (client *clientConn, e } go client.pingLoop() + go client.readLoop() return } @@ -128,8 +131,14 @@ func (c *clientConn) NextWriter(t MessageType) (io.WriteCloser, error) { default: return nil, io.EOF } + c.writerLocker.Lock() ret, err := c.getCurrent().NextWriter(message.MessageType(t), parser.MESSAGE) - return ret, err + if err != nil { + c.writerLocker.Unlock() + return ret, err + } + writer := newConnWriter(ret, &c.writerLocker) + return writer, err } func (c *clientConn) Close() error { @@ -178,6 +187,30 @@ func (c *clientConn) OnPacket(r *parser.PacketDecoder) { fallthrough case parser.PONG: c.pingChan <- true + if c.getState() == stateUpgrading { + p := make([]byte, 64) + _, err := r.Read(p) + if err == nil && strings.Contains(string(p), "probe") { + c.writerLocker.Lock() + w, _ := c.getUpgrade().NextWriter(message.MessageText, parser.UPGRADE) + if w != nil { + io.Copy(w, r) + w.Close() + } + c.writerLocker.Unlock() + + c.upgraded() + //fmt.Println("probe") + + /* + w, _ = c.getCurrent().NextWriter(message.MessageText, parser.MESSAGE) + if w != nil { + w.Write([]byte("2[\"message\",\"testtesttesttesttesttest\"]")) + w.Close() + } + */ + } + } case parser.MESSAGE: closeChan := make(chan struct{}) c.readerChan <- newConnReader(r, closeChan) @@ -231,22 +264,66 @@ func (c *clientConn) onOpen() error { if err != nil { return err } - fmt.Println(pack) - //var p []byte - p := make([]byte, 1024) + p := make([]byte, 4096) l, err := pack.Read(p) - fmt.Println(l) - fmt.Println(err) - fmt.Println(string(p)) + if err != nil { + return err + } + //fmt.Println(string(p)) + + type connectionInfo struct { + Sid string `json:"sid"` + Upgrades []string `json:"upgrades"` + PingInterval time.Duration `json:"pingInterval"` + PingTimeout time.Duration `json:"pingTimeout"` + } + + var msg connectionInfo + err = json.Unmarshal(p[:l], &msg) + if err != nil { + return err + } + msg.PingInterval *= 1000 * 1000 + msg.PingTimeout *= 1000 * 1000 + + //fmt.Println(msg) + c.pingInterval = msg.PingInterval + c.pingTimeout = msg.PingTimeout + c.id = msg.Sid + + /* + q.Set("sid", c.id) + c.request.URL.RawQuery = q.Encode() + + transport, err = creater.Client(c.request) + if err != nil { + return err + } + c.setCurrent("polling", transport) + + pack, err = c.getCurrent().NextReader() + if err != nil { + return err + } + + p2 := make([]byte, 4096) + l, err = pack.Read(p2) + if err != nil { + return err + } + //fmt.Println(string(p2)) + */ + + //upgrade creater, exists = creaters["websocket"] if !exists { return InvalidError } c.request.URL.Scheme = "ws" - q.Set("sid", "0") + q.Set("sid", c.id) q.Set("transport", "websocket") c.request.URL.RawQuery = q.Encode() @@ -254,21 +331,16 @@ func (c *clientConn) onOpen() error { if err != nil { return err } - c.setCurrent("websocket", transport) + c.setUpgrading("websocket", transport) - pack, err = c.getCurrent().NextReader() + w, err := c.getUpgrade().NextWriter(message.MessageText, parser.PING) if err != nil { return err } - fmt.Println(pack) + w.Write([]byte("probe")) + w.Close() - p2 := make([]byte, 1024) - l, err = pack.Read(p2) - fmt.Println(l) - fmt.Println(err) - fmt.Println(string(p2)) - - fmt.Println("end") + //fmt.Println("end") return nil } @@ -360,3 +432,26 @@ func (c *clientConn) pingLoop() { } } } + +func (c *clientConn) readLoop() { + + current := c.getCurrent() + + defer func() { + c.OnClose(current) + }() + + for { + current = c.getCurrent() + if c.getUpgrade() != nil { + current = c.getUpgrade() + } + + pack, err := current.NextReader() + if err != nil { + return + } + c.OnPacket(pack) + pack.Close() + } +} diff --git a/example/main.go b/example/main.go index 7905807..2a7afb3 100644 --- a/example/main.go +++ b/example/main.go @@ -1,5 +1,55 @@ package main +import ( + "github.com/zhouhui8915/go-socket.io-client" + "log" + "bufio" + "os" + "time" +) + func main() { -} + opts := &socketio_client.Options{ + Transport:"websocket", + Query:make(map[string]string), + } + opts.Query["uid"] = "1" + opts.Query["cid"] = "conf_123" + uri := "http://192.168.1.70:9090/socket.io/" + + client,err := socketio_client.NewClient(uri,opts) + if err != nil { + log.Printf("NewClient error:%v\n",err) + return + } + + client.On("error", func() { + log.Printf("on error\n") + }) + client.On("connection", func() { + log.Printf("on connect\n") + }) + client.On("message", func(msg string) { + log.Printf("on message:%v\n", msg) + }) + client.On("disconnection", func() { + log.Printf("on disconnect\n") + }) + + go func() { + authStr := "{\"uid\":\"" + opts.Query["uid"] + "\",\"cid\":\"" + opts.Query["cid"] + "\"}" + for { + client.Emit("authenticate", authStr) + time.Sleep(10 * time.Second) + } + }() + + reader := bufio.NewReader(os.Stdin) + for { + data, _, _ := reader.ReadLine() + command := string(data) + client.Emit("message",command) + log.Printf("send message:%v\n",command) + } +} \ No newline at end of file diff --git a/ioutil.go b/ioutil.go index ec2e8fb..cb2ceb6 100644 --- a/ioutil.go +++ b/ioutil.go @@ -48,3 +48,35 @@ func (w *connWriter) Close() error { }() return w.WriteCloser.Close() } + + +// add with github.com/googollee/go-socket.io/ioutil.go + +type writerHelper struct { + writer io.Writer + err error +} + +func newWriterHelper(w io.Writer) *writerHelper { + return &writerHelper{ + writer: w, + } +} + +func (h *writerHelper) Write(p []byte) { + if h.err != nil { + return + } + for len(p) > 0 { + n, err := h.writer.Write(p) + if err != nil { + h.err = err + return + } + p = p[n:] + } +} + +func (h *writerHelper) Error() error { + return h.err +} diff --git a/message_reader.go b/message_reader.go new file mode 100644 index 0000000..74eab0c --- /dev/null +++ b/message_reader.go @@ -0,0 +1,60 @@ +package socketio_client + +import ( + "bufio" +) + +type messageReader struct { + reader *bufio.Reader + message string + firstRead bool +} + +func newMessageReader(bufr *bufio.Reader) (*messageReader, error) { + if _, err := bufr.ReadBytes('"'); err != nil { + return nil, err + } + msg, err := bufr.ReadBytes('"') + if err != nil { + return nil, err + } + for { + b, err := bufr.Peek(1) + if err != nil { + return nil, err + } + if b[0] == ',' { + bufr.ReadByte() + break + } + if b[0] != ' ' { + break + } + bufr.ReadByte() + } + return &messageReader{ + reader: bufr, + message: string(msg[:len(msg)-1]), + firstRead: true, + }, nil +} + +func (r *messageReader) Message() string { + return r.message +} + +func (r *messageReader) Read(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + if r.firstRead { + r.firstRead = false + b[0] = '[' + n, err := r.reader.Read(b[1:]) + if err != nil { + return -1, err + } + return n + 1, err + } + return r.reader.Read(b) +} diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..1b77332 --- /dev/null +++ b/parser.go @@ -0,0 +1,336 @@ +package socketio_client + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "strconv" +) + +const Protocol = 4 + +type packetType int + +const ( + _CONNECT packetType = iota + _DISCONNECT + _EVENT + _ACK + _ERROR + _BINARY_EVENT + _BINARY_ACK +) + +func (t packetType) String() string { + switch t { + case _CONNECT: + return "connect" + case _DISCONNECT: + return "disconnect" + case _EVENT: + return "event" + case _ACK: + return "ack" + case _ERROR: + return "error" + case _BINARY_EVENT: + return "binary_event" + case _BINARY_ACK: + return "binary_ack" + } + return fmt.Sprintf("unknown(%d)", t) +} + +type frameReader interface { + NextReader() (MessageType, io.ReadCloser, error) +} + +type frameWriter interface { + NextWriter(MessageType) (io.WriteCloser, error) +} + +type packet struct { + Type packetType + NSP string + Id int + Data interface{} + attachNumber int +} + +type encoder struct { + w frameWriter + err error +} + +func newEncoder(w frameWriter) *encoder { + return &encoder{ + w: w, + } +} + +func (e *encoder) Encode(v packet) error { + attachments := encodeAttachments(v.Data) + v.attachNumber = len(attachments) + if v.attachNumber > 0 { + v.Type += _BINARY_EVENT - _EVENT + } + if err := e.encodePacket(v); err != nil { + return err + } + for _, a := range attachments { + if err := e.writeBinary(a); err != nil { + return err + } + } + return nil +} + +func (e *encoder) encodePacket(v packet) error { + writer, err := e.w.NextWriter(MessageText) + if err != nil { + return err + } + defer writer.Close() + + w := newTrimWriter(writer, "\n") + wh := newWriterHelper(w) + wh.Write([]byte{byte(v.Type) + '0'}) + if v.Type == _BINARY_EVENT || v.Type == _BINARY_ACK { + wh.Write([]byte(fmt.Sprintf("%d-", v.attachNumber))) + } + needEnd := false + if v.NSP != "" { + wh.Write([]byte(v.NSP)) + needEnd = true + } + if v.Id >= 0 { + f := "%d" + if needEnd { + f = ",%d" + needEnd = false + } + wh.Write([]byte(fmt.Sprintf(f, v.Id))) + } + if v.Data != nil { + if needEnd { + wh.Write([]byte{','}) + needEnd = false + } + if wh.Error() != nil { + return wh.Error() + } + encoder := json.NewEncoder(w) + return encoder.Encode(v.Data) + } + return wh.Error() +} + +func (e *encoder) writeBinary(r io.Reader) error { + writer, err := e.w.NextWriter(MessageBinary) + if err != nil { + return err + } + defer writer.Close() + + if _, err := io.Copy(writer, r); err != nil { + return err + } + return nil + +} + +type decoder struct { + reader frameReader + message string + current io.Reader + currentCloser io.Closer +} + +func newDecoder(r frameReader) *decoder { + return &decoder{ + reader: r, + } +} + +func (d *decoder) Close() { + if d != nil && d.currentCloser != nil { + d.currentCloser.Close() + d.current = nil + d.currentCloser = nil + } +} + +func (d *decoder) Decode(v *packet) error { + ty, r, err := d.reader.NextReader() + if err != nil { + return err + } + if d.current != nil { + d.Close() + } + defer func() { + if d.current == nil { + r.Close() + } + }() + + if ty != MessageText { + return fmt.Errorf("need text package") + } + reader := bufio.NewReader(r) + + v.Id = -1 + + t, err := reader.ReadByte() + if err != nil { + return err + } + v.Type = packetType(t - '0') + + if v.Type == _BINARY_EVENT || v.Type == _BINARY_ACK { + num, err := reader.ReadBytes('-') + if err != nil { + return err + } + numLen := len(num) + if numLen == 0 { + return fmt.Errorf("invalid packet") + } + n, err := strconv.ParseInt(string(num[:numLen-1]), 10, 64) + if err != nil { + return fmt.Errorf("invalid packet") + } + v.attachNumber = int(n) + } + + next, err := reader.Peek(1) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + if len(next) == 0 { + return fmt.Errorf("invalid packet") + } + + if next[0] == '/' { + path, err := reader.ReadBytes(',') + if err != nil && err != io.EOF { + return err + } + pathLen := len(path) + if pathLen == 0 { + return fmt.Errorf("invalid packet") + } + if err == nil { + path = path[:pathLen-1] + } + v.NSP = string(path) + if err == io.EOF { + return nil + } + } + + id := bytes.NewBuffer(nil) + finish := false + for { + next, err := reader.Peek(1) + if err == io.EOF { + finish = true + break + } + if err != nil { + return err + } + if '0' <= next[0] && next[0] <= '9' { + if err := id.WriteByte(next[0]); err != nil { + return err + } + } else { + break + } + reader.ReadByte() + } + if id.Len() > 0 { + id, err := strconv.ParseInt(id.String(), 10, 64) + if err != nil { + return err + } + v.Id = int(id) + } + if finish { + return nil + } + + switch v.Type { + case _EVENT: + fallthrough + case _BINARY_EVENT: + msgReader, err := newMessageReader(reader) + if err != nil { + return err + } + d.message = msgReader.Message() + d.current = msgReader + d.currentCloser = r + case _ACK: + fallthrough + case _BINARY_ACK: + d.current = reader + d.currentCloser = r + } + return nil +} + +func (d *decoder) Message() string { + return d.message +} + +func (d *decoder) DecodeData(v *packet) error { + if d.current == nil { + return nil + } + defer func() { + d.Close() + }() + decoder := json.NewDecoder(d.current) + if err := decoder.Decode(v.Data); err != nil { + return err + } + if v.Type == _BINARY_EVENT || v.Type == _BINARY_ACK { + binary, err := d.decodeBinary(v.attachNumber) + if err != nil { + return err + } + if err := decodeAttachments(v.Data, binary); err != nil { + return err + } + v.Type -= _BINARY_EVENT - _EVENT + } + return nil +} + +func (d *decoder) decodeBinary(num int) ([][]byte, error) { + ret := make([][]byte, num) + for i := 0; i < num; i++ { + d.currentCloser.Close() + t, r, err := d.reader.NextReader() + if err != nil { + return nil, err + } + d.currentCloser = r + if t == MessageText { + return nil, fmt.Errorf("need binary") + } + b, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + ret[i] = b + } + return ret, nil +} diff --git a/trim_writer.go b/trim_writer.go new file mode 100644 index 0000000..6475fbd --- /dev/null +++ b/trim_writer.go @@ -0,0 +1,45 @@ +package socketio_client + +import ( + "bytes" + "io" +) + +type trimWriter struct { + trimChars string + trimBuf []byte + output io.Writer +} + +func newTrimWriter(w io.Writer, trimChars string) *trimWriter { + return &trimWriter{ + trimChars: trimChars, + output: w, + } +} + +func (w *trimWriter) Write(p []byte) (int, error) { + out := bytes.TrimRight(p, w.trimChars) + buf := p[len(out):] + var written int + if (len(out) > 0) && (w.trimBuf != nil) { + var err error + if written, err = w.output.Write(w.trimBuf); err != nil { + return 0, err + } + w.trimBuf = nil + } + if w.trimBuf != nil { + w.trimBuf = append(w.trimBuf, buf...) + } else { + w.trimBuf = buf + } + if len(p) == 0 { + return written, nil + } + ret, err := w.output.Write(out) + if err != nil { + return 0, err + } + return written + ret, nil +}