Skip to content

Commit

Permalink
Merge pull request #10 from dewonderstruck/rfc/token-revocation
Browse files Browse the repository at this point in the history
RFC 7009: Token Revocation
  • Loading branch information
vamsii777 authored Oct 29, 2024
2 parents 27e1d00 + 3161013 commit 90777b9
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 10 deletions.
9 changes: 9 additions & 0 deletions Sources/VaporOAuth/OAuth2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ public struct OAuth2: LifecycleHandler {
scopeValidator: scopeValidator
)

let tokenRevocationHandler = TokenRevocationHandler(
clientValidator: clientValidator,
tokenManager: tokenManager
)

let resourceServerAuthenticator = ResourceServerAuthenticator(resourceServerRetriever: resourceServerRetriever)

// returning something like "Authenticate with GitHub page"
Expand All @@ -91,6 +96,10 @@ public struct OAuth2: LifecycleHandler {
// client requesting access/refresh token with code from POST /authorize endpoint
app.post("oauth", "token", use: tokenHandler.handleRequest)

// Revoke a token
app.post("oauth", "revoke", use: tokenRevocationHandler.handleRequest)


let tokenIntrospectionAuthMiddleware = TokenIntrospectionAuthMiddleware(resourceServerAuthenticator: resourceServerAuthenticator)
let resourceServerProtected = app.routes.grouped(tokenIntrospectionAuthMiddleware)
resourceServerProtected.post("oauth", "token_info", use: tokenIntrospectionHandler.handleRequest)
Expand Down
3 changes: 3 additions & 0 deletions Sources/VaporOAuth/Protocols/TokenManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ public protocol TokenManager: Sendable {
func getRefreshToken(_ refreshToken: String) async throws -> RefreshToken?
func getAccessToken(_ accessToken: String) async throws -> AccessToken?
func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) async throws

func revokeAccessToken(_ token: String) async throws
func revokeRefreshToken(_ token: String) async throws
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import Vapor

struct TokenRevocationHandler: Sendable {
let clientValidator: ClientValidator
let tokenManager: TokenManager

@Sendable
func handleRequest(_ request: Request) async throws -> Response {

// Validate content type
guard request.headers.contentType == .urlEncodedForm else {
return try createErrorResponse(
status: .badRequest,
errorMessage: OAuthResponseParameters.ErrorType.invalidRequest,
errorDescription: "Content-Type must be application/x-www-form-urlencoded"
)
}

let (errorResponse, requestObject) = try await validateRequest(request)

if let errorResponse = errorResponse {
return errorResponse
}

guard let requestObject = requestObject else {
throw Abort(.internalServerError)
}

// Client authentication
do {
try await clientValidator.authenticateClient(
clientID: requestObject.clientID,
clientSecret: request.content[String.self, at: OAuthRequestParameters.clientSecret],
grantType: nil
)
} catch {
return try createErrorResponse(
status: .unauthorized,
errorMessage: OAuthResponseParameters.ErrorType.invalidClient,
errorDescription: "Request had invalid client credentials"
)
}

// Attempt token revocation based on type hint
try await revokeToken(
token: requestObject.token,
typeHint: requestObject.tokenTypeHint,
clientID: requestObject.clientID
)

// RFC 7009 specifies returning 200 OK even for non-existent tokens
return createResponse()
}

private func validateRequest(_ request: Request) async throws -> (Response?, TokenRevocationRequest?) {
guard let token: String = request.content[OAuthRequestParameters.token] else {
return (try createErrorResponse(
status: .badRequest,
errorMessage: OAuthResponseParameters.ErrorType.invalidRequest,
errorDescription: "Request was missing the 'token' parameter"
), nil)
}

guard let clientID: String = request.content[OAuthRequestParameters.clientID] else {
return (try createErrorResponse(
status: .badRequest,
errorMessage: OAuthResponseParameters.ErrorType.invalidRequest,
errorDescription: "Request was missing the 'client_id' parameter"
), nil)
}

let tokenTypeHint: String? = request.content[OAuthRequestParameters.tokenTypeHint]

let requestObject = TokenRevocationRequest(
token: token,
tokenTypeHint: tokenTypeHint,
clientID: clientID
)

return (nil, requestObject)
}

private func revokeToken(token: String, typeHint: String?, clientID: String) async throws {
switch typeHint {
case "refresh_token":
if let refreshToken = try await tokenManager.getRefreshToken(token),
refreshToken.clientID == clientID {
try await tokenManager.revokeRefreshToken(token)
}

case "access_token", .none:
if let accessToken = try await tokenManager.getAccessToken(token),
accessToken.clientID == clientID {
try await tokenManager.revokeAccessToken(token)
}

default:
// RFC 7009: Unsupported token type hints are ignored
break
}
}

private func createErrorResponse(
status: HTTPStatus,
errorMessage: String,
errorDescription: String
) throws -> Response {
let response = Response(status: status)
try response.content.encode(ErrorResponse(
error: errorMessage,
errorDescription: errorDescription
))
return response
}

private func createResponse(status: HTTPStatus = .ok) -> Response {
let response = Response(status: status)
response.headers.replaceOrAdd(name: .cacheControl, value: "no-store")
response.headers.replaceOrAdd(name: .pragma, value: "no-cache")
return response
}
}

// MARK: - Request/Response Models
extension TokenRevocationHandler {
struct TokenRevocationRequest: Sendable {
let token: String
let tokenTypeHint: String?
let clientID: String
}

struct ErrorResponse: Content, Sendable {
let error: String
let errorDescription: String

enum CodingKeys: String, CodingKey {
case error
case errorDescription = "error_description"
}
}
}
6 changes: 4 additions & 2 deletions Sources/VaporOAuth/Utilities/StringDefines.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ struct OAuthRequestParameters: Sendable {
static let password = "password"
static let usernname = "username"
static let csrfToken = "csrfToken"
static let token = "token"
static let codeChallenge = "code_challenge"
static let codeChallengeMethod = "code_challenge_method"
static let codeVerifier = "code_verifier"
static let deviceCode = "device_code"
static let deviceCode = "device_code"
// Token Revocation parameters
public static let token = "token"
public static let tokenTypeHint = "token_type_hint"
}

struct OAuthResponseParameters: Sendable {
Expand Down
23 changes: 16 additions & 7 deletions Tests/VaporOAuthTests/Fakes/FakeTokenManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,48 @@ import VaporOAuth
import Foundation

class FakeTokenManager: TokenManager, @unchecked Sendable {

var accessTokenToReturn = "ACCESS-TOKEN-STRING"
var refreshTokenToReturn = "REFRESH-TOKEN-STRING"
var refreshTokens: [String: RefreshToken] = [:]
var accessTokens: [String: AccessToken] = [:]
var currentTime = Date()

func getRefreshToken(_ refreshToken: String) -> RefreshToken? {
return refreshTokens[refreshToken]
}

func getAccessToken(_ accessToken: String) -> AccessToken? {
return accessTokens[accessToken]
}

func generateAccessRefreshTokens(clientID: String, userID: String?, scopes: [String]?, accessTokenExpiryTime: Int) throws -> (AccessToken, RefreshToken) {
let accessToken = FakeAccessToken(tokenString: accessTokenToReturn, clientID: clientID, userID: userID, scopes: scopes, expiryTime: currentTime.addingTimeInterval(TimeInterval(accessTokenExpiryTime)))
let refreshToken = FakeRefreshToken(tokenString: refreshTokenToReturn, clientID: clientID, userID: userID, scopes: scopes)

accessTokens[accessTokenToReturn] = accessToken
refreshTokens[refreshTokenToReturn] = refreshToken
return (accessToken, refreshToken)
}

func generateAccessToken(clientID: String, userID: String?, scopes: [String]?, expiryTime: Int) throws -> AccessToken {
let accessToken = FakeAccessToken(tokenString: accessTokenToReturn, clientID: clientID, userID: userID, scopes: scopes, expiryTime: currentTime.addingTimeInterval(TimeInterval(expiryTime)))
accessTokens[accessTokenToReturn] = accessToken
return accessToken
}

func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) {
var tempRefreshToken = refreshToken
tempRefreshToken.scopes = scopes
refreshTokens[refreshToken.tokenString] = tempRefreshToken
}

// MARK: - New Token Revocation Methods
func revokeAccessToken(_ token: String) async throws {
accessTokens.removeValue(forKey: token)
}

func revokeRefreshToken(_ token: String) async throws {
refreshTokens.removeValue(forKey: token)
}
}
8 changes: 7 additions & 1 deletion Tests/VaporOAuthTests/Fakes/StubTokenManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import VaporOAuth
import Foundation

class StubTokenManager: TokenManager, @unchecked Sendable {

var accessToken = "ABCDEF"
var refreshToken = "GHIJKL"

Expand All @@ -26,4 +26,10 @@ class StubTokenManager: TokenManager, @unchecked Sendable {

func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) {
}

func revokeAccessToken(_ token: String) async throws {
}

func revokeRefreshToken(_ token: String) async throws {
}
}
Loading

0 comments on commit 90777b9

Please sign in to comment.