Skip to content

Commit

Permalink
final fix
Browse files Browse the repository at this point in the history
  • Loading branch information
1lann committed Jan 5, 2024
1 parent 7208bca commit 1167684
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 35 deletions.
6 changes: 3 additions & 3 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (t *Tmpauth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)

t.DebugLog(fmt.Sprintf("proxying request to mini server: %s", u.String()))

resp, err := t.miniClient(req)
resp, err := t.miniClient(req, 0)
if err != nil {
return http.StatusInternalServerError, err
}
Expand Down Expand Up @@ -282,7 +282,7 @@ func (t *Tmpauth) StartAuth(w http.ResponseWriter, r *http.Request) (int, error)
req.Header.Set(HostHeader, r.Host)
req.Header.Set("Content-Type", "application/jwt")

resp, err := t.miniClient(req)
resp, err := t.miniClient(req, 0)
if err != nil {
return http.StatusInternalServerError, fmt.Errorf("StartAuth on mini server: %w", err)
}
Expand Down Expand Up @@ -400,7 +400,7 @@ func (t *Tmpauth) Whomst() (map[string]json.RawMessage, error) {

req.Header.Set(ConfigIDHeader, t.miniConfigID)

resp, respErr = t.miniClient(req)
resp, respErr = t.miniClient(req, 0)
} else {
resp, respErr = t.HttpClient.Get("https://" + TmpAuthHost + "/whomst")
}
Expand Down
43 changes: 13 additions & 30 deletions mini.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package tmpauth

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -123,7 +121,7 @@ func NewMini(config MiniConfig, next CaddyHandleFunc) (*Tmpauth, error) {
tmpauth: t,
}

t.miniClient = transport.RoundTrip
t.miniClient = transport.Do

return t, nil
}
Expand Down Expand Up @@ -156,26 +154,18 @@ type MiniTransport struct {
tmpauth *Tmpauth
}

type roundTripDepthKey struct{}

func (t *MiniTransport) RoundTrip(req *http.Request) (*http.Response, error) {
depthRaw := req.Context().Value(roundTripDepthKey{})
var depth *int
if depthRaw != nil {
depth = depthRaw.(*int)
}

if depth != nil && *depth > 10 {
return nil, errors.New("mini transport reached maximum reauth depth")
}
func (t *MiniTransport) Do(req *http.Request, depth int) (*http.Response, error) {
var body []byte
if req.Body != nil {
var err error
body, err = io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("mini transport read body: %w", err)
}

body, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("mini transport read body: %w", err)
req.Body = io.NopCloser(bytes.NewReader(body))
}

req.Body = io.NopCloser(bytes.NewReader(body))

resp, err := t.base.RoundTrip(req)
if resp.StatusCode == http.StatusPreconditionFailed {
// our config ID is wrong
Expand All @@ -184,17 +174,10 @@ func (t *MiniTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("tmpauth: mini server reauth failed %w", err)
}

ctx := req.Context()

if depth != nil {
*depth++
} else {
one := 1
ctx = context.WithValue(ctx, roundTripDepthKey{}, &one)
if body != nil {
req.Body = io.NopCloser(bytes.NewReader(body))
}

req.Body = io.NopCloser(bytes.NewReader(body))
return t.RoundTrip(req.WithContext(ctx))
return t.Do(req, depth+1)
}

return resp, err
Expand Down
2 changes: 1 addition & 1 deletion setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type Tmpauth struct {
miniServerHost string
miniConfigID string
miniConfigJSON []byte
miniClient func(req *http.Request) (*http.Response, error)
miniClient func(req *http.Request, depth int) (*http.Response, error)

done chan struct{}
doneOnce sync.Once
Expand Down
2 changes: 1 addition & 1 deletion token.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (t *Tmpauth) ParseWrappedAuthJWT(tokenStr string) (*CachedToken, error) {
req.Header.Set(ConfigIDHeader, t.miniConfigID)
req.Header.Set("Content-Type", "application/jwt")

resp, err := t.miniClient(req)
resp, err := t.miniClient(req, 0)
if err != nil {
return nil, fmt.Errorf("ParseWrappedAuthJWT on mini server: %w", err)
}
Expand Down

0 comments on commit 1167684

Please sign in to comment.