From fe546c421ae740f0299086b7fbc4c36a70c4c802 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 10 Feb 2025 06:32:01 +0100 Subject: [PATCH] feat(netcore): write tests for dialer.go This commit adds tests for dialer.go. While there, add some extra comments regarding things to do next. --- netcore/dialer.go | 8 +- netcore/dialer_test.go | 293 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+), 1 deletion(-) create mode 100644 netcore/dialer_test.go diff --git a/netcore/dialer.go b/netcore/dialer.go index 14528ce..80194de 100644 --- a/netcore/dialer.go +++ b/netcore/dialer.go @@ -52,6 +52,9 @@ func (nx *Network) sequentialDial( } errv = append(errv, err) } + if len(errv) <= 0 { + return nil, errors.New("no endpoints to dial") + } return nil, errors.Join(errv...) } @@ -68,7 +71,8 @@ func (nx *Network) dialLog(ctx context.Context, network, address string) (net.Co // Emit structured event after the dial nx.emitConnectDone(ctx, network, address, t0, conn, err) - // Maybe wrap the connection if it's not nil and it makes sense to wrap it + // Maybe wrap the connection if it's not nil and it makes sense + // to wrap it (i.e., we have logging enabled) conn = nx.maybeWrapConn(ctx, conn) // Return the connection and error to the caller @@ -83,6 +87,8 @@ func (nx *Network) dialNet(ctx context.Context, network, address string) (net.Co } // otherwise use the net package + // TODO(bassosimone): either make multipath TCP configurable + // or document that we disable it by default child := &net.Dialer{} child.SetMultipathTCP(false) return child.DialContext(ctx, network, address) diff --git a/netcore/dialer_test.go b/netcore/dialer_test.go new file mode 100644 index 0000000..c62741a --- /dev/null +++ b/netcore/dialer_test.go @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package netcore + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/rbmk-project/common/mocks" + "github.com/rbmk-project/common/runtimex" + "github.com/stretchr/testify/assert" +) + +func TestNetwork_DialContext(t *testing.T) { + t.Run("lookup failure", func(t *testing.T) { + expectedErr := errors.New("mocked lookup error") + nx := &Network{ + LookupHostFunc: func(ctx context.Context, domain string) ([]string, error) { + return nil, expectedErr + }, + } + conn, err := nx.DialContext(context.Background(), "tcp", "example.com:80") + assert.ErrorIs(t, err, expectedErr) + assert.Nil(t, conn) + }) + + t.Run("dial failure", func(t *testing.T) { + expectedErr := errors.New("mocked dial error") + nx := &Network{ + LookupHostFunc: func(ctx context.Context, domain string) ([]string, error) { + return []string{"1.2.3.4"}, nil + }, + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expectedErr + }, + } + conn, err := nx.DialContext(context.Background(), "tcp", "example.com:80") + assert.ErrorIs(t, err, expectedErr) + assert.Nil(t, conn) + }) + + t.Run("successful dial", func(t *testing.T) { + mockConn := &mocks.Conn{ + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 80} + }, + } + nx := &Network{ + LookupHostFunc: func(ctx context.Context, domain string) ([]string, error) { + return []string{"1.2.3.4"}, nil + }, + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + return mockConn, nil + }, + } + conn, err := nx.DialContext(context.Background(), "tcp", "example.com:80") + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) +} + +func TestNetwork_sequentialDial(t *testing.T) { + t.Run("empty endpoints list", func(t *testing.T) { + nx := &Network{} + conn, err := nx.sequentialDial(context.Background(), "tcp", nx.dialLog) + assert.Error(t, err) + assert.Nil(t, conn) + }) + + t.Run("all endpoints fail", func(t *testing.T) { + expectedErr1 := errors.New("error 1") + expectedErr2 := errors.New("error 2") + dialAttempts := 0 + nx := &Network{ + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + dialAttempts++ + if address == "1.1.1.1:80" { + return nil, expectedErr1 + } + return nil, expectedErr2 + }, + } + conn, err := nx.sequentialDial(context.Background(), "tcp", nx.dialLog, "1.1.1.1:80", "2.2.2.2:80") + assert.Error(t, err) + assert.Nil(t, conn) + assert.Equal(t, 2, dialAttempts) + assert.ErrorIs(t, err, expectedErr1) + assert.ErrorIs(t, err, expectedErr2) + }) + + t.Run("first endpoint succeeds", func(t *testing.T) { + mockConn := &mocks.Conn{} + nx := &Network{ + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + return mockConn, nil + }, + } + conn, err := nx.sequentialDial(context.Background(), "tcp", nx.dialLog, "1.1.1.1:80", "2.2.2.2:80") + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) + + t.Run("second endpoint succeeds", func(t *testing.T) { + mockConn := &mocks.Conn{} + expectedErr := errors.New("first endpoint fails") + dialAttempts := 0 + nx := &Network{ + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + dialAttempts++ + if dialAttempts == 1 { + return nil, expectedErr + } + return mockConn, nil + }, + } + conn, err := nx.sequentialDial(context.Background(), "tcp", nx.dialLog, "1.1.1.1:80", "2.2.2.2:80") + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + assert.Equal(t, 2, dialAttempts) + }) +} + +func TestNetwork_dialLog(t *testing.T) { + t.Run("successful dial with logging", func(t *testing.T) { + var buf bytes.Buffer + fixedTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + return slog.Attr{} + } + return a + }, + })) + + mockConn := &mocks.Conn{ + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("1.1.1.1"), Port: 80} + }, + } + + nx := &Network{ + Logger: logger, + TimeNow: func() time.Time { + return fixedTime + }, + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + return mockConn, nil + }, + } + + conn, err := nx.dialLog(context.Background(), "tcp", "1.1.1.1:80") + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + + logs := strings.Split(strings.TrimSpace(buf.String()), "\n") + assert.Len(t, logs, 2) + + // Verify connectStart log + var startLog map[string]interface{} + err = json.Unmarshal([]byte(logs[0]), &startLog) + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{ + "level": "INFO", + "msg": "connectStart", + "protocol": "tcp", + "remoteAddr": "1.1.1.1:80", + "t": fixedTime.Format(time.RFC3339Nano), + }, startLog) + + // Verify connectDone log + var doneLog map[string]interface{} + err = json.Unmarshal([]byte(logs[1]), &doneLog) + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{ + "level": "INFO", + "msg": "connectDone", + "err": nil, + "errClass": "", + "localAddr": "127.0.0.1:1234", + "protocol": "tcp", + "remoteAddr": "1.1.1.1:80", + "t0": fixedTime.Format(time.RFC3339Nano), + "t": fixedTime.Format(time.RFC3339Nano), + }, doneLog) + }) + + t.Run("dial failure with logging", func(t *testing.T) { + var buf bytes.Buffer + fixedTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + return slog.Attr{} + } + return a + }, + })) + + expectedErr := errors.New("mocked dial error") + nx := &Network{ + Logger: logger, + TimeNow: func() time.Time { + return fixedTime + }, + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expectedErr + }, + } + + conn, err := nx.dialLog(context.Background(), "tcp", "1.1.1.1:80") + assert.ErrorIs(t, err, expectedErr) + assert.Nil(t, conn) + + logs := strings.Split(strings.TrimSpace(buf.String()), "\n") + assert.Len(t, logs, 2) + + // Verify connectStart log + var startLog map[string]interface{} + err = json.Unmarshal([]byte(logs[0]), &startLog) + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{ + "level": "INFO", + "msg": "connectStart", + "protocol": "tcp", + "remoteAddr": "1.1.1.1:80", + "t": fixedTime.Format(time.RFC3339Nano), + }, startLog) + + // Verify connectDone log + var doneLog map[string]interface{} + err = json.Unmarshal([]byte(logs[1]), &doneLog) + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{ + "level": "INFO", + "msg": "connectDone", + "err": expectedErr.Error(), + "errClass": "EGENERIC", + "localAddr": "", + "protocol": "tcp", + "remoteAddr": "1.1.1.1:80", + "t0": fixedTime.Format(time.RFC3339Nano), + "t": fixedTime.Format(time.RFC3339Nano), + }, doneLog) + }) +} + +func TestNetwork_dialNet(t *testing.T) { + t.Run("using custom dialer", func(t *testing.T) { + mockConn := &mocks.Conn{} + nx := &Network{ + DialContextFunc: func(ctx context.Context, network, address string) (net.Conn, error) { + return mockConn, nil + }, + } + conn, err := nx.dialNet(context.Background(), "tcp", "1.1.1.1:80") + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) + + t.Run("using net package", func(t *testing.T) { + // create a server using localhost to test against + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + nx := &Network{} + parsed := runtimex.Try1(url.Parse(server.URL)) + conn, err := nx.dialNet(context.Background(), "tcp", parsed.Host) + assert.NoError(t, err) + assert.NotNil(t, conn) + conn.Close() + }) +}