From 87ad469c9402fef3632a864514f7e238fa28b24a Mon Sep 17 00:00:00 2001 From: rosahaj <141790572+rosahaj@users.noreply.github.com> Date: Sun, 1 Dec 2024 21:27:43 +0100 Subject: [PATCH] Add DefaultRedirectPolicy --- client.go | 16 +++------------- client_test.go | 4 ++++ redirect.go | 5 +++++ 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 1b79591..b1eea7d 100644 --- a/client.go +++ b/client.go @@ -321,20 +321,10 @@ func (c *Client) GetTLSClientConfig() *tls.Config { return c.TLSClientConfig } -func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") - } - if c.DebugLog { - c.log.Debugf("<redirect> %s %s", req.Method, req.URL.String()) - } - return nil -} - // SetRedirectPolicy set the RedirectPolicy which controls the behavior of receiving redirect // responses (usually responses with 301 and 302 status code), see the predefined -// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, -// SameDomainRedirectPolicy and SameHostRedirectPolicy. +// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, DefaultRedirectPolicy, MaxRedirectPolicy, +// NoRedirectPolicy, SameDomainRedirectPolicy and SameHostRedirectPolicy. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c @@ -1565,7 +1555,7 @@ func C() *Client { xmlUnmarshal: xml.Unmarshal, cookiejarFactory: memoryCookieJarFactory, } - httpClient.CheckRedirect = c.defaultCheckRedirect + c.SetRedirectPolicy(DefaultRedirectPolicy()) c.initCookieJar() c.initTransport() diff --git a/client_test.go b/client_test.go index e9e9f75..7a6aeeb 100644 --- a/client_test.go +++ b/client_test.go @@ -369,6 +369,10 @@ func TestRedirect(t *testing.T) { tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) + _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(20)).SetRedirectPolicy(DefaultRedirectPolicy()).R().Get("/unlimited-redirect") + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "stopped after 10 redirects", true) + _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) diff --git a/redirect.go b/redirect.go index f1cc433..fcc13e4 100644 --- a/redirect.go +++ b/redirect.go @@ -21,6 +21,11 @@ func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy { } } +// DefaultRedirectPolicy allows up to 10 redirects +func DefaultRedirectPolicy() RedirectPolicy { + return MaxRedirectPolicy(10) +} + // NoRedirectPolicy disable redirect behaviour func NoRedirectPolicy() RedirectPolicy { return func(req *http.Request, via []*http.Request) error {