Skip to content

Commit 4cf8225

Browse files
committedNov 29, 2015
Added ability to filter out non-GET requests
1 parent b527709 commit 4cf8225

File tree

4 files changed

+62
-5
lines changed

4 files changed

+62
-5
lines changed
 

‎doc.go

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Fasthttp provides the following features:
2121
* Maximum request body size.
2222
* Maximum request execution time.
2323
* Maximum keep-alive connection lifetime.
24+
* Early filtering out non-GET requests.
2425
2526
* A lot of additional useful info is exposed to request handler:
2627

‎http.go

+13-4
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ type Response struct {
4242
// Copying Header by value is forbidden. Use pointer to Header instead.
4343
Header ResponseHeader
4444

45+
// Response.Read() skips reading body if set to true.
46+
// Use it for HEAD requests.
47+
SkipBody bool
48+
4549
body []byte
4650
w responseBodyWriter
4751

4852
bodyStream io.Reader
49-
50-
// If set to true, Response.Read() skips reading body.
51-
// Use it for HEAD requests.
52-
SkipBody bool
5353
}
5454

5555
// SetRequestURI sets RequestURI.
@@ -307,6 +307,8 @@ func (req *Request) Read(r *bufio.Reader) error {
307307

308308
const defaultMaxInMemoryFileSize = 16 * 1024 * 1024
309309

310+
var errGetOnly = errors.New("non-GET request received")
311+
310312
// ReadLimitBody reads request from the given r, limiting the body size.
311313
//
312314
// If maxBodySize > 0 and the body size exceeds maxBodySize,
@@ -316,11 +318,18 @@ const defaultMaxInMemoryFileSize = 16 * 1024 * 1024
316318
// reading multipart/form-data request in order to delete temporarily
317319
// uploaded files.
318320
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
321+
return req.readLimitBody(r, maxBodySize, false)
322+
}
323+
324+
func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool) error {
319325
req.clearSkipHeader()
320326
err := req.Header.Read(r)
321327
if err != nil {
322328
return err
323329
}
330+
if getOnly && !req.Header.IsGet() {
331+
return errGetOnly
332+
}
324333

325334
if req.Header.IsPost() || req.Header.IsPut() {
326335
contentLength := req.Header.ContentLength()

‎server.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ type Server struct {
163163
// Aggressive memory usage reduction is disabled by default.
164164
ReduceMemoryUsage bool
165165

166+
// Rejects all non-GET requests if set to true.
167+
//
168+
// This option is useful as anti-DoS protection for servers
169+
// accepting only GET requests. When set the request size is limited
170+
// by ReadBufferSize.
171+
//
172+
// Server accepts all the requests by default.
173+
GetOnly bool
174+
166175
// Logger, which is used by RequestCtx.Logger().
167176
//
168177
// By default standard logger from log package is used.
@@ -869,7 +878,7 @@ func (s *Server) serveConn(c net.Conn) error {
869878
if br == nil {
870879
br = acquireReader(ctx)
871880
}
872-
err = ctx.Request.ReadLimitBody(br, s.MaxRequestBodySize)
881+
err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly)
873882
if br.Buffered() == 0 || err != nil {
874883
releaseReader(s, br)
875884
br = nil

‎server_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,44 @@ func TestTimeoutHandlerSuccess(t *testing.T) {
126126
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
127127
}
128128

129+
func TestServerGetOnly(t *testing.T) {
130+
h := func(ctx *RequestCtx) {
131+
if !ctx.IsGet() {
132+
t.Fatalf("non-get request: %q", ctx.Method())
133+
}
134+
ctx.Success("foo/bar", []byte("success"))
135+
}
136+
s := &Server{
137+
Handler: h,
138+
GetOnly: true,
139+
}
140+
141+
rw := &readWriter{}
142+
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345")
143+
144+
ch := make(chan error)
145+
go func() {
146+
ch <- s.ServeConn(rw)
147+
}()
148+
149+
select {
150+
case err := <-ch:
151+
if err == nil {
152+
t.Fatalf("expecting error")
153+
}
154+
if err != errGetOnly {
155+
t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, errGetOnly)
156+
}
157+
case <-time.After(100 * time.Millisecond):
158+
t.Fatalf("timeout")
159+
}
160+
161+
resp := rw.w.Bytes()
162+
if len(resp) > 0 {
163+
t.Fatalf("unexpected response %q. Expecting zero", resp)
164+
}
165+
}
166+
129167
func TestTimeoutHandlerTimeout(t *testing.T) {
130168
h := func(ctx *RequestCtx) {
131169
time.Sleep(time.Second)

0 commit comments

Comments
 (0)
Please sign in to comment.