Skip to content

Commit

Permalink
Merge pull request #43 from davidchen0310/log_panic
Browse files Browse the repository at this point in the history
Log the correct status code when panic
  • Loading branch information
maditya authored Jul 7, 2020
2 parents ac66b39 + 2c879a6 commit 668504e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 48 deletions.
21 changes: 11 additions & 10 deletions api/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func (s *SigningService) GetBlobAvailableSigningKeys(ctx context.Context, e *emp
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

var keys []*proto.KeyMeta
for id := range s.KeyUsages[config.BlobEndpoint] {
Expand All @@ -46,10 +46,10 @@ func (s *SigningService) GetBlobSigningKey(ctx context.Context, keyMeta *proto.K
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

if keyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down Expand Up @@ -79,10 +79,11 @@ func (s *SigningService) PostSignBlob(ctx context.Context, request *proto.BlobSi
start := time.Now()
var err error

defer func() {
log.Printf(`m=%s,digest=%q,hash=%q,st=%d,et=%d,err="%v"`, methodName, request.GetDigest(), request.HashAlgorithm.String(), statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
f := func(statusCode int, err error) {
log.Printf(`m=%s,digest=%q,hash=%q,st=%d,et=%d,err="%v"`,
methodName, request.GetDigest(), request.HashAlgorithm.String(), statusCode, timeElapsedSince(start), err)
}
defer logWithCheckingPanic(f, statusCode, err)

if request.KeyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down
22 changes: 22 additions & 0 deletions api/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package api

import (
"fmt"
"net/http"
)

type logFunc func(statusCode int, err error)

// logWithCheckingPanic attemps to recover from a possible panic,
// modifies statusCode and err if there was indeed a panic,
// passes the possibly updated status and err to the logFunc,
// then panics again if there was indeed a panic to
// make UnaryInterceptor in server/server.go return "internal server error" to the client.
func logWithCheckingPanic(f logFunc, statusCode int, err error) {
if r := recover(); r != nil {
statusCode = http.StatusInternalServerError
err = fmt.Errorf("panic: %v", r)
defer panic(r)
}
f(statusCode, err)
}
61 changes: 61 additions & 0 deletions api/log_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package api

import (
"errors"
"fmt"
"net/http"
"testing"
)

func TestLogWithCheckingPanic(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input interface{}
want string // See logStr below for the format
}{
{
name: "panic with string",
input: "string",
want: "st: 500, err: panic: string",
},
{
name: "panic with error",
input: errors.New("error"),
want: "st: 500, err: panic: error",
},
{
name: "no panic",
input: nil,
want: "st: 200, err: <nil>", // See inputStatusCode below for 200
},
}
const (
logStr = "st: %d, err: %v"
inputStatusCode = http.StatusOK
)
var inputError error

for _, tc := range testCases {
// https://github.com/golang/go/wiki/CommonMistakes#using-goroutines-on-loop-iterator-variables
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got := ""
f := func(statusCode int, err error) {
got = fmt.Sprintf(logStr, statusCode, err)
}

defer func() {
// Capture the panic thrown from logWithCheckingPanic.
recover()
if got != tc.want {
t.Errorf("got: %q, want: %q", got, tc.want)
}
}()
defer logWithCheckingPanic(f, inputStatusCode, inputError)
panic(tc.input)
})
}
}
9 changes: 0 additions & 9 deletions api/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package api

import (
"fmt"
"log"
"time"

"github.com/theparanoids/crypki"
Expand All @@ -19,14 +18,6 @@ type SigningService struct {
MaxValidity map[string]uint64
}

// recoverIfPanicked recovers from panic and logs the error.
func recoverIfPanicked(method string) {
if r := recover(); r != nil {
log.Printf("%s: recovered from panic, panic: %v", method, r)
panic(r)
}
}

// timeElapsedSince returns time elapsed since start time in microseconds.
func timeElapsedSince(start time.Time) int64 {
return time.Since(start).Nanoseconds() / time.Microsecond.Nanoseconds()
Expand Down
21 changes: 11 additions & 10 deletions api/sshhost.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func (s *SigningService) GetHostSSHCertificateAvailableSigningKeys(ctx context.C
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

var keys []*proto.KeyMeta
for id := range s.KeyUsages[config.SSHHostCertEndpoint] {
Expand All @@ -46,10 +46,10 @@ func (s *SigningService) GetHostSSHCertificateSigningKey(ctx context.Context, ke
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

if keyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down Expand Up @@ -79,14 +79,15 @@ func (s *SigningService) PostHostSSHCertificate(ctx context.Context, request *pr
var err error
var cert *ssh.Certificate

defer func() {
f := func(statusCode int, err error) {
kid := ""
if cert != nil {
kid = cert.KeyId
}
log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`,
methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err)
}
defer logWithCheckingPanic(f, statusCode, err)

if request.KeyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down
21 changes: 11 additions & 10 deletions api/sshuser.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func (s *SigningService) GetUserSSHCertificateAvailableSigningKeys(ctx context.C
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

var keys []*proto.KeyMeta
for id := range s.KeyUsages[config.SSHUserCertEndpoint] {
Expand All @@ -46,10 +46,10 @@ func (s *SigningService) GetUserSSHCertificateSigningKey(ctx context.Context, ke
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

if keyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down Expand Up @@ -79,14 +79,15 @@ func (s *SigningService) PostUserSSHCertificate(ctx context.Context, request *pr
var err error
var cert *ssh.Certificate

defer func() {
f := func(statusCode int, err error) {
kid := ""
if cert != nil {
kid = cert.KeyId
}
log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`, methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
log.Printf(`m=%s,id=%q,principals=%q,st=%d,et=%d,err="%v"`,
methodName, kid, request.Principals, statusCode, timeElapsedSince(start), err)
}
defer logWithCheckingPanic(f, statusCode, err)

if request.KeyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down
18 changes: 9 additions & 9 deletions api/x509cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func (s *SigningService) GetX509CertificateAvailableSigningKeys(ctx context.Cont
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

var keys []*proto.KeyMeta
for id := range s.KeyUsages[config.X509CertEndpoint] {
Expand All @@ -45,10 +45,10 @@ func (s *SigningService) GetX509CACertificate(ctx context.Context, keyMeta *prot
start := time.Now()
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,st=%d,et=%d,err="%v"`, methodName, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

if keyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down Expand Up @@ -78,10 +78,10 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto
subject := pkix.Name{}
var err error

defer func() {
f := func(statusCode int, err error) {
log.Printf(`m=%s,sub=%q,st=%d,et=%d,err="%v"`, methodName, subject, statusCode, timeElapsedSince(start), err)
}()
defer recoverIfPanicked(methodName)
}
defer logWithCheckingPanic(f, statusCode, err)

if request.KeyMeta == nil {
statusCode = http.StatusBadRequest
Expand Down

0 comments on commit 668504e

Please sign in to comment.