Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RoundTripper method to ghttp.Server #770

Merged
merged 2 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,33 @@ To bring it all together: there are three ways to instruct a `ghttp` server to h

When a `ghttp` server receives a request it first checks against the set of handlers registered via `RouteToHandler` if there is no such handler it proceeds to pop an `AppendHandlers` handler off the stack, if the stack of ordered handlers is empty, it will check whether `GetAllowUnhandledRequests` returns `true` or `false`. If `false` the test fails. If `true`, a response is sent with whatever `GetUnhandledRequestStatusCode` returns.

### Using a RoundTripper to route requests to the test Server

So far you have seen examples of using `server.URL()` to get the string URL of the test server. This is ok if you are testing code where you can pass the URL. In some cases you might need to pass a `http.Client` or similar.

You can use `server.RoundTripper(nil)` to create a `http.RoundTripper` which will redirect requests to the test server.

The method takes another `http.RoundTripper` to make the request to the test server, this allows chaining `http.Transports` or otherwise.

If passed `nil`, then `http.DefaultTransport` is used to make the request.

```go
Describe("The http client", func() {
var server *ghttp.Server
var httpClient *http.Client

BeforeEach(func() {
server = ghttp.NewServer()
httpClient = &http.Client{Transport: server.RoundTripper(nil)}
})

AfterEach(func() {
//shut down the server between tests
server.Close()
})
})
```

## `gbytes`: Testing Streaming Buffers

`gbytes` implements `gbytes.Buffer` - an `io.WriteCloser` that captures all input to an in-memory buffer.
Expand Down
79 changes: 50 additions & 29 deletions ghttp/test_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,26 +186,26 @@ type Server struct {
calls int
}

//Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest).
// Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest).
func (s *Server) Start() {
s.HTTPTestServer.Start()
}

//URL() returns a url that will hit the server
// URL() returns a url that will hit the server
func (s *Server) URL() string {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()
return s.HTTPTestServer.URL
}

//Addr() returns the address on which the server is listening.
// Addr() returns the address on which the server is listening.
func (s *Server) Addr() string {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()
return s.HTTPTestServer.Listener.Addr().String()
}

//Close() should be called at the end of each test. It spins down and cleans up the test server.
// Close() should be called at the end of each test. It spins down and cleans up the test server.
func (s *Server) Close() {
s.rwMutex.Lock()
server := s.HTTPTestServer
Expand All @@ -217,14 +217,14 @@ func (s *Server) Close() {
}
}

//ServeHTTP() makes Server an http.Handler
//When the server receives a request it handles the request in the following order:
// ServeHTTP() makes Server an http.Handler
// When the server receives a request it handles the request in the following order:
//
//1. If the request matches a handler registered with RouteToHandler, that handler is called.
//2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
//3. If all registered handlers have been called then:
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
// 1. If the request matches a handler registered with RouteToHandler, that handler is called.
// 2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
// 3. If all registered handlers have been called then:
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.rwMutex.Lock()
defer func() {
Expand Down Expand Up @@ -280,18 +280,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
}

//ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests)
// ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests)
func (s *Server) ReceivedRequests() []*http.Request {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()

return s.receivedRequests
}

//RouteToHandler can be used to register handlers that will always handle requests that match
//the passed in method and path.
// RouteToHandler can be used to register handlers that will always handle requests that match
// the passed in method and path.
//
//The path may be either a string object or a *regexp.Regexp.
// The path may be either a string object or a *regexp.Regexp.
func (s *Server) RouteToHandler(method string, path interface{}, handler http.HandlerFunc) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()
Expand Down Expand Up @@ -337,25 +337,25 @@ func (s *Server) handlerForRoute(method string, path string) (http.HandlerFunc,
return nil, false
}

//AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc...
// AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc...
func (s *Server) AppendHandlers(handlers ...http.HandlerFunc) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.requestHandlers = append(s.requestHandlers, handlers...)
}

//SetHandler overrides the registered handler at the passed in index with the passed in handler
//This is useful, for example, when a server has been set up in a shared context, but must be tweaked
//for a particular test.
// SetHandler overrides the registered handler at the passed in index with the passed in handler
// This is useful, for example, when a server has been set up in a shared context, but must be tweaked
// for a particular test.
func (s *Server) SetHandler(index int, handler http.HandlerFunc) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.requestHandlers[index] = handler
}

