diff --git a/handler.go b/handler.go index 5b0ef5f..30d23a0 100644 --- a/handler.go +++ b/handler.go @@ -147,7 +147,7 @@ func (t *Tmpauth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) t.DebugLog(fmt.Sprintf("proxying request to mini server: %s", u.String())) - resp, err := t.miniClient(req) + resp, err := t.miniClient(req, 0) if err != nil { return http.StatusInternalServerError, err } @@ -282,7 +282,7 @@ func (t *Tmpauth) StartAuth(w http.ResponseWriter, r *http.Request) (int, error) req.Header.Set(HostHeader, r.Host) req.Header.Set("Content-Type", "application/jwt") - resp, err := t.miniClient(req) + resp, err := t.miniClient(req, 0) if err != nil { return http.StatusInternalServerError, fmt.Errorf("StartAuth on mini server: %w", err) } @@ -400,7 +400,7 @@ func (t *Tmpauth) Whomst() (map[string]json.RawMessage, error) { req.Header.Set(ConfigIDHeader, t.miniConfigID) - resp, respErr = t.miniClient(req) + resp, respErr = t.miniClient(req, 0) } else { resp, respErr = t.HttpClient.Get("https://" + TmpAuthHost + "/whomst") } diff --git a/mini.go b/mini.go index a78a5fa..334a474 100644 --- a/mini.go +++ b/mini.go @@ -2,9 +2,7 @@ package tmpauth import ( "bytes" - "context" "encoding/json" - "errors" "fmt" "io" "log" @@ -123,7 +121,7 @@ func NewMini(config MiniConfig, next CaddyHandleFunc) (*Tmpauth, error) { tmpauth: t, } - t.miniClient = transport.RoundTrip + t.miniClient = transport.Do return t, nil } @@ -156,26 +154,18 @@ type MiniTransport struct { tmpauth *Tmpauth } -type roundTripDepthKey struct{} - -func (t *MiniTransport) RoundTrip(req *http.Request) (*http.Response, error) { - depthRaw := req.Context().Value(roundTripDepthKey{}) - var depth *int - if depthRaw != nil { - depth = depthRaw.(*int) - } - - if depth != nil && *depth > 10 { - return nil, errors.New("mini transport reached maximum reauth depth") - } +func (t *MiniTransport) Do(req *http.Request, depth int) (*http.Response, error) { + var body []byte + if req.Body != nil { + var err error + body, err = io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("mini transport read body: %w", err) + } - body, err := io.ReadAll(req.Body) - if err != nil { - return nil, fmt.Errorf("mini transport read body: %w", err) + req.Body = io.NopCloser(bytes.NewReader(body)) } - req.Body = io.NopCloser(bytes.NewReader(body)) - resp, err := t.base.RoundTrip(req) if resp.StatusCode == http.StatusPreconditionFailed { // our config ID is wrong @@ -184,17 +174,10 @@ func (t *MiniTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("tmpauth: mini server reauth failed %w", err) } - ctx := req.Context() - - if depth != nil { - *depth++ - } else { - one := 1 - ctx = context.WithValue(ctx, roundTripDepthKey{}, &one) + if body != nil { + req.Body = io.NopCloser(bytes.NewReader(body)) } - - req.Body = io.NopCloser(bytes.NewReader(body)) - return t.RoundTrip(req.WithContext(ctx)) + return t.Do(req, depth+1) } return resp, err diff --git a/setup.go b/setup.go index 547a1a0..6a3332b 100644 --- a/setup.go +++ b/setup.go @@ -55,7 +55,7 @@ type Tmpauth struct { miniServerHost string miniConfigID string miniConfigJSON []byte - miniClient func(req *http.Request) (*http.Response, error) + miniClient func(req *http.Request, depth int) (*http.Response, error) done chan struct{} doneOnce sync.Once diff --git a/token.go b/token.go index 6415ee2..365ec29 100644 --- a/token.go +++ b/token.go @@ -107,7 +107,7 @@ func (t *Tmpauth) ParseWrappedAuthJWT(tokenStr string) (*CachedToken, error) { req.Header.Set(ConfigIDHeader, t.miniConfigID) req.Header.Set("Content-Type", "application/jwt") - resp, err := t.miniClient(req) + resp, err := t.miniClient(req, 0) if err != nil { return nil, fmt.Errorf("ParseWrappedAuthJWT on mini server: %w", err) }