Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid getting request body as much as possible in the handler. #81

Merged
merged 10 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
run:
go: 1.21
modules-download-mode: mod
linters:
fast: false
linters-settings:
staticcheck:
go: 1.16
issues:
exclude:
- SA3000
18 changes: 18 additions & 0 deletions copybuf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package rc

import "sync"

// Copy from net/http/server.go
const copyBufPoolSize = 32 * 1024

var copyBufPool = sync.Pool{New: func() any { return new([copyBufPoolSize]byte) }}

func getCopyBuf() []byte { //nostyle:getters
return copyBufPool.Get().(*[copyBufPoolSize]byte)[:]
}
func putCopyBuf(b []byte) {
if len(b) != copyBufPoolSize {
panic("trying to put back buffer of the wrong size in the copyBufPool") //nostyle:dontpanic
}
copyBufPool.Put((*[copyBufPoolSize]byte)(b))
}
121 changes: 64 additions & 57 deletions rc.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,28 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler {
now := time.Now()

// Copy the request so that it is not affected by the next handler.
req, preq := m.duplicateRequest(req)
// reqc is the request to be used for caching.
req, reqc := m.duplicateRequest(req)

cachedReq, cachedRes, err := m.cacher.Load(preq) //nostyle:handlerrors
cachedReq, cachedRes, err := m.cacher.Load(reqc) //nostyle:handlerrors
if err != nil {
switch {
case errors.Is(err, ErrCacheNotFound):
m.logger.Debug("cache not found", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Debug("cache not found", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
case errors.Is(err, ErrCacheExpired):
m.logger.Debug("cache expired", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Debug("cache expired", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
case errors.Is(err, ErrShouldNotUseCache):
m.logger.Debug("should not use cache", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Debug("should not use cache", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
// Skip caching
next.ServeHTTP(w, req)
return
default:
m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
}
}
cacheUsed, res, err := m.cacher.Handle(req, cachedReq, cachedRes, HandlerToRequester(next), now) //nostyle:handlerrors
cacheUsed, res, err := m.cacher.Handle(req, cachedReq, cachedRes, m.handlerToRequester(next, reqc, now), now) //nostyle:handlerrors
if err != nil {
m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
}

// Response
Expand All @@ -131,50 +132,35 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler {
}
}
w.WriteHeader(res.StatusCode)
body, err := io.ReadAll(res.Body)
if err != nil {
m.logger.Error("failed to read response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
} else {
if _, err := w.Write(body); err != nil {
// Error as debug
// - os.ErrDeadlineExceeded: The request context has been canceled or has expired.
// - "client disconnected": The client disconnected. (net/http.http2errClientDisconnected)
// - "http2: stream closed": The client disconnected. (net/http.http2errStreamClosed)
// - syscall.ECONNRESET: The client disconnected. ("connection reset by peer")
// - syscall.EPIPE: The client disconnected. ("broken pipe")
// - http.ErrBodyNotAllowed: The request method does not allow a body.
switch {
case errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.EPIPE) || contains([]string{"client disconnected", "http2: stream closed"}, err.Error()):
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
case errors.Is(err, http.ErrBodyNotAllowed):
// It is desirable that there should be no content body in the response, but the proxy server cannot handle it, so it is used as a debug log.
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
default:
m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}

ww := w.(io.Writer)
buf := getCopyBuf()
defer putCopyBuf(buf)
if _, err := io.CopyBuffer(ww, res.Body, buf); err != nil {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use io.CopyBuffer instead of io.ReadAll

// Error as debug
// - os.ErrDeadlineExceeded: The request context has been canceled or has expired.
// - "client disconnected": The client disconnected. (net/http.http2errClientDisconnected)
// - "http2: stream closed": The client disconnected. (net/http.http2errStreamClosed)
// - syscall.ECONNRESET: The client disconnected. ("connection reset by peer")
// - syscall.EPIPE: The client disconnected. ("broken pipe")
// - http.ErrBodyNotAllowed: The request method does not allow a body.
switch {
case errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.EPIPE) || contains([]string{"client disconnected", "http2: stream closed"}, err.Error()):
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
case errors.Is(err, http.ErrBodyNotAllowed):
// It is desirable that there should be no content body in the response, but the proxy server cannot handle it, so it is used as a debug log.
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
default:
m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}
}
if err := res.Body.Close(); err != nil {
m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}

if cacheUsed {
m.logger.Debug("cache used", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
return
}
ok, expires := m.cacher.Storable(preq, res, now)
if !ok {
m.logger.Debug("cache not storable", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
return
m.logger.Debug("cache used", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode))
}
// Restore response body
res.Body = io.NopCloser(bytes.NewReader(body))

// Store response as cache
if err := m.cacher.Store(preq, res, expires); err != nil {
m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
}
m.logger.Debug("cache stored", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
})
}

Expand All @@ -197,6 +183,32 @@ func (m *cacheMw) duplicateRequest(req *http.Request) (*http.Request, *http.Requ
return copy, req
}

func (m *cacheMw) handlerToRequester(h http.Handler, reqc *http.Request, now time.Time) func(*http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
rec := newRecorder()
defer rec.Reset()
h.ServeHTTP(rec, req)
res := rec.Result()
resc := rec.Result()

go func() {
ok, expires := m.cacher.Storable(reqc, resc, now)
if !ok {
m.logger.Debug("cache not storable", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(resc.Header)))
return
}

// Store response as cache
if err := m.cacher.Store(reqc, resc, expires); err != nil {
m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", resc.StatusCode))
}
m.logger.Debug("cache stored", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", resc.StatusCode))
}()
Comment on lines +194 to +206
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use goroutine to avoid blocking due to the storing process.


return res, nil
}
}

func (m *cacheMw) maskHeader(h http.Header) http.Header {
const masked = "*****"
c := h.Clone()
Expand Down Expand Up @@ -237,17 +249,6 @@ func New(cacher Cacher, opts ...Option) func(next http.Handler) http.Handler {
return rl.Handler
}

// HandlerToRequester converts http.Handler to func(*http.Request) (*http.Response, error).
func HandlerToRequester(h http.Handler) func(*http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
rec := newRecorder()
h.ServeHTTP(rec, req)
res := rec.Result()
res.Header = rec.Header()
return res, nil
}
}

type recorder struct {
statusCode int
header http.Header
Expand Down Expand Up @@ -280,11 +281,17 @@ func (r *recorder) Result() *http.Response {
Status: http.StatusText(r.statusCode),
StatusCode: r.statusCode,
Header: r.header.Clone(),
Body: io.NopCloser(r.buf),
Body: io.NopCloser(bytes.NewReader(r.buf.Bytes())),
ContentLength: int64(r.buf.Len()),
}
}

func (r *recorder) Reset() {
r.statusCode = 0
r.header = make(http.Header)
r.buf.Reset()
}

func contains(s []string, e string) bool {
for _, v := range s {
if e == v {
Expand Down
Loading