diff --git a/pkg/adt/interval_tree_test.go b/pkg/adt/interval_tree_test.go index 7c9e8c69a309..8eb0246ad971 100644 --- a/pkg/adt/interval_tree_test.go +++ b/pkg/adt/interval_tree_test.go @@ -19,6 +19,7 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -272,32 +273,18 @@ func TestIntervalTreeIntersects(t *testing.T) { ivt := NewIntervalTree() ivt.Insert(NewStringInterval("1", "3"), 123) - if ivt.Intersects(NewStringPoint("0")) { - t.Errorf("contains 0") - } - if !ivt.Intersects(NewStringPoint("1")) { - t.Errorf("missing 1") - } - if !ivt.Intersects(NewStringPoint("11")) { - t.Errorf("missing 11") - } - if !ivt.Intersects(NewStringPoint("2")) { - t.Errorf("missing 2") - } - if ivt.Intersects(NewStringPoint("3")) { - t.Errorf("contains 3") - } + assert.Falsef(t, ivt.Intersects(NewStringPoint("0")), "contains 0") + assert.Truef(t, ivt.Intersects(NewStringPoint("1")), "missing 1") + assert.Truef(t, ivt.Intersects(NewStringPoint("11")), "missing 11") + assert.Truef(t, ivt.Intersects(NewStringPoint("2")), "missing 2") + assert.Falsef(t, ivt.Intersects(NewStringPoint("3")), "contains 3") } func TestIntervalTreeStringAffine(t *testing.T) { ivt := NewIntervalTree() ivt.Insert(NewStringAffineInterval("8", ""), 123) - if !ivt.Intersects(NewStringAffinePoint("9")) { - t.Errorf("missing 9") - } - if ivt.Intersects(NewStringAffinePoint("7")) { - t.Errorf("contains 7") - } + assert.Truef(t, ivt.Intersects(NewStringAffinePoint("9")), "missing 9") + assert.Falsef(t, ivt.Intersects(NewStringAffinePoint("7")), "contains 7") } func TestIntervalTreeStab(t *testing.T) { @@ -310,27 +297,13 @@ func TestIntervalTreeStab(t *testing.T) { tr := ivt.(*intervalTree) require.Equalf(t, 0, tr.root.max.Compare(StringComparable("8")), "wrong root max got %v, expected 8", tr.root.max) - if x := len(ivt.Stab(NewStringPoint("0"))); x != 3 { - t.Errorf("got %d, expected 3", x) - } - if x := len(ivt.Stab(NewStringPoint("1"))); x != 2 { - t.Errorf("got %d, expected 2", x) - } - if x := len(ivt.Stab(NewStringPoint("2"))); x != 1 { - t.Errorf("got %d, expected 1", x) - } - if x := len(ivt.Stab(NewStringPoint("3"))); x != 0 { - t.Errorf("got %d, expected 0", x) - } - if x := len(ivt.Stab(NewStringPoint("5"))); x != 1 { - t.Errorf("got %d, expected 1", x) - } - if x := len(ivt.Stab(NewStringPoint("55"))); x != 1 { - t.Errorf("got %d, expected 1", x) - } - if x := len(ivt.Stab(NewStringPoint("6"))); x != 1 { - t.Errorf("got %d, expected 1", x) - } + assert.Len(t, ivt.Stab(NewStringPoint("0")), 3) + assert.Len(t, ivt.Stab(NewStringPoint("1")), 2) + assert.Len(t, ivt.Stab(NewStringPoint("2")), 1) + assert.Empty(t, ivt.Stab(NewStringPoint("3"))) + assert.Len(t, ivt.Stab(NewStringPoint("5")), 1) + assert.Len(t, ivt.Stab(NewStringPoint("55")), 1) + assert.Len(t, ivt.Stab(NewStringPoint("6")), 1) } type xy struct { @@ -368,15 +341,11 @@ func TestIntervalTreeRandom(t *testing.T) { require.NotEmptyf(t, ivt.Stab(NewInt64Point(v)), "expected %v stab non-zero for [%+v)", v, xy) require.Truef(t, ivt.Intersects(NewInt64Point(v)), "did not get %d as expected for [%+v)", v, xy) } - if !ivt.Delete(NewInt64Interval(ab.x, ab.y)) { - t.Errorf("did not delete %v as expected", ab) - } + assert.Truef(t, ivt.Delete(NewInt64Interval(ab.x, ab.y)), "did not delete %v as expected", ab) delete(ivs, ab) } - if ivt.Len() != 0 { - t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len()) - } + assert.Equalf(t, 0, ivt.Len(), "got ivt.Len() = %v, expected 0", ivt.Len()) } // TestIntervalTreeSortedVisit tests that intervals are visited in sorted order. @@ -417,17 +386,13 @@ func TestIntervalTreeSortedVisit(t *testing.T) { last := tt.ivls[0].Begin count := 0 chk := func(iv *IntervalValue) bool { - if last.Compare(iv.Ivl.Begin) > 0 { - t.Errorf("#%d: expected less than %d, got interval %+v", i, last, iv.Ivl) - } + assert.LessOrEqualf(t, last.Compare(iv.Ivl.Begin), 0, "#%d: expected less than %d, got interval %+v", i, last, iv.Ivl) last = iv.Ivl.Begin count++ return true } ivt.Visit(tt.visitRange, chk) - if count != len(tt.ivls) { - t.Errorf("#%d: did not cover all intervals. expected %d, got %d", i, len(tt.ivls), count) - } + assert.Lenf(t, tt.ivls, count, "#%d: did not cover all intervals. expected %d, got %d", i, len(tt.ivls), count) } } @@ -468,9 +433,7 @@ func TestIntervalTreeVisitExit(t *testing.T) { count++ return tt.f(n) }) - if count != tt.wcount { - t.Errorf("#%d: expected count %d, got %d", i, tt.wcount, count) - } + assert.Equalf(t, count, tt.wcount, "#%d: expected count %d, got %d", i, tt.wcount, count) } } @@ -530,8 +493,7 @@ func TestIntervalTreeContains(t *testing.T) { for _, ivl := range tt.ivls { ivt.Insert(ivl, struct{}{}) } - if v := ivt.Contains(tt.chkIvl); v != tt.wContains { - t.Errorf("#%d: ivt.Contains got %v, expected %v", i, v, tt.wContains) - } + v := ivt.Contains(tt.chkIvl) + assert.Equalf(t, v, tt.wContains, "#%d: ivt.Contains got %v, expected %v", i, v, tt.wContains) } } diff --git a/pkg/crc/crc_test.go b/pkg/crc/crc_test.go index 3c9cc3a280cc..38990fac68ce 100644 --- a/pkg/crc/crc_test.go +++ b/pkg/crc/crc_test.go @@ -9,6 +9,7 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,38 +22,22 @@ func TestHash32(t *testing.T) { // create a new hash with stdhash.Sum32() as initial crc hash := New(stdhash.Sum32(), crc32.IEEETable) - wsize := stdhash.Size() - if g := hash.Size(); g != wsize { - t.Errorf("size = %d, want %d", g, wsize) - } - wbsize := stdhash.BlockSize() - if g := hash.BlockSize(); g != wbsize { - t.Errorf("block size = %d, want %d", g, wbsize) - } - wsum32 := stdhash.Sum32() - if g := hash.Sum32(); g != wsum32 { - t.Errorf("Sum32 = %d, want %d", g, wsum32) - } + assert.Equalf(t, hash.Size(), stdhash.Size(), "size") + assert.Equalf(t, hash.BlockSize(), stdhash.BlockSize(), "block size") + assert.Equalf(t, hash.Sum32(), stdhash.Sum32(), "Sum32") wsum := stdhash.Sum(make([]byte, 32)) - if g := hash.Sum(make([]byte, 32)); !reflect.DeepEqual(g, wsum) { - t.Errorf("sum = %v, want %v", g, wsum) - } + g := hash.Sum(make([]byte, 32)) + assert.Truef(t, reflect.DeepEqual(g, wsum), "sum") // write something _, err = stdhash.Write([]byte("test data")) require.NoErrorf(t, err, "unexpected write error: %v", err) _, err = hash.Write([]byte("test data")) require.NoErrorf(t, err, "unexpected write error: %v", err) - wsum32 = stdhash.Sum32() - if g := hash.Sum32(); g != wsum32 { - t.Errorf("Sum32 after write = %d, want %d", g, wsum32) - } + assert.Equalf(t, hash.Sum32(), stdhash.Sum32(), "Sum32 after write") // reset stdhash.Reset() hash.Reset() - wsum32 = stdhash.Sum32() - if g := hash.Sum32(); g != wsum32 { - t.Errorf("Sum32 after reset = %d, want %d", g, wsum32) - } + assert.Equalf(t, hash.Sum32(), stdhash.Sum32(), "Sum32 after reset") } diff --git a/pkg/featuregate/feature_gate_test.go b/pkg/featuregate/feature_gate_test.go index 7411f363fa34..5dc5a86537d7 100644 --- a/pkg/featuregate/feature_gate_test.go +++ b/pkg/featuregate/feature_gate_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -213,16 +214,13 @@ func TestFeatureGateFlag(t *testing.T) { err := fs.Parse([]string{fmt.Sprintf("--%s=%s", defaultFlagName, test.arg)}) if test.parseError != "" { - if !strings.Contains(err.Error(), test.parseError) { - t.Errorf("%d: Parse() Expected %v, Got %v", i, test.parseError, err) - } + assert.Containsf(t, err.Error(), test.parseError, "%d: Parse() Expected %v, Got %v", i, test.parseError, err) } else if err != nil { t.Errorf("%d: Parse() Expected nil, Got %v", i, err) } for k, v := range test.expect { - if actual := f.enabled.Load().(map[Feature]bool)[k]; actual != v { - t.Errorf("%d: expected %s=%v, Got %v", i, k, v, actual) - } + actual := f.enabled.Load().(map[Feature]bool)[k] + assert.Equalf(t, actual, v, "%d: expected %s=%v, Got %v", i, k, v, actual) } }) } @@ -240,20 +238,12 @@ func TestFeatureGateOverride(t *testing.T) { }) f.Set("TestAlpha=true,TestBeta=true") - if f.Enabled(testAlphaGate) != true { - t.Errorf("Expected true") - } - if f.Enabled(testBetaGate) != true { - t.Errorf("Expected true") - } + assert.Truef(t, f.Enabled(testAlphaGate), "Expected true") + assert.Truef(t, f.Enabled(testBetaGate), "Expected true") f.Set("TestAlpha=false") - if f.Enabled(testAlphaGate) != false { - t.Errorf("Expected false") - } - if f.Enabled(testBetaGate) != true { - t.Errorf("Expected true") - } + assert.Falsef(t, f.Enabled(testAlphaGate), "Expected false") + assert.Truef(t, f.Enabled(testBetaGate), "Expected true") } func TestFeatureGateFlagDefaults(t *testing.T) { @@ -268,12 +258,8 @@ func TestFeatureGateFlagDefaults(t *testing.T) { testBetaGate: {Default: true, PreRelease: Beta}, }) - if f.Enabled(testAlphaGate) != false { - t.Errorf("Expected false") - } - if f.Enabled(testBetaGate) != true { - t.Errorf("Expected true") - } + assert.Falsef(t, f.Enabled(testAlphaGate), "Expected false") + assert.Truef(t, f.Enabled(testBetaGate), "Expected true") } func TestFeatureGateKnownFeatures(t *testing.T) { @@ -411,9 +397,8 @@ func TestFeatureGateSetFromMap(t *testing.T) { t.Errorf("%d: SetFromMap(%#v) Expected success, Got err:%v", i, test.setmap, err) } for k, v := range test.expect { - if actual := f.Enabled(k); actual != v { - t.Errorf("%d: SetFromMap(%#v) Expected %s=%v, Got %s=%v", i, test.setmap, k, v, k, actual) - } + actual := f.Enabled(k) + assert.Equalf(t, actual, v, "%d: SetFromMap(%#v) Expected %s=%v, Got %s=%v", i, test.setmap, k, v, k, actual) } }) } @@ -467,9 +452,7 @@ func TestFeatureGateString(t *testing.T) { f.Add(featuremap) f.SetFromMap(test.setmap) result := f.String() - if result != test.expect { - t.Errorf("%d: SetFromMap(%#v) Expected %s, Got %s", i, test.setmap, test.expect, result) - } + assert.Equalf(t, result, test.expect, "%d: SetFromMap(%#v) Expected %s, Got %s", i, test.setmap, test.expect, result) }) } } @@ -477,128 +460,90 @@ func TestFeatureGateString(t *testing.T) { func TestFeatureGateOverrideDefault(t *testing.T) { t.Run("overrides take effect", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.Add(map[Feature]FeatureSpec{ + err := f.Add(map[Feature]FeatureSpec{ "TestFeature1": {Default: true}, "TestFeature2": {Default: false}, - }); err != nil { - t.Fatal(err) - } - if err := f.OverrideDefault("TestFeature1", false); err != nil { - t.Fatal(err) - } - if err := f.OverrideDefault("TestFeature2", true); err != nil { - t.Fatal(err) - } - if f.Enabled("TestFeature1") { - t.Error("expected TestFeature1 to have effective default of false") - } - if !f.Enabled("TestFeature2") { - t.Error("expected TestFeature2 to have effective default of true") - } + }) + require.NoError(t, err) + require.NoError(t, f.OverrideDefault("TestFeature1", false)) + require.NoError(t, f.OverrideDefault("TestFeature2", true)) + assert.Falsef(t, f.Enabled("TestFeature1"), "expected TestFeature1 to have effective default of false") + assert.Truef(t, f.Enabled("TestFeature2"), "expected TestFeature2 to have effective default of true") }) t.Run("overrides are preserved across deep copies", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.Add(map[Feature]FeatureSpec{"TestFeature": {Default: false}}); err != nil { - t.Fatal(err) - } - if err := f.OverrideDefault("TestFeature", true); err != nil { - t.Fatal(err) - } + err := f.Add(map[Feature]FeatureSpec{"TestFeature": {Default: false}}) + require.NoError(t, err) + require.NoError(t, f.OverrideDefault("TestFeature", true)) fcopy := f.DeepCopy() - if !fcopy.Enabled("TestFeature") { - t.Error("default override was not preserved by deep copy") - } + assert.Truef(t, fcopy.Enabled("TestFeature"), "default override was not preserved by deep copy") }) t.Run("reflected in known features", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.Add(map[Feature]FeatureSpec{"TestFeature": { + err := f.Add(map[Feature]FeatureSpec{"TestFeature": { Default: false, PreRelease: Alpha, - }}); err != nil { - t.Fatal(err) - } - if err := f.OverrideDefault("TestFeature", true); err != nil { - t.Fatal(err) - } + }}) + require.NoError(t, err) + require.NoError(t, f.OverrideDefault("TestFeature", true)) var found bool for _, s := range f.KnownFeatures() { if !strings.Contains(s, "TestFeature") { continue } found = true - if !strings.Contains(s, "default=true") { - t.Errorf("expected override of default to be reflected in known feature description %q", s) - } - } - if !found { - t.Error("found no entry for TestFeature in known features") + assert.Containsf(t, s, "default=true", "expected override of default to be reflected in known feature description %q", s) } + assert.Truef(t, found, "found no entry for TestFeature in known features") }) t.Run("may not change default for specs with locked defaults", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.Add(map[Feature]FeatureSpec{ + err := f.Add(map[Feature]FeatureSpec{ "LockedFeature": { Default: true, LockToDefault: true, }, - }); err != nil { - t.Fatal(err) - } - if f.OverrideDefault("LockedFeature", false) == nil { - t.Error("expected error when attempting to override the default for a feature with a locked default") - } - if f.OverrideDefault("LockedFeature", true) == nil { - t.Error("expected error when attempting to override the default for a feature with a locked default") - } + }) + require.NoError(t, err) + require.Errorf(t, f.OverrideDefault("LockedFeature", false), "expected error when attempting to override the default for a feature with a locked default") + assert.Errorf(t, f.OverrideDefault("LockedFeature", true), "expected error when attempting to override the default for a feature with a locked default") }) t.Run("does not supersede explicitly-set value", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.Add(map[Feature]FeatureSpec{"TestFeature": {Default: true}}); err != nil { - t.Fatal(err) - } - if err := f.OverrideDefault("TestFeature", false); err != nil { - t.Fatal(err) - } - if err := f.SetFromMap(map[string]bool{"TestFeature": true}); err != nil { - t.Fatal(err) - } - if !f.Enabled("TestFeature") { - t.Error("expected feature to be effectively enabled despite default override") - } + err := f.Add(map[Feature]FeatureSpec{"TestFeature": {Default: true}}) + require.NoError(t, err) + require.NoError(t, f.OverrideDefault("TestFeature", false)) + require.NoError(t, f.SetFromMap(map[string]bool{"TestFeature": true})) + assert.Truef(t, f.Enabled("TestFeature"), "expected feature to be effectively enabled despite default override") }) t.Run("prevents re-registration of feature spec after overriding default", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.Add(map[Feature]FeatureSpec{ + err := f.Add(map[Feature]FeatureSpec{ "TestFeature": { Default: true, PreRelease: Alpha, }, - }); err != nil { - t.Fatal(err) - } - if err := f.OverrideDefault("TestFeature", false); err != nil { - t.Fatal(err) - } - if err := f.Add(map[Feature]FeatureSpec{ + }) + require.NoError(t, err) + require.NoError(t, f.OverrideDefault("TestFeature", false)) + err = f.Add(map[Feature]FeatureSpec{ "TestFeature": { Default: true, PreRelease: Alpha, }, - }); err == nil { - t.Error("expected re-registration to return a non-nil error after overriding its default") - } + }) + assert.Errorf(t, err, "expected re-registration to return a non-nil error after overriding its default") }) t.Run("does not allow override for an unknown feature", func(t *testing.T) { f := New("test", zaptest.NewLogger(t)) - if err := f.OverrideDefault("TestFeature", true); err == nil { - t.Error("expected an error to be returned in attempt to override default for unregistered feature") - } + err := f.OverrideDefault("TestFeature", true) + assert.Errorf(t, err, "expected an error to be returned in attempt to override default for unregistered feature") }) t.Run("returns error if already added to flag set", func(t *testing.T) { @@ -606,8 +551,7 @@ func TestFeatureGateOverrideDefault(t *testing.T) { fs := flag.NewFlagSet("test", flag.ContinueOnError) f.AddFlag(fs, defaultFlagName) - if err := f.OverrideDefault("TestFeature", true); err == nil { - t.Error("expected a non-nil error to be returned") - } + err := f.OverrideDefault("TestFeature", true) + assert.Errorf(t, err, "expected a non-nil error to be returned") }) } diff --git a/pkg/flags/flag_test.go b/pkg/flags/flag_test.go index f2b8ce1eb716..bf256b58d37e 100644 --- a/pkg/flags/flag_test.go +++ b/pkg/flags/flag_test.go @@ -19,6 +19,7 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -33,9 +34,7 @@ func TestSetFlagsFromEnv(t *testing.T) { // flags should be settable using env vars t.Setenv("ETCD_A", "foo") // and command-line flags - if err := fs.Set("b", "bar"); err != nil { - t.Fatal(err) - } + require.NoError(t, fs.Set("b", "bar")) // first verify that flags are as expected before reading the env for f, want := range map[string]string{ @@ -48,16 +47,13 @@ func TestSetFlagsFromEnv(t *testing.T) { // now read the env and verify flags were updated as expected err := SetFlagsFromEnv(zaptest.NewLogger(t), "ETCD", fs) - if err != nil { - t.Errorf("err=%v, want nil", err) - } + require.NoErrorf(t, err, "err=%v, want nil", err) for f, want := range map[string]string{ "a": "foo", "b": "bar", } { - if got := fs.Lookup(f).Value.String(); got != want { - t.Errorf("flag %q=%q, want %q", f, got, want) - } + got := fs.Lookup(f).Value.String() + assert.Equalf(t, want, got, "flag %q=%q, want %q", f, got, want) } } @@ -66,9 +62,8 @@ func TestSetFlagsFromEnvBad(t *testing.T) { fs := flag.NewFlagSet("testing", flag.ExitOnError) fs.Int("x", 0, "") t.Setenv("ETCD_X", "not_a_number") - if err := SetFlagsFromEnv(zaptest.NewLogger(t), "ETCD", fs); err == nil { - t.Errorf("err=nil, want != nil") - } + err := SetFlagsFromEnv(zaptest.NewLogger(t), "ETCD", fs) + assert.Errorf(t, err, "err=nil, want != nil") } func TestSetFlagsFromEnvParsingError(t *testing.T) { diff --git a/pkg/flags/selective_string_test.go b/pkg/flags/selective_string_test.go index cc310ed63bfb..f317643a6a1a 100644 --- a/pkg/flags/selective_string_test.go +++ b/pkg/flags/selective_string_test.go @@ -16,6 +16,8 @@ package flags import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestSelectiveStringValue(t *testing.T) { @@ -35,13 +37,9 @@ func TestSelectiveStringValue(t *testing.T) { } for i, tt := range tests { sf := NewSelectiveStringValue(tt.vals...) - if sf.v != tt.vals[0] { - t.Errorf("#%d: want default val=%v,but got %v", i, tt.vals[0], sf.v) - } + assert.Equalf(t, sf.v, tt.vals[0], "#%d: want default val=%v,but got %v", i, tt.vals[0], sf.v) err := sf.Set(tt.val) - if tt.pass != (err == nil) { - t.Errorf("#%d: want pass=%t, but got err=%v", i, tt.pass, err) - } + assert.Equalf(t, tt.pass, (err == nil), "#%d: want pass=%t, but got err=%v", i, tt.pass, err) } } @@ -63,8 +61,6 @@ func TestSelectiveStringsValue(t *testing.T) { for i, tt := range tests { sf := NewSelectiveStringsValue(tt.vals...) err := sf.Set(tt.val) - if tt.pass != (err == nil) { - t.Errorf("#%d: want pass=%t, but got err=%v", i, tt.pass, err) - } + assert.Equalf(t, tt.pass, (err == nil), "#%d: want pass=%t, but got err=%v", i, tt.pass, err) } } diff --git a/pkg/flags/uint32_test.go b/pkg/flags/uint32_test.go index 949fbefb671c..fdb3ef1d0ade 100644 --- a/pkg/flags/uint32_test.go +++ b/pkg/flags/uint32_test.go @@ -56,13 +56,9 @@ func TestUint32Value(t *testing.T) { err := val.Set(tc.s) if tc.expectError { - if err == nil { - t.Errorf("Expected failure on parsing uint32 value from %s", tc.s) - } + assert.Errorf(t, err, "Expected failure on parsing uint32 value from %s", tc.s) } else { - if err != nil { - t.Errorf("Unexpected error when parsing %s: %v", tc.s, err) - } + require.NoErrorf(t, err, "Unexpected error when parsing %s: %v", tc.s, err) assert.Equal(t, tc.expectedVal, uint32(val)) } }) diff --git a/pkg/flags/unique_urls_test.go b/pkg/flags/unique_urls_test.go index a37e9ca35fa5..816ba6768fbc 100644 --- a/pkg/flags/unique_urls_test.go +++ b/pkg/flags/unique_urls_test.go @@ -104,7 +104,7 @@ func TestUniqueURLsFromFlag(t *testing.T) { fs.Var(u, name, "usage") uss := UniqueURLsFromFlag(fs, name) - require.Equal(t, len(u.Values), len(uss)) + require.Len(t, uss, len(u.Values)) um := make(map[string]struct{}) for _, x := range uss { diff --git a/pkg/flags/urls_test.go b/pkg/flags/urls_test.go index 4b8429264bfa..aafd58e833bd 100644 --- a/pkg/flags/urls_test.go +++ b/pkg/flags/urls_test.go @@ -19,6 +19,7 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -40,9 +41,8 @@ func TestValidateURLsValueBad(t *testing.T) { } for i, in := range tests { u := URLsValue{} - if err := u.Set(in); err == nil { - t.Errorf(`#%d: unexpected nil error for in=%q`, i, in) - } + err := u.Set(in) + assert.Errorf(t, err, `#%d: unexpected nil error for in=%q`, i, in) } } diff --git a/pkg/httputil/httputil_test.go b/pkg/httputil/httputil_test.go index f14d597ed9c4..9fb7880f9094 100644 --- a/pkg/httputil/httputil_test.go +++ b/pkg/httputil/httputil_test.go @@ -17,6 +17,8 @@ package httputil import ( "net/http" "testing" + + "github.com/stretchr/testify/assert" ) func TestGetHostname(t *testing.T) { @@ -43,8 +45,6 @@ func TestGetHostname(t *testing.T) { } for i := range tt { hv := GetHostname(tt[i].req) - if hv != tt[i].host { - t.Errorf("#%d: %q expected host %q, got '%v'", i, tt[i].req.Host, tt[i].host, hv) - } + assert.Equalf(t, hv, tt[i].host, "#%d: %q expected host %q, got '%v'", i, tt[i].req.Host, tt[i].host, hv) } } diff --git a/pkg/idutil/id_test.go b/pkg/idutil/id_test.go index 92be7fb3569b..9bdeac93dda5 100644 --- a/pkg/idutil/id_test.go +++ b/pkg/idutil/id_test.go @@ -17,15 +17,15 @@ package idutil import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestNewGenerator(t *testing.T) { g := NewGenerator(0x12, time.Unix(0, 0).Add(0x3456*time.Millisecond)) id := g.Next() wid := uint64(0x12000000345601) - if id != wid { - t.Errorf("id = %x, want %x", id, wid) - } + assert.Equalf(t, id, wid, "id = %x, want %x", id, wid) } func TestNewGeneratorUnique(t *testing.T) { @@ -33,14 +33,12 @@ func TestNewGeneratorUnique(t *testing.T) { id := g.Next() // different server generates different ID g1 := NewGenerator(1, time.Time{}) - if gid := g1.Next(); id == gid { - t.Errorf("generate the same id %x using different server ID", id) - } + gid := g1.Next() + assert.NotEqualf(t, id, gid, "generate the same id %x using different server ID", id) // restarted server generates different ID g2 := NewGenerator(0, time.Now()) - if gid := g2.Next(); id == gid { - t.Errorf("generate the same id %x after restart", id) - } + gid = g2.Next() + assert.NotEqualf(t, id, gid, "generate the same id %x after restart", id) } func TestNext(t *testing.T) { @@ -48,9 +46,7 @@ func TestNext(t *testing.T) { wid := uint64(0x12000000345601) for i := 0; i < 1000; i++ { id := g.Next() - if id != wid+uint64(i) { - t.Errorf("id = %x, want %x", id, wid+uint64(i)) - } + assert.Equalf(t, id, wid+uint64(i), "id = %x, want %x", id, wid+uint64(i)) } } diff --git a/pkg/ioutil/pagewriter_test.go b/pkg/ioutil/pagewriter_test.go index 77f1336bbe83..1a53bf11f1d5 100644 --- a/pkg/ioutil/pagewriter_test.go +++ b/pkg/ioutil/pagewriter_test.go @@ -32,9 +32,7 @@ func TestPageWriterRandom(t *testing.T) { n := 0 for i := 0; i < 4096; i++ { c, err := w.Write(buf[:rand.Intn(len(buf))]) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) n += c } require.LessOrEqualf(t, cw.writeBytes, n, "wrote %d bytes to io.Writer, but only wrote %d bytes", cw.writeBytes, n) @@ -53,26 +51,20 @@ func TestPageWriterPartialSlack(t *testing.T) { cw := &checkPageWriter{pageBytes: 64, t: t} w := NewPageWriter(cw, pageBytes, 0) // put writer in non-zero page offset - if _, err := w.Write(buf[:64]); err != nil { - t.Fatal(err) - } - if err := w.Flush(); err != nil { - t.Fatal(err) - } + _, err := w.Write(buf[:64]) + require.NoError(t, err) + require.NoError(t, w.Flush()) require.Equalf(t, 1, cw.writes, "got %d writes, expected 1", cw.writes) // nearly fill buffer - if _, err := w.Write(buf[:1022]); err != nil { - t.Fatal(err) - } + _, err = w.Write(buf[:1022]) + require.NoError(t, err) // overflow buffer, but without enough to write as aligned - if _, err := w.Write(buf[:8]); err != nil { - t.Fatal(err) - } + _, err = w.Write(buf[:8]) + require.NoError(t, err) require.Equalf(t, 1, cw.writes, "got %d writes, expected 1", cw.writes) // finish writing slack space - if _, err := w.Write(buf[:128]); err != nil { - t.Fatal(err) - } + _, err = w.Write(buf[:128]) + require.NoError(t, err) require.Equalf(t, 2, cw.writes, "got %d writes, expected 2", cw.writes) } @@ -83,21 +75,15 @@ func TestPageWriterOffset(t *testing.T) { buf := make([]byte, defaultBufferBytes) cw := &checkPageWriter{pageBytes: 64, t: t} w := NewPageWriter(cw, pageBytes, 0) - if _, err := w.Write(buf[:64]); err != nil { - t.Fatal(err) - } - if err := w.Flush(); err != nil { - t.Fatal(err) - } + _, err := w.Write(buf[:64]) + require.NoError(t, err) + require.NoError(t, w.Flush()) require.Equalf(t, 64, w.pageOffset, "w.pageOffset expected 64, got %d", w.pageOffset) w = NewPageWriter(cw, w.pageOffset, pageBytes) - if _, err := w.Write(buf[:64]); err != nil { - t.Fatal(err) - } - if err := w.Flush(); err != nil { - t.Fatal(err) - } + _, err = w.Write(buf[:64]) + require.NoError(t, err) + require.NoError(t, w.Flush()) require.Equalf(t, 0, w.pageOffset, "w.pageOffset expected 0, got %d", w.pageOffset) } diff --git a/pkg/ioutil/reader_test.go b/pkg/ioutil/reader_test.go index 06ff2906c40a..84460b93da2c 100644 --- a/pkg/ioutil/reader_test.go +++ b/pkg/ioutil/reader_test.go @@ -17,6 +17,9 @@ package ioutil import ( "bytes" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLimitedBufferReaderRead(t *testing.T) { @@ -24,10 +27,6 @@ func TestLimitedBufferReaderRead(t *testing.T) { ln := 1 lr := NewLimitedBufferReader(buf, ln) n, err := lr.Read(make([]byte, 10)) - if err != nil { - t.Fatalf("unexpected read error: %v", err) - } - if n != ln { - t.Errorf("len(data read) = %d, want %d", n, ln) - } + require.NoErrorf(t, err, "unexpected read error: %v", err) + assert.Equalf(t, n, ln, "len(data read) = %d, want %d", n, ln) } diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go index 5b4551e1fc7d..47a9a4df0988 100644 --- a/pkg/netutil/netutil_test.go +++ b/pkg/netutil/netutil_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -136,14 +138,10 @@ func TestResolveTCPAddrs(t *testing.T) { urls, err := resolveTCPAddrs(ctx, zaptest.NewLogger(t), tt.urls) cancel() if tt.hasError { - if err == nil { - t.Errorf("expected error") - } + require.Errorf(t, err, "expected error") continue } - if !reflect.DeepEqual(urls, tt.expected) { - t.Errorf("expected: %v, got %v", tt.expected, urls) - } + assert.Truef(t, reflect.DeepEqual(urls, tt.expected), "expected: %v, got %v", tt.expected, urls) } } @@ -308,9 +306,7 @@ func TestURLsEqual(t *testing.T) { for i, test := range tests { result, err := urlsEqual(context.TODO(), zaptest.NewLogger(t), test.a, test.b) - if result != test.expect { - t.Errorf("idx=%d #%d: a:%v b:%v, expected %v but %v", i, test.n, test.a, test.b, test.expect, result) - } + assert.Equalf(t, result, test.expect, "idx=%d #%d: a:%v b:%v, expected %v but %v", i, test.n, test.a, test.b, test.expect, result) if test.err != nil { if err.Error() != test.err.Error() { t.Errorf("idx=%d #%d: err expected %v but %v", i, test.n, test.err, err) @@ -347,11 +343,7 @@ func TestURLStringsEqual(t *testing.T) { t.Logf("TestURLStringsEqual, case #%d", idx) resolveTCPAddr = c.resolver result, err := URLStringsEqual(context.TODO(), zaptest.NewLogger(t), c.urlsA, c.urlsB) - if !result { - t.Errorf("unexpected result %v", result) - } - if err != nil { - t.Errorf("unexpected error %v", err) - } + assert.Truef(t, result, "unexpected result %v", result) + assert.NoErrorf(t, err, "unexpected error %v", err) } } diff --git a/pkg/netutil/routes_linux_test.go b/pkg/netutil/routes_linux_test.go index a0056e990e7a..16b34bc74a4c 100644 --- a/pkg/netutil/routes_linux_test.go +++ b/pkg/netutil/routes_linux_test.go @@ -16,20 +16,20 @@ package netutil -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) func TestGetDefaultInterface(t *testing.T) { ifc, err := GetDefaultInterfaces() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Logf("default network interfaces: %+v\n", ifc) } func TestGetDefaultHost(t *testing.T) { ip, err := GetDefaultHost() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Logf("default ip: %v", ip) } diff --git a/pkg/osutil/osutil_test.go b/pkg/osutil/osutil_test.go index 28fcc7288b10..01f303226189 100644 --- a/pkg/osutil/osutil_test.go +++ b/pkg/osutil/osutil_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -29,9 +30,7 @@ func init() { setDflSignal = func(syscall.Signal) {} } func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) { select { case s := <-c: - if s != sig { - t.Fatalf("signal was %v, want %v", s, sig) - } + require.Equalf(t, s, sig, "signal was %v, want %v", s, sig) case <-time.After(1 * time.Second): t.Fatalf("timeout waiting for %v", sig) } @@ -54,12 +53,8 @@ func TestHandleInterrupts(t *testing.T) { waitSig(t, c, sig) waitSig(t, c, sig) - if n == 3 { - t.Fatalf("interrupt handlers were called in wrong order") - } - if n != 4 { - t.Fatalf("interrupt handlers were not called properly") - } + require.NotEqualf(t, 3, n, "interrupt handlers were called in wrong order") + require.Equalf(t, 4, n, "interrupt handlers were not called properly") // reset interrupt handlers interruptHandlers = interruptHandlers[:0] interruptExitMu.Unlock() diff --git a/pkg/pbutil/pbutil_test.go b/pkg/pbutil/pbutil_test.go index 5a8dd9a17ef4..c773facde7f8 100644 --- a/pkg/pbutil/pbutil_test.go +++ b/pkg/pbutil/pbutil_test.go @@ -18,21 +18,21 @@ import ( "errors" "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestMarshaler(t *testing.T) { data := []byte("test data") m := &fakeMarshaler{data: data} - if g := MustMarshal(m); !reflect.DeepEqual(g, data) { - t.Errorf("data = %s, want %s", g, m.data) - } + g := MustMarshal(m) + assert.Truef(t, reflect.DeepEqual(g, data), "data = %s, want %s", g, m.data) } func TestMarshalerPanic(t *testing.T) { defer func() { - if r := recover(); r == nil { - t.Errorf("recover = nil, want error") - } + r := recover() + assert.NotNilf(t, r, "recover = nil, want error") }() m := &fakeMarshaler{err: errors.New("blah")} MustMarshal(m) @@ -42,16 +42,13 @@ func TestUnmarshaler(t *testing.T) { data := []byte("test data") m := &fakeUnmarshaler{} MustUnmarshal(m, data) - if !reflect.DeepEqual(m.data, data) { - t.Errorf("data = %s, want %s", m.data, data) - } + assert.Truef(t, reflect.DeepEqual(m.data, data), "data = %s, want %s", m.data, data) } func TestUnmarshalerPanic(t *testing.T) { defer func() { - if r := recover(); r == nil { - t.Errorf("recover = nil, want error") - } + r := recover() + assert.NotNilf(t, r, "recover = nil, want error") }() m := &fakeUnmarshaler{err: errors.New("blah")} MustUnmarshal(m, nil) @@ -69,12 +66,8 @@ func TestGetBool(t *testing.T) { } for i, tt := range tests { b, set := GetBool(tt.b) - if b != tt.wb { - t.Errorf("#%d: value = %v, want %v", i, b, tt.wb) - } - if set != tt.wset { - t.Errorf("#%d: set = %v, want %v", i, set, tt.wset) - } + assert.Equalf(t, b, tt.wb, "#%d: value = %v, want %v", i, b, tt.wb) + assert.Equalf(t, set, tt.wset, "#%d: set = %v, want %v", i, set, tt.wset) } } diff --git a/pkg/proxy/server_test.go b/pkg/proxy/server_test.go index c1d3805cfb02..2739c76858c0 100644 --- a/pkg/proxy/server_test.go +++ b/pkg/proxy/server_test.go @@ -84,14 +84,14 @@ func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { go func() { defer close(donec) for data := range writec { - send(t, data, scheme, srcAddr, tlsInfo) + send(t, data, scheme, srcAddr, tlsInfo) //nolint:testifylint //FIXME } }() recvc := make(chan []byte, 1) go func() { for i := 0; i < 2; i++ { - recvc <- receive(t, ln) + recvc <- receive(t, ln) //nolint:testifylint //FIXME } }() @@ -146,9 +146,7 @@ func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { default: } - if err := p.Close(); err != nil { - t.Fatal(err) - } + require.NoError(t, p.Close()) select { case <-p.Done(): @@ -207,31 +205,25 @@ func testServerDelayAccept(t *testing.T, secure bool) { now := time.Now() send(t, data, scheme, srcAddr, tlsInfo) - if d := receive(t, ln); !bytes.Equal(data, d) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } + d := receive(t, ln) + require.Truef(t, bytes.Equal(data, d), "expected %q, got %q", string(data), string(d)) took1 := time.Since(now) t.Logf("took %v with no latency", took1) lat, rv := 700*time.Millisecond, 10*time.Millisecond p.DelayAccept(lat, rv) defer p.UndelayAccept() - if err := p.ResetListener(); err != nil { - t.Fatal(err) - } + require.NoError(t, p.ResetListener()) time.Sleep(200 * time.Millisecond) now = time.Now() send(t, data, scheme, srcAddr, tlsInfo) - if d := receive(t, ln); !bytes.Equal(data, d) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } + d = receive(t, ln) + require.Truef(t, bytes.Equal(data, d), "expected %q, got %q", string(data), string(d)) took2 := time.Since(now) t.Logf("took %v with latency %v±%v", took2, lat, rv) - if took1 >= took2 { - t.Fatalf("expected took1 %v < took2 %v", took1, took2) - } + require.Lessf(t, took1, took2, "expected took1 %v < took2 %v", took1, took2) } func TestServer_PauseTx(t *testing.T) { @@ -262,7 +254,7 @@ func TestServer_PauseTx(t *testing.T) { recvc := make(chan []byte, 1) go func() { - recvc <- receive(t, ln) + recvc <- receive(t, ln) //nolint:testifylint //FIXME }() select { @@ -275,9 +267,7 @@ func TestServer_PauseTx(t *testing.T) { select { case d := <-recvc: - if !bytes.Equal(data, d) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } + require.Truef(t, bytes.Equal(data, d), "expected %q, got %q", string(data), string(d)) case <-time.After(2 * time.Second): t.Fatal("took too long to receive after unpause") } @@ -310,15 +300,13 @@ func TestServer_ModifyTx_corrupt(t *testing.T) { }) data := []byte("Hello World!") send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); bytes.Equal(d, data) { - t.Fatalf("expected corrupted data, got %q", string(d)) - } + d := receive(t, ln) + require.Falsef(t, bytes.Equal(d, data), "expected corrupted data, got %q", string(d)) p.UnmodifyTx() send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected uncorrupted data, got %q", string(d)) - } + d = receive(t, ln) + require.Truef(t, bytes.Equal(d, data), "expected uncorrupted data, got %q", string(d)) } func TestServer_ModifyTx_packet_loss(t *testing.T) { @@ -349,15 +337,13 @@ func TestServer_ModifyTx_packet_loss(t *testing.T) { }) data := []byte("Hello World!") send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); bytes.Equal(d, data) { - t.Fatalf("expected corrupted data, got %q", string(d)) - } + d := receive(t, ln) + require.Falsef(t, bytes.Equal(d, data), "expected corrupted data, got %q", string(d)) p.UnmodifyTx() send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected uncorrupted data, got %q", string(d)) - } + d = receive(t, ln) + require.Truef(t, bytes.Equal(d, data), "expected uncorrupted data, got %q", string(d)) } func TestServer_BlackholeTx(t *testing.T) { @@ -388,7 +374,7 @@ func TestServer_BlackholeTx(t *testing.T) { recvc := make(chan []byte, 1) go func() { - recvc <- receive(t, ln) + recvc <- receive(t, ln) //nolint:testifylint //FIXME }() select { @@ -405,9 +391,7 @@ func TestServer_BlackholeTx(t *testing.T) { select { case d := <-recvc: - if !bytes.Equal(data, d) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } + require.Truef(t, bytes.Equal(data, d), "expected %q, got %q", string(data), string(d)) case <-time.After(2 * time.Second): t.Fatal("took too long to receive after unblackhole") } @@ -440,9 +424,8 @@ func TestServer_Shutdown(t *testing.T) { data := []byte("Hello World!") send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } + d := receive(t, ln) + require.Truef(t, bytes.Equal(d, data), "expected %q, got %q", string(data), string(d)) } func TestServer_ShutdownListener(t *testing.T) { @@ -476,9 +459,8 @@ func TestServer_ShutdownListener(t *testing.T) { data := []byte("Hello World!") send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } + d := receive(t, ln) + require.Truef(t, bytes.Equal(d, data), "expected %q, got %q", string(data), string(d)) } func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) } @@ -497,20 +479,15 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) { d, err := io.ReadAll(req.Body) req.Body.Close() - if err != nil { - t.Fatal(err) - } - if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil { - t.Fatal(err) - } + require.NoError(t, err) //nolint:testifylint //FIXME + _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))) + require.NoError(t, err) //nolint:testifylint //FIXME }) tlsInfo := createTLSInfo(lg, secure) var tlsConfig *tls.Config if secure { _, err := tlsInfo.ServerConfig() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) } srv := &http.Server{ Addr: dstAddr, @@ -570,18 +547,14 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { } require.NoError(t, err) d, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp.Body.Close() took1 := time.Since(now) t.Logf("took %v with no latency", took1) rs1 := string(d) exp := fmt.Sprintf("%q(confirmed)", data) - if rs1 != exp { - t.Fatalf("got %q, expected %q", rs1, exp) - } + require.Equalf(t, exp, rs1, "got %q, expected %q", rs1, exp) lat, rv := 100*time.Millisecond, 10*time.Millisecond if delayTx { @@ -595,9 +568,7 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { now = time.Now() if secure { tp, terr := transport.NewTransport(tlsInfo, 3*time.Second) - if terr != nil { - t.Fatal(terr) - } + require.NoError(t, terr) cli := &http.Client{Transport: tp} resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data)) defer cli.CloseIdleConnections() @@ -606,24 +577,16 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data)) defer http.DefaultClient.CloseIdleConnections() } - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) d, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp.Body.Close() took2 := time.Since(now) t.Logf("took %v with latency %v±%v", took2, lat, rv) rs2 := string(d) - if rs2 != exp { - t.Fatalf("got %q, expected %q", rs2, exp) - } - if took1 > took2 { - t.Fatalf("expected took1 %v < took2 %v", took1, took2) - } + require.Equalf(t, exp, rs2, "got %q, expected %q", rs2, exp) + require.LessOrEqualf(t, took1, took2, "expected took1 %v < took2 %v", took1, took2) } func newUnixAddr() string { @@ -640,9 +603,7 @@ func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln ne } else { ln, err = net.Listen(scheme, addr) } - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return ln } @@ -651,36 +612,25 @@ func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSI var err error if !tlsInfo.Empty() { tp, terr := transport.NewTransport(tlsInfo, 3*time.Second) - if terr != nil { - t.Fatal(terr) - } + require.NoError(t, terr) out, err = tp.DialContext(context.Background(), scheme, addr) } else { out, err = net.Dial(scheme, addr) } - if err != nil { - t.Fatal(err) - } - if _, err = out.Write(data); err != nil { - t.Fatal(err) - } - if err = out.Close(); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + _, err = out.Write(data) + require.NoError(t, err) + require.NoError(t, out.Close()) } func receive(t *testing.T, ln net.Listener) (data []byte) { buf := bytes.NewBuffer(make([]byte, 0, 1024)) for { in, err := ln.Accept() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) var n int64 n, err = buf.ReadFrom(in) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if n > 0 { break } diff --git a/pkg/report/report_test.go b/pkg/report/report_test.go index d6bdc3bf95c9..0c9e370289ba 100644 --- a/pkg/report/report_test.go +++ b/pkg/report/report_test.go @@ -17,10 +17,10 @@ package report import ( "fmt" "reflect" - "strings" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -28,16 +28,12 @@ func TestPercentiles(t *testing.T) { nums := make([]float64, 100) nums[99] = 1 // 99-percentile (1 out of 100) data := percentiles(nums) - if data[len(pctls)-2] != 1 { - t.Fatalf("99-percentile expected 1, got %f", data[len(pctls)-2]) - } + require.InDeltaf(t, 1, data[len(pctls)-2], 0.0, "99-percentile expected 1, got %f", data[len(pctls)-2]) nums = make([]float64, 1000) nums[999] = 1 // 99.9-percentile (1 out of 1000) data = percentiles(nums) - if data[len(pctls)-1] != 1 { - t.Fatalf("99.9-percentile expected 1, got %f", data[len(pctls)-1]) - } + require.InDeltaf(t, 1, data[len(pctls)-1], 0.0, "99.9-percentile expected 1, got %f", data[len(pctls)-1]) } func TestReport(t *testing.T) { @@ -76,9 +72,7 @@ func TestReport(t *testing.T) { } ss := <-r.Run() for i, ws := range wstrs { - if !strings.Contains(ss, ws) { - t.Errorf("#%d: stats string missing %s", i, ws) - } + assert.Containsf(t, ss, ws, "#%d: stats string missing %s", i, ws) } } diff --git a/pkg/schedule/schedule_test.go b/pkg/schedule/schedule_test.go index af0b5e613eac..898dc0ee91d4 100644 --- a/pkg/schedule/schedule_test.go +++ b/pkg/schedule/schedule_test.go @@ -19,6 +19,7 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -53,7 +54,5 @@ func TestFIFOSchedule(t *testing.T) { } s.WaitFinish(100) - if s.Finished() != 100 { - t.Errorf("finished = %d, want %d", s.Finished(), 100) - } + assert.Equalf(t, 100, s.Finished(), "finished = %d, want %d", s.Finished(), 100) } diff --git a/pkg/traceutil/trace_test.go b/pkg/traceutil/trace_test.go index 4d6d3513f3a0..7b4ae28e61ce 100644 --- a/pkg/traceutil/trace_test.go +++ b/pkg/traceutil/trace_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.etcd.io/etcd/client/pkg/v3/logutil" @@ -50,9 +51,7 @@ func TestGet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { trace := Get(tt.inputCtx) - if trace == nil { - t.Errorf("Expected %v; Got nil", tt.outputTrace) - } + assert.NotNilf(t, trace, "Expected %v; Got nil", tt.outputTrace) if tt.outputTrace == nil || trace.operation != tt.outputTrace.operation { t.Errorf("Expected %v; Got %v", tt.outputTrace, trace) } @@ -75,16 +74,10 @@ func TestCreate(t *testing.T) { ) trace := New(op, nil, fields[0], fields[1]) - if trace.operation != op { - t.Errorf("Expected %v; Got %v", op, trace.operation) - } + assert.Equalf(t, trace.operation, op, "Expected %v; Got %v", op, trace.operation) for i, f := range trace.fields { - if f.Key != fields[i].Key { - t.Errorf("Expected %v; Got %v", fields[i].Key, f.Key) - } - if f.Value != fields[i].Value { - t.Errorf("Expected %v; Got %v", fields[i].Value, f.Value) - } + assert.Equalf(t, f.Key, fields[i].Key, "Expected %v; Got %v", fields[i].Key, f.Key) + assert.Equalf(t, f.Value, fields[i].Value, "Expected %v; Got %v", fields[i].Value, f.Value) } for i, v := range steps { @@ -92,15 +85,9 @@ func TestCreate(t *testing.T) { } for i, v := range trace.steps { - if steps[i] != v.msg { - t.Errorf("Expected %v; Got %v", steps[i], v.msg) - } - if stepFields[i].Key != v.fields[0].Key { - t.Errorf("Expected %v; Got %v", stepFields[i].Key, v.fields[0].Key) - } - if stepFields[i].Value != v.fields[0].Value { - t.Errorf("Expected %v; Got %v", stepFields[i].Value, v.fields[0].Value) - } + assert.Equalf(t, steps[i], v.msg, "Expected %v; Got %v", steps[i], v.msg) + assert.Equalf(t, stepFields[i].Key, v.fields[0].Key, "Expected %v; Got %v", stepFields[i].Key, v.fields[0].Key) + assert.Equalf(t, stepFields[i].Value, v.fields[0].Value, "Expected %v; Got %v", stepFields[i].Value, v.fields[0].Value) } } @@ -220,9 +207,7 @@ func TestLog(t *testing.T) { require.NoError(t, err) for _, msg := range tt.expectedMsg { - if !bytes.Contains(data, []byte(msg)) { - t.Errorf("Expected to find %v in log", msg) - } + assert.Truef(t, bytes.Contains(data, []byte(msg)), "Expected to find %v in log", msg) } }) } @@ -295,9 +280,7 @@ func TestLogIfLong(t *testing.T) { data, err := os.ReadFile(logPath) require.NoError(t, err) for _, msg := range tt.expectedMsg { - if !bytes.Contains(data, []byte(msg)) { - t.Errorf("Expected to find %v in log", msg) - } + assert.Truef(t, bytes.Contains(data, []byte(msg)), "Expected to find %v in log", msg) } }) } diff --git a/pkg/wait/wait_test.go b/pkg/wait/wait_test.go index 54395cb360c8..482e9b646576 100644 --- a/pkg/wait/wait_test.go +++ b/pkg/wait/wait_test.go @@ -18,6 +18,8 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestWait(t *testing.T) { @@ -26,9 +28,8 @@ func TestWait(t *testing.T) { ch := wt.Register(eid) wt.Trigger(eid, "foo") v := <-ch - if g, w := fmt.Sprintf("%v (%T)", v, v), "foo (string)"; g != w { - t.Errorf("<-ch = %v, want %v", g, w) - } + g, w := fmt.Sprintf("%v (%T)", v, v), "foo (string)" + assert.Equalf(t, g, w, "<-ch = %v, want %v", g, w) if g := <-ch; g != nil { t.Errorf("unexpected non-nil value: %v (%T)", g, g) @@ -69,9 +70,8 @@ func TestTriggerDupSuppression(t *testing.T) { wt.Trigger(eid, "bar") v := <-ch - if g, w := fmt.Sprintf("%v (%T)", v, v), "foo (string)"; g != w { - t.Errorf("<-ch = %v, want %v", g, w) - } + g, w := fmt.Sprintf("%v (%T)", v, v), "foo (string)" + assert.Equalf(t, g, w, "<-ch = %v, want %v", g, w) if g := <-ch; g != nil { t.Errorf("unexpected non-nil value: %v (%T)", g, g) @@ -86,17 +86,11 @@ func TestIsRegistered(t *testing.T) { wt.Register(2) for i := uint64(0); i < 3; i++ { - if !wt.IsRegistered(i) { - t.Errorf("event ID %d isn't registered", i) - } + assert.Truef(t, wt.IsRegistered(i), "event ID %d isn't registered", i) } - if wt.IsRegistered(4) { - t.Errorf("event ID 4 shouldn't be registered") - } + assert.Falsef(t, wt.IsRegistered(4), "event ID 4 shouldn't be registered") wt.Trigger(0, "foo") - if wt.IsRegistered(0) { - t.Errorf("event ID 0 is already triggered, shouldn't be registered") - } + assert.Falsef(t, wt.IsRegistered(0), "event ID 0 is already triggered, shouldn't be registered") }