Skip to content

Commit

Permalink
Adding testing for Client funcs, some Client refactoring; resolve bra…
Browse files Browse the repository at this point in the history
…infart bug in IoClose
  • Loading branch information
momer committed May 2, 2024
1 parent a4f0471 commit e888ddb
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 24 deletions.
106 changes: 87 additions & 19 deletions bonsai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"golang.org/x/time/rate"
)

// Client representation configuration
const (
// Version reflects this API Client's version
Version = "1.0.0"
Expand All @@ -26,7 +27,10 @@ const (
// UserAgent is the internally used value for the User-Agent header
// in all outgoing HTTP requests
UserAgent = "bonsai-api-go/" + Version
)

// Client rate limiter configuration
const (
// DefaultClientBurstAllowance is the default Bonsai API request burst allowance
DefaultClientBurstAllowance = 60
// DefaultClientBurstDuration is the default interval for a token bucket of size DefaultClientBurstAllowance to be refilled.
Expand All @@ -35,12 +39,26 @@ const (
ProvisionClientBurstAllowance = 5
// ProvisionClientBurstDuration is the default interval for a token bucket of size ProvisionClientBurstAllowance to be refilled.
ProvisionClientBurstDuration = 1 * time.Minute
)

// Common API Response headers
const (
// HeaderRetryAfter holds the number of seconds to delay before making the next request
// ref: https://bonsai.io/docs/api-error-429-too-many-requests
HeaderRetryAfter = "Retry-After"
)

// HTTP Content Types and related Header
const (
HTTPHeaderContentType = "Content-Type"
HTTPContentTypeJSON string = "application/json"
)

// HTTP Status Response Errors
var (
ErrorHTTPStatusNotFound = errors.New("not found")
)

// ResponseError captures API response errors
// returned as JSON in supported scenarios.
//
Expand All @@ -56,7 +74,15 @@ type ResponseError struct {
// The community is as yet undecided on a great way to handle this
// ref: https://github.com/golang/go/issues/47811
func (r ResponseError) Error() string {
return strings.Join(r.Errors, "; ")
return fmt.Sprintf("%v (%d)", r.Errors, r.Status)
}

func (r ResponseError) Is(target error) bool {
switch r.Status {
case http.StatusNotFound:
return target == ErrorHTTPStatusNotFound
}
return false
}

// listOpts specifies options for listing resources.
Expand Down Expand Up @@ -170,10 +196,55 @@ type PaginatedResponse struct {
TotalRecords int `json:"total_records"`
}

type httpResponse = *http.Response
type Response struct {
*http.Response
httpResponse

PaginatedResponse
Body io.ReadCloser
PaginatedResponse `json:"pagination"`
}

func (r *Response) WithHTTPResponse(httpResp *http.Response) error {
var err error
bodyBuf := new(bytes.Buffer)
r.httpResponse = httpResp

if httpResp == nil {
return errors.New("received nil http.Response")
}

_, err = bodyBuf.ReadFrom(httpResp.Body)
if err != nil {
return fmt.Errorf("error reading response body: %w", err)
}

err = IoClose(httpResp.Body, err)
if err != nil {
return err
}

r.Body = io.NopCloser(bodyBuf)

switch httpResp.Header.Get("Content-Type") {
case HTTPContentTypeJSON:
err = json.Unmarshal(bodyBuf.Bytes(), r)
}
if err != nil {
return fmt.Errorf("error unmarshaling response body: %w", err)
}

return err
}

func (r *Response) MarkPaginationComplete() {
r.PaginatedResponse = PaginatedResponse{}
}

// NewResponse reserves this function signature, and is
// the recommended way to instantiate a Response, as its behavior
// may change.
func NewResponse() (*Response, error) {
return &Response{}, nil
}

type limiter = *rate.Limiter
Expand Down Expand Up @@ -239,17 +310,17 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re
}

// Do performs an HTTP request against the API.
func (c *Client) Do(ctx context.Context, r *http.Request) (*Response, error) {
func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
reqBuf := new(bytes.Buffer)

// Capture the original request body
if r.ContentLength > 0 {
_, err := reqBuf.ReadFrom(r.Body)
if req.ContentLength > 0 {
_, err := reqBuf.ReadFrom(req.Body)
if err != nil {
return nil, fmt.Errorf("error reading request body: %w", err)
}

err = IoClose(r.Body, err)
err = IoClose(req.Body, err)
if err != nil {
return nil, err
}
Expand All @@ -260,8 +331,8 @@ func (c *Client) Do(ctx context.Context, r *http.Request) (*Response, error) {
respBuf := new(bytes.Buffer)
// Wrap the buffer in a no-op Closer, such that
// it satisfies the ReadCloser interface
if r.ContentLength > 0 {
r.Body = io.NopCloser(reqBuf)
if req.ContentLength > 0 {
req.Body = io.NopCloser(reqBuf)
}

// Context cancelled, timed-out, burst issue, or other rate limit issue;
Expand All @@ -270,24 +341,21 @@ func (c *Client) Do(ctx context.Context, r *http.Request) (*Response, error) {
return nil, fmt.Errorf("failed while awaiting execution per rate-limit: %w", err)
}

httpResp, err := c.httpClient.Do(r)
resp := &Response{Response: httpResp}
httpResp, err := c.httpClient.Do(req)
if err != nil {
return resp, err
return nil, fmt.Errorf("http request failed: %w", err)
}

_, err = respBuf.ReadFrom(resp.Body)
resp, err := NewResponse()
if err != nil {
return resp, fmt.Errorf("error reading response body: %w", err)
return resp, fmt.Errorf("creating new Response")
}

err = IoClose(resp.Body, err)
err = resp.WithHTTPResponse(httpResp)
if err != nil {
return resp, err
return resp, fmt.Errorf("setting http response: %w", err)
}

resp.Body = io.NopCloser(respBuf)

if resp.StatusCode >= 400 {
respErr := ResponseError{}
if err = json.Unmarshal(respBuf.Bytes(), &respErr); err != nil {
Expand Down Expand Up @@ -335,7 +403,7 @@ func (c *Client) all(ctx context.Context, f func(int) (*Response, error)) error

// The caller is responsible for determining whether or not we've exhausted
// retries.
if reflect.ValueOf(resp.PaginatedResponse).IsZero() || resp.PageNumber == 0 {
if reflect.ValueOf(resp.PaginatedResponse).IsZero() || resp.PageNumber <= 0 {
return nil
}
// We should be fine with a straight increment, but let's play it safe
Expand Down
90 changes: 90 additions & 0 deletions bonsai/client_impl_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package bonsai

import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -17,9 +23,28 @@ type ClientImplTestSuite struct {
*require.Assertions
// Suite is the testify/suite used for all HTTP request tests
suite.Suite

// serveMux is the request multiplexer used for tests
serveMux *http.ServeMux
// server is the testing server on some local port
server *httptest.Server
// client allows each test to have a reachable *Client for testing
client *Client
}

func (s *ClientImplTestSuite) SetupSuite() {
// Configure http client and other miscellany
s.serveMux = http.NewServeMux()
s.server = httptest.NewServer(s.serveMux)
token, err := NewToken("TestToken")
if err != nil {
log.Fatal(fmt.Errorf("invalid token received: %w", err))
}
s.client = NewClient(
WithEndpoint(s.server.URL),
WithToken(token),
)

// configure testify
s.Assertions = require.New(s.T())
}
Expand Down Expand Up @@ -58,6 +83,71 @@ func (s *ClientImplTestSuite) TestListOptsValues() {
}
}

func (s *ClientImplTestSuite) TestClientAll() {
const expectedPageCount = 4
var (
ctx = context.Background()
expectedPage = 1
)

s.serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(HTTPHeaderContentType, HTTPContentTypeJSON)

respBody, _ := NewResponse()
respBody.PaginatedResponse = PaginatedResponse{
PageNumber: 3,
PageSize: 1,
TotalRecords: 3,
}

switch page := r.URL.Query().Get("page"); page {
case "", "1":
respBody.PaginatedResponse.PageNumber = 1
case "2":
respBody.PaginatedResponse.PageNumber = 2
case "3":
respBody.PaginatedResponse.PageNumber = 3
default:
s.FailNowf("invalid page parameter", "page parameter: %v", page)
}

err := json.NewEncoder(w).Encode(respBody)
s.Nil(err, "encode response body")
})

// The caller must track results against expected count
// A reminder to the reader: this is the caller.
var resultCount = 0
err := s.client.all(context.Background(), func(page int) (*Response, error) {
s.Equalf(expectedPage, page, "expected page number (%d) matches actual (%d)", expectedPage, page)

path := fmt.Sprintf("/?page=%d&size=1", page)

req, err := s.client.NewRequest(ctx, "GET", path, nil)
s.Nil(err, "new request for path")

resp, err := s.client.Do(context.Background(), req)
s.Nil(err, "do request")

expectedPage++
// A reference of how these funcs should handle this;
// recall, the response may be shorter than max.
//
// Ideally, this count wouldn't be derived from PageSize,
// but rather, from the total count of discovered items
// unmarshaled.
resultCount += max(resp.PageSize, 0)

if resultCount >= resp.TotalRecords {
resp.MarkPaginationComplete()
}
return resp, err
})
s.Nil(err, "client.all call")

s.Equalf(expectedPage, expectedPageCount, "expected page visit count (%d) matches actual visit count (%d)", expectedPageCount-1, expectedPage-1)
}

func TestClientImplTestSuite(t *testing.T) {
suite.Run(t, new(ClientImplTestSuite))
}
Loading

0 comments on commit e888ddb

Please sign in to comment.