Skip to content

Commit

Permalink
feat: Support optional field 'audience' in Identity Provider #650 (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay authored Jan 27, 2025
1 parent cbf3f88 commit 91d0c02
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Priority order:
| identityProviders.*.negativeCacheExpirationMs | 10000 | No |How long to retain JWKS response in the cache in case of failed response.
| identityProviders.*.issuerPattern | - | No |Regexp to match the claim "iss" to identity provider.
| identityProviders.*.disableJwtVerification | false | No |The flag disables JWT verification. *Note*. `userInfoEndpoint` must be unset if the flag is set to `true`.
| identityProviders.*.audience | - | No |If the setting is set it will be validated against the claim `aud` in JWT
| vertx.* | - | No |Vertx settings. Refer to [vertx.io](https://vertx.io/docs/apidocs/io/vertx/core/VertxOptions.html) to learn more.
| server.* | - | No |Vertx HTTP server settings for incoming requests.
| client.* | - | No |Vertx HTTP client settings for outbound requests.
Expand Down
7 changes: 7 additions & 0 deletions sample/aidial.settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,37 @@
"jwksUrl": "https://login.microsoftonline.com/path/discovery/keys",
"rolePath": "groups",
"projectPath": "aud",
"audience": "dial",
"issuerPattern": "^https:\\/\\/some\\.windows\\.net.+$"
},
"keycloak": {
"jwksUrl": "https://host.com/realms/your/protocol/openid-connect/certs",
"rolePath": "resource_access.your.roles",
"projectPath": "azp",
"audience": "dial",
"issuerPattern": "^https:\\/\\/some-keycloak.com.+$"
},
"google": {
"rolePath": "fn:getGoogleWorkspaceGroups",
"projectPath": "aud",
"userInfoEndpoint": "https://openidconnect.googleapis.com/v1/userinfo",
"loggingKey": "email",
"audience": "dial",
"loggingSalt": "salt"
},
"cognito": {
"loggingKey": "email",
"issuerPattern": "^https:\\/\\/cognito-idp\\.eu-north-1\\.amazonaws\\.com.+$",
"rolePath": "roles",
"projectPath": "aud",
"audience": "dial",
"jwksUrl": "https://cognito-idp.eu-north-1.amazonaws.com/eu-north-1_PWSAjo4OY/.well-known/jwks.json",
"loggingSalt": "loggingSalt"
},
"gitlab": {
"rolePath": "groups",
"projectPath": "aud",
"audience": "dial",
"userInfoEndpoint": "https://gitlab.com/oauth/userinfo",
"loggingKey": "email",
"loggingSalt": "salt"
Expand All @@ -48,6 +53,7 @@
"issuerPattern": "^https:\\/\\/chatbot-ui-staging\\.eu\\.auth0\\.com.+$",
"rolePath": "dial_roles",
"projectPath": "aud",
"audience": "dial",
"jwksUrl": "https://<your_domain>.auth0.com/.well-known/jwks.json",
"loggingSalt": "loggingSalt"
},
Expand All @@ -56,6 +62,7 @@
"issuerPattern": "^https:\\/\\/<your_domain>\\.okta\\.com.*$",
"rolePath": "Groups",
"projectPath": "aud",
"audience": "dial",
"jwksUrl": "https://<your_domain>.okta.com/oauth2/default/v1/keys",
"loggingSalt": "loggingSalt"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
Expand Down Expand Up @@ -84,6 +85,8 @@ public class IdentityProvider {

private final GetUserRoleFn getUserRoleFn;

private final String audience;

public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client,
Function<String, JwkProvider> jwkProviderSupplier, GetUserRoleFunctionFactory factory) {
if (settings == null) {
Expand Down Expand Up @@ -153,6 +156,8 @@ public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client,
}
obfuscateUserEmail = settings.getBoolean("obfuscateUserEmail", true);

audience = settings.getString("audience", null);

long period = Math.min(negativeCacheExpirationMs, positiveCacheExpirationMs);
vertx.setPeriodic(0, period, event -> evictExpiredJwks());
}
Expand Down Expand Up @@ -235,7 +240,11 @@ private DecodedJWT verifyJwt(DecodedJWT jwt, JwkResult jwkResult) {
}
Jwk jwk = jwkResult.jwk();
try {
return JWT.require(Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null)).build().verify(jwt);
Verification verification = JWT.require(Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null));
if (audience != null) {
verification.withAudience(audience);
}
return verification.build().verify(jwt);
} catch (JwkException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,57 @@ public void testExtractClaims_29() throws JwkException {
});
}

@Test
public void testExtractClaims_30() throws JwkException {
settings.put("audience", "dial");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("aud", "dial").withClaim("roles", List.of("manager")).sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("manager"), claims.userRoles());
});
}

@Test
public void testExtractClaims_31() throws JwkException {
settings.put("audience", "dial");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("aud", "wrong_aud").withClaim("roles", List.of("manager")).sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertFalse(res.succeeded());
ExtractedClaims claims = res.result();
assertNull(claims);
});
}

@Test
public void testExtractClaims_FromUserInfo_01() {
settings.remove("jwksUrl");
Expand Down

0 comments on commit 91d0c02

Please sign in to comment.