From bb2b7811921c71156e02c3a2f5a7196eb96afa78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Benkovsk=C3=BD?= Date: Fri, 11 Aug 2023 15:41:10 +0200 Subject: [PATCH] connection creation race condition --- doq/client.go | 47 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/doq/client.go b/doq/client.go index a9dfb60..57f49dc 100644 --- a/doq/client.go +++ b/doq/client.go @@ -15,10 +15,11 @@ import ( // Client encapsulates and provides logic for querying DNS servers over QUIC. // The client should be thread-safe. The client reuses single QUIC connection to the server, while creating multiple parallel QUIC streams. type Client struct { - lock sync.Mutex + connLock sync.RWMutex + conn quic.Connection + addr string tlsconfig *tls.Config - conn quic.Connection writeTimeout time.Duration readTimeout time.Duration connectTimeout time.Duration @@ -55,8 +56,8 @@ func NewClient(addr string, options Options) *Client { } func (c *Client) dial(ctx context.Context) error { - c.lock.Lock() - defer c.lock.Unlock() + c.connLock.Lock() + defer c.connLock.Unlock() if c.conn != nil { c.conn.ConnectionState() if err := c.conn.Context().Err(); err == nil { @@ -100,17 +101,8 @@ func (c *Client) dial(ctx context.Context) error { // Send sends DNS request using DNS over QUIC. func (c *Client) Send(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - if c.conn == nil { - // connection not yet created, create one - if err := c.dial(ctx); err != nil { - return nil, err - } - } - if err := c.conn.Context().Err(); err != nil { - // connection is not healthy, try to dial a new one - if err := c.dial(ctx); err != nil { - return nil, err - } + if err := c.dialIfNeeded(ctx); err != nil { + return nil, err } stream, err := c.conn.OpenStreamSync(ctx) @@ -137,6 +129,31 @@ func (c *Client) Send(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { return readMsg(readCtx, stream) } +func (c *Client) dialIfNeeded(ctx context.Context) error { + c.connLock.RLock() + connNotCreated := c.conn == nil + c.connLock.RUnlock() + + if connNotCreated { + // connection not yet created, create one + if err := c.dial(ctx); err != nil { + return err + } + } + + c.connLock.RLock() + connFailed := c.conn.Context().Err() != nil + c.connLock.RUnlock() + + if connFailed { + // connection is not healthy, try to dial a new one + if err := c.dial(ctx); err != nil { + return err + } + } + return nil +} + func writeMsg(ctx context.Context, stream quic.Stream, msg *dns.Msg) error { pack, err := msg.Pack() if err != nil {