Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix.599.auth.deadlock #600

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions edge-apis/authwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ type ApiSession interface {

//RequiresRouterTokenUpdate returns true if the token is a bearer token requires updating on edge router connections.
RequiresRouterTokenUpdate() bool

GetRequestHeaders() http.Header
}

var _ ApiSession = (*ApiSessionLegacy)(nil)
Expand All @@ -86,7 +88,12 @@ var _ ApiSession = (*ApiSessionOidc)(nil)
// ApiSessionLegacy represents OpenZiti's original authentication API Session Detail, supplied in the `zt-session` header.
// It has been supplanted by OIDC authentication represented by ApiSessionOidc.
type ApiSessionLegacy struct {
Detail *rest_model.CurrentAPISessionDetail
Detail *rest_model.CurrentAPISessionDetail
RequestHeaders http.Header
}

func (a *ApiSessionLegacy) GetRequestHeaders() http.Header {
return a.RequestHeaders
}

func (a *ApiSessionLegacy) RequiresRouterTokenUpdate() bool {
Expand Down Expand Up @@ -119,8 +126,15 @@ func (a *ApiSessionLegacy) AuthenticateRequest(request runtime.ClientRequest, _
return errors.New("api session is nil")
}

header, val := a.GetAccessHeader()
for h, v := range a.RequestHeaders {
err := request.SetHeaderParam(h, v...)
if err != nil {
return err
}
}

//legacy does not support multiple zt-session headers, so we can it sfely
header, val := a.GetAccessHeader()
err := request.SetHeaderParam(header, val)
if err != nil {
return err
Expand Down Expand Up @@ -151,7 +165,12 @@ func (a *ApiSessionLegacy) GetExpiresAt() *time.Time {

// ApiSessionOidc represents an authenticated session backed by OIDC tokens.
type ApiSessionOidc struct {
OidcTokens *oidc.Tokens[*oidc.IDTokenClaims]
OidcTokens *oidc.Tokens[*oidc.IDTokenClaims]
RequestHeaders http.Header
}

func (a *ApiSessionOidc) GetRequestHeaders() http.Header {
return a.RequestHeaders
}

func (a *ApiSessionOidc) RequiresRouterTokenUpdate() bool {
Expand Down Expand Up @@ -203,9 +222,31 @@ func (a *ApiSessionOidc) AuthenticateRequest(request runtime.ClientRequest, _ st
return errors.New("api session is nil")
}

header, val := a.GetAccessHeader()
if a.RequestHeaders == nil {
a.RequestHeaders = http.Header{}
}

//multiple Authorization headers are allowed, obtain all auth header candidates
primaryAuthHeader, primaryAuthValue := a.GetAccessHeader()
altAuthValues := a.RequestHeaders.Get(primaryAuthHeader)

authValues := []string{primaryAuthValue}

if len(altAuthValues) > 0 {
authValues = append(authValues, altAuthValues)
}

//set request headers
for h, v := range a.RequestHeaders {
err := request.SetHeaderParam(h, v...)
if err != nil {
return err
}
}

//restore auth headers
err := request.SetHeaderParam(primaryAuthHeader, authValues...)

err := request.SetHeaderParam(header, val)
if err != nil {
return err
}
Expand Down Expand Up @@ -320,7 +361,9 @@ func (self *ZitiEdgeManagement) legacyAuth(credentials Credentials, configTypes
return nil, err
}

return &ApiSessionLegacy{Detail: resp.GetPayload().Data}, err
return &ApiSessionLegacy{
Detail: resp.GetPayload().Data,
RequestHeaders: credentials.GetRequestHeaders()}, err
}

func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) {
Expand Down Expand Up @@ -355,7 +398,8 @@ func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession, httpCli
}

return &ApiSessionOidc{
OidcTokens: tokens,
OidcTokens: tokens,
RequestHeaders: apiSession.GetRequestHeaders(),
}, nil
}

Expand Down Expand Up @@ -453,7 +497,7 @@ func (self *ZitiEdgeClient) legacyAuth(credentials Credentials, configTypes []st
return nil, err
}

return &ApiSessionLegacy{Detail: resp.GetPayload().Data}, err
return &ApiSessionLegacy{Detail: resp.GetPayload().Data, RequestHeaders: credentials.GetRequestHeaders()}, err
}