//GetHandler returns the handler registered at the passed in index.
// GetHandler returns the handler registered at the passed in index.
func (s *Server) GetHandler(index int) http.HandlerFunc {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()
Expand All @@ -374,12 +374,12 @@ func (s *Server) Reset() {
s.routedHandlers = nil
}

//WrapHandler combines the passed in handler with the handler registered at the passed in index.
//This is useful, for example, when a server has been set up in a shared context but must be tweaked
//for a particular test.
// WrapHandler combines the passed in handler with the handler registered at the passed in index.
// This is useful, for example, when a server has been set up in a shared context but must be tweaked
// for a particular test.
//
//If the currently registered handler is A, and the new passed in handler is B then
//WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index
// If the currently registered handler is A, and the new passed in handler is B then
// WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index
func (s *Server) WrapHandler(index int, handler http.HandlerFunc) {
existingHandler := s.GetHandler(index)
s.SetHandler(index, CombineHandlers(existingHandler, handler))
Expand All @@ -392,34 +392,55 @@ func (s *Server) CloseClientConnections() {
s.HTTPTestServer.CloseClientConnections()
}

//SetAllowUnhandledRequests enables the server to accept unhandled requests.
// SetAllowUnhandledRequests enables the server to accept unhandled requests.
func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.AllowUnhandledRequests = allowUnhandledRequests
}

//GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
// GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
func (s *Server) GetAllowUnhandledRequests() bool {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()

return s.AllowUnhandledRequests
}

//SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
// SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
func (s *Server) SetUnhandledRequestStatusCode(statusCode int) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.UnhandledRequestStatusCode = statusCode
}

//GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
// GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
func (s *Server) GetUnhandledRequestStatusCode() int {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()

return s.UnhandledRequestStatusCode
}

// RoundTripper returns a RoundTripper which updates requests to point to the server.
// This is useful when you want to use the server as a RoundTripper in an http.Client.
// If rt is nil, http.DefaultTransport is used.
func (s *Server) RoundTripper(rt http.RoundTripper) http.RoundTripper {
if rt == nil {
rt = http.DefaultTransport
}
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
r.URL.Scheme = "http"
r.URL.Host = s.Addr()
return rt.RoundTrip(r)
})
}

// Helper type for creating a RoundTripper from a function
type RoundTripperFunc func(*http.Request) (*http.Response, error)

func (fn RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return fn(r)
}
59 changes: 59 additions & 0 deletions ghttp/test_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1190,4 +1190,63 @@ var _ = Describe("TestServer", func() {
})
})
})

Describe("RoundTripper", func() {
var called []string
BeforeEach(func() {
called = []string{}
s.RouteToHandler("GET", "/routed", func(w http.ResponseWriter, req *http.Request) {
called = append(called, "get")
})
s.RouteToHandler("POST", "/routed", func(w http.ResponseWriter, req *http.Request) {
called = append(called, "post")
})
})

It("should send http traffic to test server with default transport", func() {
client := http.Client{Transport: s.RoundTripper(nil)}
client.Get("http://example.com/routed")
client.Post("http://example.com/routed", "application/json", nil)
client.Get("http://foo.bar/routed")
client.Post("http://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should send https traffic to test server with default transport", func() {
client := http.Client{Transport: s.RoundTripper(nil)}
client.Get("https://example.com/routed")
client.Post("https://example.com/routed", "application/json", nil)
client.Get("https://foo.bar/routed")
client.Post("https://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should send http traffic to test server with default transport", func() {
transport := http.Transport{}
client := http.Client{Transport: s.RoundTripper(&transport)}
client.Get("http://example.com/routed")
client.Post("http://example.com/routed", "application/json", nil)
client.Get("http://foo.bar/routed")
client.Post("http://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should send http traffic to test server with default transport", func() {
transport := http.Transport{}
client := http.Client{Transport: s.RoundTripper(&transport)}
client.Get("https://example.com/routed")
client.Post("https://example.com/routed", "application/json", nil)
client.Get("https://foo.bar/routed")
client.Post("https://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should not change the path of the request", func() {
client := http.Client{Transport: s.RoundTripper(nil)}
client.Get("https://example.com/routed")
Expect(called).Should(Equal([]string{"get"}))
Expect(s.ReceivedRequests()).Should(HaveLen(1))
Expect(s.ReceivedRequests()[0].URL.Path).Should(Equal("/routed"))
})
})
})
Loading