Skip to content

Commit

Permalink
add recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
hslam committed Jan 29, 2021
1 parent f35f9de commit eb110ff
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import (
func main() {
m := rum.New()
m.SetPoll(true)
m.Recovery(rum.Recovery)
m.NotFound(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not Found : "+r.URL.String(), http.StatusNotFound)
})
Expand Down
63 changes: 47 additions & 16 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package rum

import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"sync"
Expand All @@ -28,11 +30,18 @@ var ErrGroupExisted = errors.New("Group Existed")
// ErrParamsKeyEmpty is the error returned by HandleFunc when the params key is empty.
var ErrParamsKeyEmpty = errors.New("Params key must be not empty")

// ContextKey represents the context key.
type ContextKey string

// ContextKeyRecovery is the context key of recovery.
const ContextKeyRecovery = ContextKey("mux:context:recovery")

// Mux is an HTTP request multiplexer.
type Mux struct {
mut sync.RWMutex
prefixes map[string]*prefix
middlewares []http.Handler
recovery http.Handler
notFound http.Handler
group string
groups map[string]*Mux
Expand Down Expand Up @@ -133,7 +142,24 @@ func (m *Mux) serveEntry(entry *Entry, w http.ResponseWriter, r *http.Request) {
}
}

// Recovery returns a recovery handler function that recovers from any panics and writes a 500 status code.
func Recovery(w http.ResponseWriter, r *http.Request) {
err := r.Context().Value(ContextKeyRecovery)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "500 Internal Server Error : %v\n", err)
}

func (m *Mux) serveHandler(handler http.Handler, w http.ResponseWriter, r *http.Request) {
if m.recovery != nil {
defer func() {
if err := recover(); err != nil {
ctx := context.WithValue(r.Context(), ContextKeyRecovery, err)
m.recovery.ServeHTTP(w, r.WithContext(ctx))
}
}()
}
m.middleware(w, r)
if handler != nil {
handler.ServeHTTP(w, r)
Expand All @@ -149,14 +175,12 @@ func (m *Mux) getHandlerFunc(path string) *Entry {
return nil
}

// HandleFunc registers the handler function for the given pattern
// in the Mux.
// HandleFunc registers a handler function with the given pattern to the Mux.
func (m *Mux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) *Entry {
return m.Handle(pattern, http.HandlerFunc(handler))
}

// Handle registers the handler for the given pattern
// in the Mux.
// Handle registers a handler with the given pattern to the Mux.
func (m *Mux) Handle(pattern string, handler http.Handler) *Entry {
m.mut.Lock()
defer m.mut.Unlock()
Expand Down Expand Up @@ -189,7 +213,7 @@ func (m *Mux) Handle(pattern string, handler http.Handler) *Entry {
return entry
}

// Group registers a group for the given pattern in the Mux.
// Group registers a group with the given pattern to the Mux.
func (m *Mux) Group(group string, f func(m *Mux)) {
m.mut.Lock()
defer m.mut.Unlock()
Expand All @@ -203,13 +227,20 @@ func (m *Mux) Group(group string, f func(m *Mux)) {
m.groups[group] = groupMux
}

// NotFound registers the not found handler function in the Mux.
// NotFound registers a not found handler function to the Mux.
func (m *Mux) NotFound(handler http.HandlerFunc) {
m.mut.Lock()
defer m.mut.Unlock()
m.notFound = handler
}

// Recovery registers a recovery handler function to the Mux.
func (m *Mux) Recovery(handler http.HandlerFunc) {
m.mut.Lock()
defer m.mut.Unlock()
m.recovery = handler
}

// Use uses middleware.
func (m *Mux) Use(handler http.HandlerFunc) {
m.mut.Lock()
Expand Down Expand Up @@ -317,70 +348,70 @@ func (m *Mux) replace(s string) string {
return s
}

// GET adds a GET HTTP method for the entry.
// GET adds a GET HTTP method to the entry.
func (entry *Entry) GET() *Entry {
entry.method |= get
entry.get = entry.handler
return entry
}

// POST adds a POST HTTP method for the entry.
// POST adds a POST HTTP method to the entry.
func (entry *Entry) POST() *Entry {
entry.method |= post
entry.post = entry.handler
return entry
}

// PUT adds a PUT HTTP method for the entry.
// PUT adds a PUT HTTP method to the entry.
func (entry *Entry) PUT() *Entry {
entry.method |= put
entry.put = entry.handler
return entry
}

// DELETE adds a DELETE HTTP method for the entry.
// DELETE adds a DELETE HTTP method to the entry.
func (entry *Entry) DELETE() *Entry {
entry.method |= delete
entry.delete = entry.handler
return entry
}

// PATCH adds a PATCH HTTP method for the entry.
// PATCH adds a PATCH HTTP method to the entry.
func (entry *Entry) PATCH() *Entry {
entry.method |= patch
entry.patch = entry.handler
return entry
}

// HEAD adds a HEAD HTTP method for the entry.
// HEAD adds a HEAD HTTP method to the entry.
func (entry *Entry) HEAD() *Entry {
entry.method |= head
entry.head = entry.handler
return entry
}

// OPTIONS adds a OPTIONS HTTP method for the entry.
// OPTIONS adds a OPTIONS HTTP method to the entry.
func (entry *Entry) OPTIONS() *Entry {
entry.method |= options
entry.options = entry.handler
return entry
}

// TRACE adds a TRACE HTTP method for the entry.
// TRACE adds a TRACE HTTP method to the entry.
func (entry *Entry) TRACE() *Entry {
entry.method |= trace
entry.trace = entry.handler
return entry
}

// CONNECT adds a CONNECT HTTP method for the entry.
// CONNECT adds a CONNECT HTTP method to the entry.
func (entry *Entry) CONNECT() *Entry {
entry.method |= connect
entry.connect = entry.handler
return entry
}

// All adds all HTTP method for the entry.
// All adds all HTTP method to the entry.
func (entry *Entry) All() {
entry.GET()
entry.POST()
Expand Down
19 changes: 19 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,25 @@ func TestDefaultNotFound(t *testing.T) {
httpServer.Close()
}

func TestDefaultRecovery(t *testing.T) {
m := New()
m.Recovery(Recovery)
msg := "panic test"
m.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
panic(msg)
w.Write([]byte("hello world Method:GET\n"))
}).GET()
addr := ":8080"
httpServer := &http.Server{
Addr: addr,
Handler: m,
}
l, _ := net.Listen("tcp", addr)
go httpServer.Serve(l)
testHTTP("GET", "http://"+addr+"/hello", http.StatusInternalServerError, "500 Internal Server Error : "+msg+"\n", t)
httpServer.Close()
}

func TestHandleFunc(t *testing.T) {
m := NewMux()
m.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit eb110ff

Please sign in to comment.