diff --git a/segment.go b/segment.go index 932e03f..26ed868 100644 --- a/segment.go +++ b/segment.go @@ -82,10 +82,17 @@ func (l Logger) Segment(pctx context.Context, message string, options ...Option) // } // } // +//It is expected that errors in the returned function are actual problems, as it will log the error. It is not expected +//that segments will be used where the error is unimportant. func (l Logger) SegmentFn(pctx context.Context, message string, options ...Option) func(func(ctx context.Context) error) error { return func(f func(ctx context.Context) error) error { c, done := l.Segment(pctx, message, options...) + err := f(c) + if err != nil { + l.Error(c, "segment errored", Err(err)) + } + done() return err } diff --git a/segment_test.go b/segment_test.go index 37d9cdb..daf8370 100644 --- a/segment_test.go +++ b/segment_test.go @@ -81,7 +81,7 @@ func TestLogger_Segment(t *testing.T) { func TestLogger_SegmentFn(t *testing.T) { t.Run("starting a segment outputs a message, and closing a segment also outputs a message with fields indicating begin/end, verifies function is called and error returned", func(t *testing.T) { mockImpl := MockImpl{} - mockImpl.On("Impl", mock.Anything, mock.Anything).Times(3) + mockImpl.On("Impl", mock.Anything, mock.Anything).Times(4) expectedMessage := "message" expectedInnerMessage := "inner message" @@ -97,10 +97,11 @@ func TestLogger_SegmentFn(t *testing.T) { assert.True(t, mockImpl.AssertExpectations(t)) assert.Equal(t, io.ErrUnexpectedEOF, err) - var capturedMessage [3]Message + var capturedMessage [4]Message capturedMessage[0] = mockImpl.Calls[0].Arguments.Get(1).(Message) capturedMessage[1] = mockImpl.Calls[1].Arguments.Get(1).(Message) capturedMessage[2] = mockImpl.Calls[2].Arguments.Get(1).(Message) + capturedMessage[3] = mockImpl.Calls[3].Arguments.Get(1).(Message) assert.Equal(t, expectedMessage, capturedMessage[0].Message) assert.Equal(t, SegmentStartValue, capturedMessage[0].Data[SegmentField]) @@ -109,8 +110,10 @@ func TestLogger_SegmentFn(t *testing.T) { assert.Equal(t, expectedInnerMessage, capturedMessage[1].Message) assert.Equal(t, expectedValue, capturedMessage[1].Data[expectedKey]) - assert.Equal(t, expectedMessage, capturedMessage[2].Message) - assert.Equal(t, SegmentEndValue, capturedMessage[2].Data[SegmentField]) - assert.Equal(t, expectedValue, capturedMessage[2].Data[expectedKey]) + assert.Contains(t, capturedMessage[2].Message, "segment errored") + + assert.Equal(t, expectedMessage, capturedMessage[3].Message) + assert.Equal(t, SegmentEndValue, capturedMessage[3].Data[SegmentField]) + assert.Equal(t, expectedValue, capturedMessage[3].Data[expectedKey]) }) }