Skip to content

Commit

Permalink
connection creation race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
Ondřej Benkovský committed Aug 11, 2023
1 parent b47adc7 commit bb2b781
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions doq/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import (
// Client encapsulates and provides logic for querying DNS servers over QUIC.
// The client should be thread-safe. The client reuses single QUIC connection to the server, while creating multiple parallel QUIC streams.
type Client struct {
lock sync.Mutex
connLock sync.RWMutex
conn quic.Connection

addr string
tlsconfig *tls.Config
conn quic.Connection
writeTimeout time.Duration
readTimeout time.Duration
connectTimeout time.Duration
Expand Down Expand Up @@ -55,8 +56,8 @@ func NewClient(addr string, options Options) *Client {
}

func (c *Client) dial(ctx context.Context) error {
c.lock.Lock()
defer c.lock.Unlock()
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {
c.conn.ConnectionState()
if err := c.conn.Context().Err(); err == nil {
Expand Down Expand Up @@ -100,17 +101,8 @@ func (c *Client) dial(ctx context.Context) error {

// Send sends DNS request using DNS over QUIC.
func (c *Client) Send(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if c.conn == nil {
// connection not yet created, create one
if err := c.dial(ctx); err != nil {
return nil, err
}
}
if err := c.conn.Context().Err(); err != nil {
// connection is not healthy, try to dial a new one
if err := c.dial(ctx); err != nil {
return nil, err
}
if err := c.dialIfNeeded(ctx); err != nil {
return nil, err
}

stream, err := c.conn.OpenStreamSync(ctx)
Expand All @@ -137,6 +129,31 @@ func (c *Client) Send(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
return readMsg(readCtx, stream)
}

func (c *Client) dialIfNeeded(ctx context.Context) error {
c.connLock.RLock()
connNotCreated := c.conn == nil
c.connLock.RUnlock()

if connNotCreated {
// connection not yet created, create one
if err := c.dial(ctx); err != nil {
return err
}
}

c.connLock.RLock()
connFailed := c.conn.Context().Err() != nil
c.connLock.RUnlock()

if connFailed {
// connection is not healthy, try to dial a new one
if err := c.dial(ctx); err != nil {
return err
}
}
return nil
}

func writeMsg(ctx context.Context, stream quic.Stream, msg *dns.Msg) error {
pack, err := msg.Pack()
if err != nil {
Expand Down

0 comments on commit bb2b781

Please sign in to comment.