diff --git a/caller.go b/caller.go index a109fa0..c6e0a28 100644 --- a/caller.go +++ b/caller.go @@ -4,11 +4,13 @@ import ( "errors" "fmt" "reflect" + "sync" ) type caller struct { - Func reflect.Value - Args []reflect.Type + sync.RWMutex + Func reflect.Value + Args []reflect.Type } func newCaller(f interface{}) (*caller, error) { @@ -28,12 +30,15 @@ func newCaller(f interface{}) (*caller, error) { } return &caller{ - Func: fv, - Args: args, + Func: fv, + Args: args, }, nil } func (c *caller) GetArgs() []interface{} { + c.RLock() + defer c.RUnlock() + ret := make([]interface{}, len(c.Args)) for i, argT := range c.Args { if argT.Kind() == reflect.Ptr { @@ -46,6 +51,8 @@ func (c *caller) GetArgs() []interface{} { } func (c *caller) Call(args []interface{}) []reflect.Value { + c.RLock() + defer c.RUnlock() var a []reflect.Value diff := 0 diff --git a/client.go b/client.go index ad9176e..eaf72e1 100644 --- a/client.go +++ b/client.go @@ -2,9 +2,10 @@ package socketio_client import ( "net/url" - "reflect" "path" + "reflect" "strings" + "sync" ) type Options struct { @@ -17,10 +18,11 @@ type Client struct { conn *clientConn - events map[string]*caller - acks map[int]*caller - id int - namespace string + eventsLock sync.RWMutex + events map[string]*caller + acks map[int]*caller + id int + namespace string } func NewClient(uri string, opts *Options) (client *Client, err error) { @@ -29,10 +31,10 @@ func NewClient(uri string, opts *Options) (client *Client, err error) { if err != nil { return } - url.Path = path.Join("/socket.io",url.Path) + url.Path = path.Join("/socket.io", url.Path) url.Path = url.EscapedPath() - if strings.HasSuffix(url.Path,"socket.io"){ - url.Path+="/" + if strings.HasSuffix(url.Path, "socket.io") { + url.Path += "/" } q := url.Query() for k, v := range opts.Query { @@ -63,7 +65,9 @@ func (client *Client) On(message string, f interface{}) (err error) { if err != nil { return } + client.eventsLock.Lock() client.events[message] = c + client.eventsLock.Unlock() return } @@ -147,7 +151,9 @@ func (client *Client) onPacket(decoder *decoder, packet *packet) ([]interface{}, default: message = decoder.Message() } + client.eventsLock.RLock() c, ok := client.events[message] + client.eventsLock.RUnlock() if !ok { // If the message is not recognized by the server, the decoder.currentCloser // needs to be closed otherwise the server will be stuck until the e diff --git a/client_conn.go b/client_conn.go index 118a984..000735a 100644 --- a/client_conn.go +++ b/client_conn.go @@ -181,6 +181,7 @@ func (c *clientConn) OnPacket(r *parser.PacketDecoder) { t := c.getCurrent() u := c.getUpgrade() newWriter := t.NextWriter + c.writerLocker.Lock() if u != nil { if w, _ := t.NextWriter(message.MessageText, parser.NOOP); w != nil { w.Close() @@ -191,6 +192,7 @@ func (c *clientConn) OnPacket(r *parser.PacketDecoder) { io.Copy(w, r) w.Close() } + c.writerLocker.Unlock() fallthrough case parser.PONG: c.pingChan <- true