diff --git a/edge-apis/authwrapper.go b/edge-apis/authwrapper.go index 2e69023e..b9f12f18 100644 --- a/edge-apis/authwrapper.go +++ b/edge-apis/authwrapper.go @@ -8,6 +8,7 @@ import ( "github.com/go-openapi/runtime" "github.com/go-openapi/strfmt" "github.com/go-resty/resty/v2" + "github.com/golang-jwt/jwt/v5" "github.com/openziti/edge-api/rest_client_api_client" clientAuth "github.com/openziti/edge-api/rest_client_api_client/authentication" clientControllers "github.com/openziti/edge-api/rest_client_api_client/controllers" @@ -74,6 +75,9 @@ type ApiSession interface { //GetId returns the id of the ApiSession GetId() string + + //RequiresRouterTokenUpdate returns true if the token is a bearer token requires updating on edge router connections. + RequiresRouterTokenUpdate() bool } var _ ApiSession = (*ApiSessionLegacy)(nil) @@ -85,6 +89,10 @@ type ApiSessionLegacy struct { Detail *rest_model.CurrentAPISessionDetail } +func (a *ApiSessionLegacy) RequiresRouterTokenUpdate() bool { + return false +} + func (a *ApiSessionLegacy) GetId() string { return stringz.OrEmpty(a.Detail.ID) } @@ -146,6 +154,10 @@ type ApiSessionOidc struct { OidcTokens *oidc.Tokens[*oidc.IDTokenClaims] } +func (a *ApiSessionOidc) RequiresRouterTokenUpdate() bool { + return true +} + func (a *ApiSessionOidc) GetAccessClaims() (*ApiAccessClaims, error) { claims := &ApiAccessClaims{} @@ -491,39 +503,60 @@ func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenC } func exchangeTokens(clientTransportPool ClientTransportPool, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + subjectToken := curTokens.RefreshToken + subjectTokenType := oidc.RefreshTokenType + + // if subjectToken is "", then we don't have a refresh token, attempt to exchange a non-expired access token + if subjectToken == "" { + if curTokens.Expiry.Before(time.Now()) { + return nil, errors.New("cannot exchange token: refresh token not found, access token expired") + } + + if curTokens.AccessToken == "" { + return nil, errors.New("cannot exchange token: refresh token not found, access token not found") + } + subjectToken = curTokens.AccessToken + subjectTokenType = oidc.AccessTokenType + } var outTokens *oidc.Tokens[*oidc.IDTokenClaims] _, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { apiHost := transport.ApiUrl.Host - te, err := tokenexchange.NewTokenExchanger(apiHost, tokenexchange.WithHTTPClient(client)) + issuer := "https://" + apiHost + "/oidc" + tokenEndpoint := "https://" + apiHost + "/oidc/oauth/token" + + te, err := tokenexchange.NewTokenExchangerClientCredentials(issuer, "native", "", tokenexchange.WithHTTPClient(client), tokenexchange.WithStaticTokenEndpoint(issuer, tokenEndpoint)) if err != nil { return nil, err } - accessResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.AccessTokenType) + var tokenResponse *oidc.TokenExchangeResponse - if err != nil { - return nil, err - } + now := time.Now() - //TODO: be smarter, only refresh refresh token if the new access token lives beyond refresh - refreshResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType) + switch subjectTokenType { + case oidc.RefreshTokenType: + tokenResponse, err = tokenexchange.ExchangeToken(te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType) + case oidc.AccessTokenType: + tokenResponse, err = tokenexchange.ExchangeToken(te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.AccessTokenType) + } if err != nil { return nil, err } - idResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.IDTokenType) + idResp, err := tokenexchange.ExchangeToken(te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.IDTokenType) if err != nil { return nil, err } - idClaims := &oidc.IDTokenClaims{} + idClaims := &IdClaims{} - err = json.Unmarshal([]byte(idResp.AccessToken), idClaims) + //access token is used to hold id token per zitadel comments + _, _, err = jwt.NewParser().ParseUnverified(idResp.AccessToken, idClaims) if err != nil { return nil, err @@ -531,13 +564,13 @@ func exchangeTokens(clientTransportPool ClientTransportPool, curTokens *oidc.Tok outTokens = &oidc.Tokens[*oidc.IDTokenClaims]{ Token: &oauth2.Token{ - AccessToken: accessResp.AccessToken, - TokenType: accessResp.TokenType, - RefreshToken: refreshResp.RefreshToken, - Expiry: time.Time{}, + AccessToken: tokenResponse.AccessToken, + TokenType: tokenResponse.TokenType, + RefreshToken: tokenResponse.RefreshToken, + Expiry: now.Add(time.Duration(tokenResponse.ExpiresIn)), }, - IDTokenClaims: idClaims, - IDToken: idResp.AccessToken, //access token is used to hold id token per zitadel comments + IDTokenClaims: &idClaims.IDTokenClaims, + IDToken: idResp.AccessToken, //access token field is used to hold id token per zitadel comments } return outTokens, nil diff --git a/edge-apis/oidc.go b/edge-apis/oidc.go index 7ed29732..4cef0775 100644 --- a/edge-apis/oidc.go +++ b/edge-apis/oidc.go @@ -39,6 +39,38 @@ type ApiAccessClaims struct { Scopes []string `json:"scopes,omitempty"` } +var _ jwt.Claims = (*IdClaims)(nil) + +// IdClaims wraps oidc.IDToken claims to fulfill the jwt.Claims interface +type IdClaims struct { + oidc.IDTokenClaims +} + +func (r *IdClaims) GetExpirationTime() (*jwt.NumericDate, error) { + return &jwt.NumericDate{Time: r.TokenClaims.GetExpiration()}, nil +} + +func (r *IdClaims) GetNotBefore() (*jwt.NumericDate, error) { + notBefore := r.TokenClaims.NotBefore.AsTime() + return &jwt.NumericDate{Time: notBefore}, nil +} + +func (r *IdClaims) GetIssuedAt() (*jwt.NumericDate, error) { + return &jwt.NumericDate{Time: r.TokenClaims.GetIssuedAt()}, nil +} + +func (r *IdClaims) GetIssuer() (string, error) { + return r.TokenClaims.Issuer, nil +} + +func (r *IdClaims) GetSubject() (string, error) { + return r.TokenClaims.Issuer, nil +} + +func (r *IdClaims) GetAudience() (jwt.ClaimStrings, error) { + return jwt.ClaimStrings(r.TokenClaims.Audience), nil +} + type localRpServer struct { Server *http.Server Port string diff --git a/example/chat/chat-server/chat-server.go b/example/chat/chat-server/chat-server.go index 990504a9..7b455156 100644 --- a/example/chat/chat-server/chat-server.go +++ b/example/chat/chat-server/chat-server.go @@ -165,4 +165,5 @@ func main() { logger.Infof("new connection") go server.handleChat(conn) } + } diff --git a/ziti/client.go b/ziti/client.go index 5b313ae3..e86371ec 100644 --- a/ziti/client.go +++ b/ziti/client.go @@ -25,14 +25,8 @@ import ( "crypto/x509/pkix" "encoding/pem" "fmt" - "github.com/golang-jwt/jwt/v5" - "github.com/openziti/foundation/v2/genext" - "github.com/openziti/transport/v2" - "github.com/pkg/errors" - "strings" - "time" - "github.com/go-openapi/strfmt" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/michaelquigley/pfxlog" "github.com/openziti/edge-api/rest_client_api_client/authentication" @@ -43,10 +37,14 @@ import ( "github.com/openziti/edge-api/rest_client_api_client/session" "github.com/openziti/edge-api/rest_model" "github.com/openziti/edge-api/rest_util" + "github.com/openziti/foundation/v2/genext" nfPem "github.com/openziti/foundation/v2/pem" "github.com/openziti/identity" apis "github.com/openziti/sdk-golang/edge-apis" "github.com/openziti/sdk-golang/ziti/edge/posture" + "github.com/openziti/transport/v2" + "github.com/pkg/errors" + "strings" ) // CtrlClient is a stateful version of ZitiEdgeClient that simplifies operations @@ -72,7 +70,7 @@ func (self *CtrlClient) GetCurrentApiSession() apis.ApiSession { } // Refresh will contact the controller extending the current ApiSession for legacy API Sessions -func (self *CtrlClient) Refresh() (*time.Time, error) { +func (self *CtrlClient) Refresh() (apis.ApiSession, error) { if apiSession := self.GetCurrentApiSession(); apiSession != nil { newApiSession, err := self.API.RefreshApiSession(apiSession, self.HttpClient) @@ -82,7 +80,7 @@ func (self *CtrlClient) Refresh() (*time.Time, error) { self.ApiSession.Store(&newApiSession) - return newApiSession.GetExpiresAt(), nil + return newApiSession, nil } return nil, errors.New("no api session") diff --git a/ziti/edge/conn.go b/ziti/edge/conn.go index b10097ce..12dc254b 100644 --- a/ziti/edge/conn.go +++ b/ziti/edge/conn.go @@ -42,7 +42,7 @@ type RouterClient interface { //UpdateToken will attempt to send token updates to the connected router. A success/failure response is expected //within the timeout period. - UpdateToken(token string, timeout time.Duration) error + UpdateToken(token []byte, timeout time.Duration) error } type RouterConn interface { diff --git a/ziti/edge/network/factory.go b/ziti/edge/network/factory.go index 5169f624..3008eda7 100644 --- a/ziti/edge/network/factory.go +++ b/ziti/edge/network/factory.go @@ -116,8 +116,8 @@ func (conn *routerConn) NewDialConn(service *rest_model.ServiceDetail) *edgeConn return edgeCh } -func (conn *routerConn) UpdateToken(token string, timeout time.Duration) error { - msg := edge.NewUpdateTokenMsg([]byte(token)) +func (conn *routerConn) UpdateToken(token []byte, timeout time.Duration) error { + msg := edge.NewUpdateTokenMsg(token) resp, err := msg.WithTimeout(timeout).SendForReply(conn.ch) if err != nil { diff --git a/ziti/ziti.go b/ziti/ziti.go index 3e32d41f..2545164e 100644 --- a/ziti/ziti.go +++ b/ziti/ziti.go @@ -743,6 +743,20 @@ func (context *ContextImpl) RefreshService(serviceName string) (*rest_model.Serv return serviceDetail, nil } +func (context *ContextImpl) updateTokenOnAllErs(apiSession apis.ApiSession) { + if apiSession.RequiresRouterTokenUpdate() { + for tpl := range context.routerConnections.IterBuffered() { + erConn := tpl.Val + erKey := tpl.Key + go func() { + if err := erConn.UpdateToken(apiSession.GetToken(), 10*time.Second); err != nil { + pfxlog.Logger().WithError(err).WithField("er", erKey).Warn("error updating apiSession token to connected ER") + } + }() + } + } +} + func (context *ContextImpl) runRefreshes() { log := pfxlog.Logger() svcRefreshInterval := context.options.RefreshInterval @@ -768,8 +782,9 @@ func (context *ContextImpl) runRefreshes() { defer sessionRefreshTick.Stop() refreshAt := time.Now().Add(30 * time.Second) + if currentApiSession := context.CtrlClt.GetCurrentApiSession(); currentApiSession != nil && currentApiSession.GetExpiresAt() != nil { - refreshAt = time.Time(*currentApiSession.GetExpiresAt()).Add(-10 * time.Second) + refreshAt = (*currentApiSession.GetExpiresAt()).Add(-10 * time.Second) } for { @@ -778,14 +793,25 @@ func (context *ContextImpl) runRefreshes() { return case <-time.After(time.Until(refreshAt)): - exp, err := context.CtrlClt.Refresh() + apiSession := context.CtrlClt.GetCurrentApiSession() + + if apiSession == nil { + pfxlog.Logger().Warn("could not refresh api session, current api session is nil") + continue + } + + newApiSession, err := context.CtrlClt.Refresh() + if err != nil { log.Errorf("could not refresh apiSession: %v", err) refreshAt = time.Now().Add(5 * time.Second) } else { + exp := newApiSession.GetExpiresAt() refreshAt = exp.Add(-10 * time.Second) log.Debugf("apiSession refreshed, new expiration[%s]", *exp) + + context.updateTokenOnAllErs(newApiSession) } case <-svcRefreshTick.C: @@ -926,8 +952,9 @@ func (context *ContextImpl) RefreshApiSessionWithBackoff() error { expBackoff.MaxElapsedTime = 24 * time.Hour operation := func() error { - _, err := context.CtrlClt.Refresh() + newApiSession, err := context.CtrlClt.Refresh() if err == nil { + context.updateTokenOnAllErs(newApiSession) return nil } @@ -990,9 +1017,13 @@ func (context *ContextImpl) authenticateMfa(code string) error { return err } - if _, err := context.CtrlClt.Refresh(); err != nil { + newApiSession, err := context.CtrlClt.Refresh() + + if err != nil { return err } + context.updateTokenOnAllErs(newApiSession) + apiSession := context.CtrlClt.GetCurrentApiSession() if apiSession != nil && len(apiSession.GetAuthQueries()) == 0 {