func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) {
Expand All @@ -480,7 +524,8 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient
}

newApiSession := &ApiSessionLegacy{
Detail: newApiSessionDetail.Payload.Data,
Detail: newApiSessionDetail.Payload.Data,
RequestHeaders: apiSession.GetRequestHeaders(),
}

return newApiSession, nil
Expand All @@ -492,7 +537,8 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient
}

return &ApiSessionOidc{
OidcTokens: tokens,
OidcTokens: tokens,
RequestHeaders: apiSession.GetRequestHeaders(),
}, nil
}

Expand Down Expand Up @@ -748,7 +794,8 @@ func oidcAuth(clientTransportPool ClientTransportPool, credentials Credentials,
}

return &ApiSessionOidc{
OidcTokens: outTokens,
OidcTokens: outTokens,
RequestHeaders: credentials.GetRequestHeaders(),
}, nil
}

Expand Down
77 changes: 60 additions & 17 deletions edge-apis/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ type Credentials interface {
// Method returns the authentication necessary to complete an authentication request.
Method() string

// AddHeader adds a header to the request.
AddHeader(key, value string)
// AddAuthHeader adds a header for all authentication requests.
AddAuthHeader(key, value string)

// AddRequestHeader adds a header for all requests after authentication
AddRequestHeader(key, value string)

// AddJWT adds additional JWTs to the credentials. Used to satisfy secondary authentication/MFA requirements. The
// provided token should be the base64 encoded version of the token.
Expand All @@ -38,6 +41,9 @@ type Credentials interface {
// ClientAuthInfoWriter is used to pass a Credentials instance to the openapi runtime to authenticate outgoing
//requests.
runtime.ClientAuthInfoWriter

// GetRequestHeaders returns a set of headers to use after authentication during normal HTTP operations
GetRequestHeaders() http.Header
}

// IdentityProvider is a sentinel interface used to determine whether the backing Credentials instance can provide
Expand Down Expand Up @@ -83,8 +89,11 @@ type BaseCredentials struct {
// ConfigTypes is used to set the configuration types for services during authentication
ConfigTypes []string

// Headers is a map of strings to string arrays of headers to send with auth requests.
Headers *http.Header
// AuthHeaders is a map of strings to string arrays of headers to send with auth requests.
AuthHeaders http.Header

// RequestHeaders is a map of string to string arrays of headers to send on non-authentication requests.
RequestHeaders http.Header

// EnvInfo is provided during authentication to set environmental information about the client.
EnvInfo *rest_model.EnvInfo
Expand Down Expand Up @@ -121,41 +130,75 @@ func (c *BaseCredentials) GetCaPool() *x509.CertPool {
return c.CaPool
}

// AddHeader provides a base implementation to add a header to the request.
func (c *BaseCredentials) AddHeader(key, value string) {
if c.Headers == nil {
c.Headers = &http.Header{}
// AddAuthHeader provides a base implementation to add a header to authentication requests.
func (c *BaseCredentials) AddAuthHeader(key, value string) {
if c.AuthHeaders == nil {
c.AuthHeaders = http.Header{}
}
c.Headers.Add(key, value)
c.AuthHeaders.Add(key, value)
}

// AddRequestHeader provides a base implementation to add a header to all requests after authentication.
func (c *BaseCredentials) AddRequestHeader(key, value string) {
if c.RequestHeaders == nil {
c.RequestHeaders = http.Header{}
}

c.RequestHeaders.Add(key, value)
}

// AddJWT adds additional JWTs to the credentials. Used to satisfy secondary authentication/MFA requirements. The
// provided token should be the base64 encoded version of the token. Convenience function for AddHeader.
func (c *BaseCredentials) AddJWT(token string) {
c.AddHeader("Authorization", "Bearer "+token)
c.AddAuthHeader("Authorization", "Bearer "+token)
c.AddRequestHeader("Authorization", "Bearer "+token)
}

// AuthenticateRequest provides a base implementation to authenticate an outgoing request. This is provided here
// for authentication methods such as `cert` which do not have to provide any more request level information.
func (c *BaseCredentials) AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error {
var errors []error

if c.Headers != nil {
for hName, hVals := range *c.Headers {
for _, hVal := range hVals {
err := request.SetHeaderParam(hName, hVal)
if err != nil {
errors = append(errors, err)
}
for hName, hVals := range c.AuthHeaders {
for _, hVal := range hVals {
err := request.SetHeaderParam(hName, hVal)
if err != nil {
errors = append(errors, err)
}
}
}

if len(errors) > 0 {
return network.MultipleErrors(errors)
}
return nil
}

// ProcessRequest proves a base implemmentation mutate runtime.ClientRequests as they are sent out after
// authentication. Useful for adding headers.
func (c *BaseCredentials) ProcessRequest(request runtime.ClientRequest, _ strfmt.Registry) error {
var errors []error

for hName, hVals := range c.RequestHeaders {
for _, hVal := range hVals {
err := request.SetHeaderParam(hName, hVal)
if err != nil {
errors = append(errors, err)
}
}
}

if len(errors) > 0 {
return network.MultipleErrors(errors)
}
return nil
}

// GetRequestHeaders returns headers that should be sent on requests post authentication.
func (c *BaseCredentials) GetRequestHeaders() http.Header {
return c.RequestHeaders
}

// TlsCerts provides a base implementation of returning the tls.Certificate array that will be used to setup
// mTLS connections. This is provided here for authentication methods that do not initially require mTLS (e.g. JWTs).
func (c *BaseCredentials) TlsCerts() []tls.Certificate {
Expand Down
2 changes: 1 addition & 1 deletion ziti/sdkinfo/build_info.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions ziti/ziti.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ type Context interface {
Close()

// Deprecated: AddZitiMfaHandler adds a Ziti MFA handler, invoked during authentication.
// Replaced with event functionality. Use `zitiContext.AddListener(MfaTotpCode, handler)` instead.
// Replaced with event functionality. Use `zitiContext.Events().AddMfaTotpCodeListener(func(Context, *rest_model.AuthQueryDetail, MfaCodeResponse))` instead.
AddZitiMfaHandler(handler func(query *rest_model.AuthQueryDetail, resp MfaCodeResponse) error)

// EnrollZitiMfa will attempt to enable TOTP 2FA on the currently authenticating identity if not already enrolled.
Expand Down Expand Up @@ -193,7 +193,6 @@ type ContextImpl struct {
authQueryHandlers map[string]func(query *rest_model.AuthQueryDetail, response MfaCodeResponse) error

events.EventEmmiter
apiSessionLock sync.Mutex
lastSuccessfulApiSessionRefresh time.Time
}

Expand Down Expand Up @@ -928,9 +927,6 @@ func (context *ContextImpl) Reauthenticate() error {
}

func (context *ContextImpl) Authenticate() error {
context.apiSessionLock.Lock()
defer context.apiSessionLock.Unlock()

if context.CtrlClt.GetCurrentApiSession() != nil {
if time.Since(context.lastSuccessfulApiSessionRefresh) < 5*time.Second {
return nil
Expand Down Expand Up @@ -1040,6 +1036,10 @@ func (context *ContextImpl) authenticateMfa(code string) error {
func (context *ContextImpl) handleAuthQuery(authQuery *rest_model.AuthQueryDetail) error {
context.Emit(EventAuthQuery, authQuery)

if authQuery.Provider == nil {
return fmt.Errorf("unhandled response from controller: authentication query has no provider specified")
}

if *authQuery.Provider == rest_model.MfaProvidersZiti {
handler := context.authQueryHandlers[string(rest_model.MfaProvidersZiti)]

Expand All @@ -1054,7 +1054,7 @@ func (context *ContextImpl) handleAuthQuery(authQuery *rest_model.AuthQueryDetai
return nil
}

return fmt.Errorf("unsupported MFA provider: %v", authQuery.Provider)
return fmt.Errorf("unsupported MFA provider: %v", *authQuery.Provider)
}

func (context *ContextImpl) Dial(serviceName string) (edge.Conn, error) {
Expand Down
Loading