From 8c5c5d097871abb90149d02883cfc84dcb17981d Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 13:07:21 +0800 Subject: [PATCH 1/9] Add logWithCheckingPanic and its unit tests --- api/log.go | 22 +++++++++++++++++ api/log_test.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 api/log.go create mode 100644 api/log_test.go diff --git a/api/log.go b/api/log.go new file mode 100644 index 00000000..a98de6e2 --- /dev/null +++ b/api/log.go @@ -0,0 +1,22 @@ +package api + +import ( + "fmt" + "net/http" +) + +type logFunc func(statusCode int, err error) + +// logWithCheckingPanic attemps to recover from a possible panic, +// modifies statusCode and err if there was indeed a panic, +// passes the possibly updated status and err to the logFunc, +// then panics again if there was indeed a panic to +// make UnaryInterceptor in server/server.go return "internal server error" to the client. +func logWithCheckingPanic(f logFunc, statusCode int, err error) { + if r := recover(); r != nil { + statusCode = http.StatusInternalServerError + err = fmt.Errorf("panic: %v", r) + defer panic(r) + } + f(statusCode, err) +} diff --git a/api/log_test.go b/api/log_test.go new file mode 100644 index 00000000..99cabc34 --- /dev/null +++ b/api/log_test.go @@ -0,0 +1,65 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "testing" +) + +type logTestCase struct { + name string + panicInput interface{} +} + +func TestLogWithCheckingPanic(t *testing.T) { + t.Parallel() + testCases := []*logTestCase{ + { + name: "panic with string", + panicInput: "string", + }, + { + name: "panic with error", + panicInput: errors.New("error"), + }, + { + name: "no panic", + panicInput: nil, + }, + } + + for _, tc := range testCases { + testLogWithCheckingPanic(t, tc) + } +} + +func testLogWithCheckingPanic(t *testing.T, tc *logTestCase) { + const ( + logStr = "st: %d, err: %v" + inputStatusCode = http.StatusOK + ) + var inputError error + + want := fmt.Sprintf(logStr, http.StatusInternalServerError, "panic: " + fmt.Sprintf("%s", tc.panicInput)) + if tc.panicInput == nil { + want = fmt.Sprintf(logStr, inputStatusCode, inputError) + } + + // Make the channel buffered to avoid deadlock. + logC := make(chan string, 1) + f := func(statusCode int, err error) { + logC <- fmt.Sprintf(logStr, statusCode, err) + } + + defer func() { + // Capture the panic thrown from logWithCheckingPanic. + recover() + got := <-logC + if get != want { + t.Errorf("%s failed, got: %s, want: %s", tc.name, got, want) + } + }() + defer logWithCheckingPanic(f, inputStatusCode, inputError) + panic(tc.panicInput) +} \ No newline at end of file From bfd22eae8d6c494492a6f6317f8021cad1954f6b Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 13:20:58 +0800 Subject: [PATCH 2/9] Replace recoverIfPanicked with logWithCheckingPanic --- api/blob.go | 19 +++++++++++++------ api/sshhost.go | 19 +++++++++++++------ api/sshuser.go | 19 +++++++++++++------ api/x509cert.go | 18 ++++++++++++------ 4 files changed, 51 insertions(+), 24 deletions(-) diff --git a/api/blob.go b/api/blob.go index dde82a72..033ea792 100644 --- a/api/blob.go +++ b/api/blob.go @@ -27,9 +27,11 @@ func (s *SigningService) GetBlobAvailableSigningKeys(ctx context.Context, e *emp var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.BlobEndpoint] { @@ -47,9 +49,11 @@ func (s *SigningService) GetBlobSigningKey(ctx context.Context, keyMeta *proto.K var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -80,9 +84,12 @@ func (s *SigningService) PostSignBlob(ctx context.Context, request *proto.BlobSi var err error defer func() { - log.Printf(`m=%s,digest=%q,hash=%q,st=%d,et=%d,err="%v"`, methodName, request.GetDigest(), request.HashAlgorithm.String(), statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,digest=%q,hash=%q,st=%d,et=%d,err="%v"`, + methodName, request.GetDigest(), request.HashAlgorithm.String(), statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if request.KeyMeta == nil { statusCode = http.StatusBadRequest diff --git a/api/sshhost.go b/api/sshhost.go index 09c9a432..885237d5 100644 --- a/api/sshhost.go +++ b/api/sshhost.go @@ -27,9 +27,11 @@ func (s *SigningService) GetHostSSHCertificateAvailableSigningKeys(ctx context.C var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.SSHHostCertEndpoint] { @@ -47,9 +49,11 @@ func (s *SigningService) GetHostSSHCertificateSigningKey(ctx context.Context, ke var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -84,9 +88,12 @@ func (s *SigningService) PostHostSSHCertificate(ctx context.Context, request *pr if cert != nil { kid = cert.KeyId } - log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, + methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if request.KeyMeta == nil { statusCode = http.StatusBadRequest diff --git a/api/sshuser.go b/api/sshuser.go index e3df3385..2b966bca 100644 --- a/api/sshuser.go +++ b/api/sshuser.go @@ -27,9 +27,11 @@ func (s *SigningService) GetUserSSHCertificateAvailableSigningKeys(ctx context.C var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.SSHUserCertEndpoint] { @@ -47,9 +49,11 @@ func (s *SigningService) GetUserSSHCertificateSigningKey(ctx context.Context, ke var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -84,9 +88,12 @@ func (s *SigningService) PostUserSSHCertificate(ctx context.Context, request *pr if cert != nil { kid = cert.KeyId } - log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, + methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if request.KeyMeta == nil { statusCode = http.StatusBadRequest diff --git a/api/x509cert.go b/api/x509cert.go index 0a794052..d67be974 100644 --- a/api/x509cert.go +++ b/api/x509cert.go @@ -27,9 +27,11 @@ func (s *SigningService) GetX509CertificateAvailableSigningKeys(ctx context.Cont var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.X509CertEndpoint] { @@ -46,9 +48,11 @@ func (s *SigningService) GetX509CACertificate(ctx context.Context, keyMeta *prot var err error defer func() { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -79,9 +83,11 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto var err error defer func() { - log.Printf(`m=%s,sub=%q,st=%d,et=%d,err="%v"`, methodName, subject, statusCode, timeElapsedSince(start), err) + f := func(statusCode int, err error) { + log.Printf(`m=%s,sub=%q,st=%d,et=%d,err="%v"`, methodName, subject, statusCode, timeElapsedSince(start), err) + } + logWithCheckingPanic(f, statusCode, err) }() - defer recoverIfPanicked(methodName) if request.KeyMeta == nil { statusCode = http.StatusBadRequest From 05d8db6f940ced41411498ca8f59df9a0687159e Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 13:21:16 +0800 Subject: [PATCH 3/9] Remove recoverIfPanicked --- api/sign.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/api/sign.go b/api/sign.go index 76252a75..7c9d83a5 100644 --- a/api/sign.go +++ b/api/sign.go @@ -5,7 +5,6 @@ package api import ( "fmt" - "log" "time" "github.com/theparanoids/crypki" @@ -19,14 +18,6 @@ type SigningService struct { MaxValidity map[string]uint64 } -// recoverIfPanicked recovers from panic and logs the error. -func recoverIfPanicked(method string) { - if r := recover(); r != nil { - log.Printf("%s: recovered from panic, panic: %v", method, r) - panic(r) - } -} - // timeElapsedSince returns time elapsed since start time in microseconds. func timeElapsedSince(start time.Time) int64 { return time.Since(start).Nanoseconds() / time.Microsecond.Nanoseconds() From 184f72105cdd14a3864eb07b9667d08ff1174541 Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 13:35:16 +0800 Subject: [PATCH 4/9] Fix format and typo --- api/log_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/api/log_test.go b/api/log_test.go index 99cabc34..efd5003f 100644 --- a/api/log_test.go +++ b/api/log_test.go @@ -8,7 +8,7 @@ import ( ) type logTestCase struct { - name string + name string panicInput interface{} } @@ -16,19 +16,19 @@ func TestLogWithCheckingPanic(t *testing.T) { t.Parallel() testCases := []*logTestCase{ { - name: "panic with string", + name: "panic with string", panicInput: "string", }, { - name: "panic with error", + name: "panic with error", panicInput: errors.New("error"), }, { - name: "no panic", + name: "no panic", panicInput: nil, }, } - + for _, tc := range testCases { testLogWithCheckingPanic(t, tc) } @@ -36,12 +36,12 @@ func TestLogWithCheckingPanic(t *testing.T) { func testLogWithCheckingPanic(t *testing.T, tc *logTestCase) { const ( - logStr = "st: %d, err: %v" + logStr = "st: %d, err: %v" inputStatusCode = http.StatusOK ) var inputError error - want := fmt.Sprintf(logStr, http.StatusInternalServerError, "panic: " + fmt.Sprintf("%s", tc.panicInput)) + want := fmt.Sprintf(logStr, http.StatusInternalServerError, "panic: "+fmt.Sprintf("%s", tc.panicInput)) if tc.panicInput == nil { want = fmt.Sprintf(logStr, inputStatusCode, inputError) } @@ -56,10 +56,10 @@ func testLogWithCheckingPanic(t *testing.T, tc *logTestCase) { // Capture the panic thrown from logWithCheckingPanic. recover() got := <-logC - if get != want { + if got != want { t.Errorf("%s failed, got: %s, want: %s", tc.name, got, want) } }() defer logWithCheckingPanic(f, inputStatusCode, inputError) panic(tc.panicInput) -} \ No newline at end of file +} From bef978866f8301750be53b698acfbc00fc27963c Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 13:43:03 +0800 Subject: [PATCH 5/9] Remove the usage of channel because it's unnecessary --- api/log_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/api/log_test.go b/api/log_test.go index efd5003f..26cdcd5e 100644 --- a/api/log_test.go +++ b/api/log_test.go @@ -46,16 +46,14 @@ func testLogWithCheckingPanic(t *testing.T, tc *logTestCase) { want = fmt.Sprintf(logStr, inputStatusCode, inputError) } - // Make the channel buffered to avoid deadlock. - logC := make(chan string, 1) + got := "" f := func(statusCode int, err error) { - logC <- fmt.Sprintf(logStr, statusCode, err) + got = fmt.Sprintf(logStr, statusCode, err) } defer func() { // Capture the panic thrown from logWithCheckingPanic. recover() - got := <-logC if got != want { t.Errorf("%s failed, got: %s, want: %s", tc.name, got, want) } From 8433b2d9e788f70e2a04892ca57712633428b67b Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 14:10:12 +0800 Subject: [PATCH 6/9] Invoke logWithCheckingPanic directly with defer This change is necessary because the extra function call will prevent logWithCheckingPanic from catching the panic. --- api/blob.go | 32 +++++++++++++------------------- api/sshhost.go | 32 +++++++++++++------------------- api/sshuser.go | 32 +++++++++++++------------------- api/x509cert.go | 30 ++++++++++++------------------ 4 files changed, 51 insertions(+), 75 deletions(-) diff --git a/api/blob.go b/api/blob.go index 033ea792..1531fb20 100644 --- a/api/blob.go +++ b/api/blob.go @@ -26,12 +26,10 @@ func (s *SigningService) GetBlobAvailableSigningKeys(ctx context.Context, e *emp start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.BlobEndpoint] { @@ -48,12 +46,10 @@ func (s *SigningService) GetBlobSigningKey(ctx context.Context, keyMeta *proto.K start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -83,13 +79,11 @@ func (s *SigningService) PostSignBlob(ctx context.Context, request *proto.BlobSi start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,digest=%q,hash=%q,st=%d,et=%d,err="%v"`, - methodName, request.GetDigest(), request.HashAlgorithm.String(), statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,digest=%q,hash=%q,st=%d,et=%d,err="%v"`, + methodName, request.GetDigest(), request.HashAlgorithm.String(), statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if request.KeyMeta == nil { statusCode = http.StatusBadRequest diff --git a/api/sshhost.go b/api/sshhost.go index 885237d5..fbc418ca 100644 --- a/api/sshhost.go +++ b/api/sshhost.go @@ -26,12 +26,10 @@ func (s *SigningService) GetHostSSHCertificateAvailableSigningKeys(ctx context.C start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.SSHHostCertEndpoint] { @@ -48,12 +46,10 @@ func (s *SigningService) GetHostSSHCertificateSigningKey(ctx context.Context, ke start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -83,17 +79,15 @@ func (s *SigningService) PostHostSSHCertificate(ctx context.Context, request *pr var err error var cert *ssh.Certificate - defer func() { + f := func(statusCode int, err error) { kid := "" if cert != nil { kid = cert.KeyId } - f := func(statusCode int, err error) { - log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, - methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, + methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if request.KeyMeta == nil { statusCode = http.StatusBadRequest diff --git a/api/sshuser.go b/api/sshuser.go index 2b966bca..cc01b0a4 100644 --- a/api/sshuser.go +++ b/api/sshuser.go @@ -26,12 +26,10 @@ func (s *SigningService) GetUserSSHCertificateAvailableSigningKeys(ctx context.C start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.SSHUserCertEndpoint] { @@ -48,12 +46,10 @@ func (s *SigningService) GetUserSSHCertificateSigningKey(ctx context.Context, ke start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -83,17 +79,15 @@ func (s *SigningService) PostUserSSHCertificate(ctx context.Context, request *pr var err error var cert *ssh.Certificate - defer func() { + f := func(statusCode int, err error) { kid := "" if cert != nil { kid = cert.KeyId } - f := func(statusCode int, err error) { - log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, - methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, + methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if request.KeyMeta == nil { statusCode = http.StatusBadRequest diff --git a/api/x509cert.go b/api/x509cert.go index d67be974..c4020bbf 100644 --- a/api/x509cert.go +++ b/api/x509cert.go @@ -26,12 +26,10 @@ func (s *SigningService) GetX509CertificateAvailableSigningKeys(ctx context.Cont start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) var keys []*proto.KeyMeta for id := range s.KeyUsages[config.X509CertEndpoint] { @@ -47,12 +45,10 @@ func (s *SigningService) GetX509CACertificate(ctx context.Context, keyMeta *prot start := time.Now() var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if keyMeta == nil { statusCode = http.StatusBadRequest @@ -82,12 +78,10 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto subject := pkix.Name{} var err error - defer func() { - f := func(statusCode int, err error) { - log.Printf(`m=%s,sub=%q,st=%d,et=%d,err="%v"`, methodName, subject, statusCode, timeElapsedSince(start), err) - } - logWithCheckingPanic(f, statusCode, err) - }() + f := func(statusCode int, err error) { + log.Printf(`m=%s,sub=%q,st=%d,et=%d,err="%v"`, methodName, subject, statusCode, timeElapsedSince(start), err) + } + defer logWithCheckingPanic(f, statusCode, err) if request.KeyMeta == nil { statusCode = http.StatusBadRequest From 5399d3974f2f00a939c1e36a2b8e5c797b6a365c Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 14:45:08 +0800 Subject: [PATCH 7/9] Make testLogWithCheckingPanic an anonymous function It's not used by other functions, so it's good to limit the visibility of it. --- api/log_test.go | 53 +++++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/api/log_test.go b/api/log_test.go index 26cdcd5e..065796e2 100644 --- a/api/log_test.go +++ b/api/log_test.go @@ -30,34 +30,35 @@ func TestLogWithCheckingPanic(t *testing.T) { } for _, tc := range testCases { - testLogWithCheckingPanic(t, tc) - } -} + // https://github.com/golang/go/wiki/CommonMistakes#using-goroutines-on-loop-iterator-variables + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + const ( + logStr = "st: %d, err: %v" + inputStatusCode = http.StatusOK + ) + var inputError error -func testLogWithCheckingPanic(t *testing.T, tc *logTestCase) { - const ( - logStr = "st: %d, err: %v" - inputStatusCode = http.StatusOK - ) - var inputError error + want := fmt.Sprintf(logStr, http.StatusInternalServerError, "panic: "+fmt.Sprintf("%s", tc.panicInput)) + if tc.panicInput == nil { + want = fmt.Sprintf(logStr, inputStatusCode, inputError) + } - want := fmt.Sprintf(logStr, http.StatusInternalServerError, "panic: "+fmt.Sprintf("%s", tc.panicInput)) - if tc.panicInput == nil { - want = fmt.Sprintf(logStr, inputStatusCode, inputError) - } + got := "" + f := func(statusCode int, err error) { + got = fmt.Sprintf(logStr, statusCode, err) + } - got := "" - f := func(statusCode int, err error) { - got = fmt.Sprintf(logStr, statusCode, err) + defer func() { + // Capture the panic thrown from logWithCheckingPanic. + recover() + if got != want { + t.Errorf("got: %q, want: %q", got, want) + } + }() + defer logWithCheckingPanic(f, inputStatusCode, inputError) + panic(tc.panicInput) + }) } - - defer func() { - // Capture the panic thrown from logWithCheckingPanic. - recover() - if got != want { - t.Errorf("%s failed, got: %s, want: %s", tc.name, got, want) - } - }() - defer logWithCheckingPanic(f, inputStatusCode, inputError) - panic(tc.panicInput) } From 2636e791dc61d93a4d348308270e7ef4b1e19571 Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 15:01:01 +0800 Subject: [PATCH 8/9] Add panicRecoveryPrefix --- api/log.go | 4 +++- api/log_test.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/api/log.go b/api/log.go index a98de6e2..e09cb44e 100644 --- a/api/log.go +++ b/api/log.go @@ -7,6 +7,8 @@ import ( type logFunc func(statusCode int, err error) +const panicRecoveryPrefix = "panic: " + // logWithCheckingPanic attemps to recover from a possible panic, // modifies statusCode and err if there was indeed a panic, // passes the possibly updated status and err to the logFunc, @@ -15,7 +17,7 @@ type logFunc func(statusCode int, err error) func logWithCheckingPanic(f logFunc, statusCode int, err error) { if r := recover(); r != nil { statusCode = http.StatusInternalServerError - err = fmt.Errorf("panic: %v", r) + err = fmt.Errorf("%s%v", panicRecoveryPrefix, r) defer panic(r) } f(statusCode, err) diff --git a/api/log_test.go b/api/log_test.go index 065796e2..2c8c34e7 100644 --- a/api/log_test.go +++ b/api/log_test.go @@ -40,7 +40,7 @@ func TestLogWithCheckingPanic(t *testing.T) { ) var inputError error - want := fmt.Sprintf(logStr, http.StatusInternalServerError, "panic: "+fmt.Sprintf("%s", tc.panicInput)) + want := fmt.Sprintf(logStr, http.StatusInternalServerError, panicRecoveryPrefix+fmt.Sprintf("%s", tc.panicInput)) if tc.panicInput == nil { want = fmt.Sprintf(logStr, inputStatusCode, inputError) } From 2c879a6e295682a0234028e674c6c8f6dcf13e29 Mon Sep 17 00:00:00 2001 From: hchen12 Date: Tue, 7 Jul 2020 15:41:21 +0800 Subject: [PATCH 9/9] Make logTestCase an anonymous struct and 'want' plaintext --- api/log.go | 4 +--- api/log_test.go | 47 ++++++++++++++++++++++------------------------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/api/log.go b/api/log.go index e09cb44e..a98de6e2 100644 --- a/api/log.go +++ b/api/log.go @@ -7,8 +7,6 @@ import ( type logFunc func(statusCode int, err error) -const panicRecoveryPrefix = "panic: " - // logWithCheckingPanic attemps to recover from a possible panic, // modifies statusCode and err if there was indeed a panic, // passes the possibly updated status and err to the logFunc, @@ -17,7 +15,7 @@ const panicRecoveryPrefix = "panic: " func logWithCheckingPanic(f logFunc, statusCode int, err error) { if r := recover(); r != nil { statusCode = http.StatusInternalServerError - err = fmt.Errorf("%s%v", panicRecoveryPrefix, r) + err = fmt.Errorf("panic: %v", r) defer panic(r) } f(statusCode, err) diff --git a/api/log_test.go b/api/log_test.go index 2c8c34e7..3deac325 100644 --- a/api/log_test.go +++ b/api/log_test.go @@ -7,43 +7,40 @@ import ( "testing" ) -type logTestCase struct { - name string - panicInput interface{} -} - func TestLogWithCheckingPanic(t *testing.T) { t.Parallel() - testCases := []*logTestCase{ + testCases := []struct { + name string + input interface{} + want string // See logStr below for the format + }{ { - name: "panic with string", - panicInput: "string", + name: "panic with string", + input: "string", + want: "st: 500, err: panic: string", }, { - name: "panic with error", - panicInput: errors.New("error"), + name: "panic with error", + input: errors.New("error"), + want: "st: 500, err: panic: error", }, { - name: "no panic", - panicInput: nil, + name: "no panic", + input: nil, + want: "st: 200, err: ", // See inputStatusCode below for 200 }, } + const ( + logStr = "st: %d, err: %v" + inputStatusCode = http.StatusOK + ) + var inputError error for _, tc := range testCases { // https://github.com/golang/go/wiki/CommonMistakes#using-goroutines-on-loop-iterator-variables tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - const ( - logStr = "st: %d, err: %v" - inputStatusCode = http.StatusOK - ) - var inputError error - - want := fmt.Sprintf(logStr, http.StatusInternalServerError, panicRecoveryPrefix+fmt.Sprintf("%s", tc.panicInput)) - if tc.panicInput == nil { - want = fmt.Sprintf(logStr, inputStatusCode, inputError) - } got := "" f := func(statusCode int, err error) { @@ -53,12 +50,12 @@ func TestLogWithCheckingPanic(t *testing.T) { defer func() { // Capture the panic thrown from logWithCheckingPanic. recover() - if got != want { - t.Errorf("got: %q, want: %q", got, want) + if got != tc.want { + t.Errorf("got: %q, want: %q", got, tc.want) } }() defer logWithCheckingPanic(f, inputStatusCode, inputError) - panic(tc.panicInput) + panic(tc.input) }) } }