Skip to content

Commit

Permalink
review from daniel
Browse files Browse the repository at this point in the history
  • Loading branch information
strehle committed Jan 14, 2025
1 parent 32cf038 commit 01d3a45
Show file tree
Hide file tree
Showing 19 changed files with 694 additions and 827 deletions.
2 changes: 1 addition & 1 deletion docs/UAA-Client-Authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ This should allow a continuous trust between a UAA to UAA communication, e.g. us

The new parameter for federated Credentials in UAA clients is (Work in progress parameter):

* fed_creds
* jwt_creds

### tls_client_auth (Planned Feature)
Not yet defined a release date.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,23 @@ public String getIssuer() {
return this.issuer;
}

public void setIssuer(final String issuer) {
public void setIssuer(String issuer) {
this.issuer = issuer;
}

public String getSubject() {
return this.subject;
}

public void setSubject(final String subject) {
public void setSubject(String subject) {
this.subject = subject;
}

public String getAudience() {
return this.audience;
}

public void setAudience(final String audience) {
public void setAudience(String audience) {
this.audience = audience;
}

Expand All @@ -130,6 +130,6 @@ public boolean isFederated() {

@JsonIgnore
public ClientJwtCredential getFederation() {
return ClientJwtCredential.builder().issuer(issuer).subject(subject).audience(audience).build();
return new ClientJwtCredential(subject, issuer, audience);
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
package org.cloudfoundry.identity.uaa.oauth.client;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import lombok.Builder;
import lombok.Data;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.springframework.util.StringUtils;

import java.util.List;

@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
@JsonIgnoreProperties(ignoreUnknown = true)
@Builder(toBuilder = true)
@Data
public class ClientJwtCredential {

Expand All @@ -25,21 +23,20 @@ public class ClientJwtCredential {
@JsonProperty("aud")
private String audience;

public ClientJwtCredential() {
}

public ClientJwtCredential(String subject, String issuer, String audience) {
@JsonCreator
public ClientJwtCredential(@JsonProperty("sub") String subject, @JsonProperty("iss") String issuer, @JsonProperty("aud") String audience) {
this.subject = subject;
this.issuer = issuer;
this.audience = audience;
if (!isValid()) {
throw new IllegalArgumentException("Invalid federated jwt credentials");
}
}

@JsonIgnore
public boolean isValid() {
private boolean isValid() {
return StringUtils.hasText(subject) && StringUtils.hasText(issuer);
}

@JsonIgnore
public static List<ClientJwtCredential> parse(String clientJwtCredentials) {
try {
return JsonUtils.readValue(clientJwtCredentials, new TypeReference<>() {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,40 @@

import java.util.List;

import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;


class ClientJwtCredentialTest {

@Test
void parse() {
assertDoesNotThrow(() -> ClientJwtCredential.parse("[{\"iss\":\"http://localhost:8080/uaa\",\"sub\":\"client_with_jwks_trust\"}]"));
List<ClientJwtCredential> federationList = ClientJwtCredential.parse("[{\"iss\":\"http://localhost:8080/uaa\",\"sub\":\"client_with_jwks_trust\"},{\"iss\":\"http://localhost:8080/uaa\"}]");
assertThat(ClientJwtCredential.parse("[{\"iss\":\"http://localhost:8080/uaa\",\"sub\":\"client_with_jwks_trust\"}]")).isInstanceOf(List.class);
List<ClientJwtCredential> federationList = ClientJwtCredential.parse("[{\"iss\":\"http://localhost:8080/uaa\",\"sub\":\"client_with_jwks_trust\"},{\"iss\":\"http://localhost:8080/uaa\", \"sub\":\"another_client\"}]");
assertThat(federationList).hasSize(2);
}

@Test
void testConstructor() {
ClientJwtCredential jwtCredential = new ClientJwtCredential("subject", "issuer", "audience");
assertEquals("subject", jwtCredential.getSubject());
assertEquals("issuer", jwtCredential.getIssuer());
assertEquals("audience", jwtCredential.getAudience());
assertTrue(jwtCredential.isValid());
jwtCredential = new ClientJwtCredential();
assertFalse(jwtCredential.isValid());
assertThat(jwtCredential.getSubject()).isEqualTo("subject");
assertThat(jwtCredential.getIssuer()).isEqualTo("issuer");
assertThat(jwtCredential.getAudience()).isEqualTo("audience");
}

@Test
void testDeserializer() {
assertFalse(ClientJwtCredential.parse("[{\"iss\":\"issuer\"}]").iterator().next().isValid());
void testDeserializerConstructorException() {
assertThatThrownBy(() -> ClientJwtCredential.parse("[{\"iss\":\"http://localhost:8080/uaa\",\"sub\":\"client_with_jwks_trust\"},{\"iss\":\"http://localhost:8080/uaa\"}]"))
.isInstanceOf(IllegalArgumentException.class).hasMessage("Client jwt configuration cannot be parsed");
assertThatThrownBy(() -> ClientJwtCredential.parse("[{\"sub\":\"client_with_jwks_trust\"}]"))
.isInstanceOf(IllegalArgumentException.class).hasMessage("Client jwt configuration cannot be parsed");
assertThatThrownBy(() -> ClientJwtCredential.parse("[{\"unknown\":\"client_with_jwks_trust\"}]"))
.isInstanceOf(IllegalArgumentException.class).hasMessage("Client jwt configuration cannot be parsed");
}

@Test
void testDeserializerException() {
assertThrows(IllegalArgumentException.class, () -> ClientJwtCredential.parse("[\"iss\":\"issuer\"]"));
void testDeserializerParserException() {
assertThatThrownBy(() -> ClientJwtCredential.parse("[\"iss\":\"issuer\"]"))
.isInstanceOf(IllegalArgumentException.class).hasMessage("Client jwt configuration cannot be parsed");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha

try {
if (clientId == null) {
clientId = Optional.ofNullable(loginInfo.get(CLIENT_ASSERTION)).map(JwtClientAuthentication::getClientId).orElse(null);
clientId = Optional.ofNullable(loginInfo.get(CLIENT_ASSERTION)).map(JwtClientAuthentication::getClientIdOidcAssertion).orElse(null);
}
wrapClientCredentialLogin(req, res, loginInfo, clientId);
} catch (AuthenticationException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public ClientJwtConfiguration(final String jwksUri, final JsonWebKeySet<JsonWebK
}
}

public ClientJwtConfiguration(final List<ClientJwtCredential> clientJwtCredentials) {
public ClientJwtConfiguration(List<ClientJwtCredential> clientJwtCredentials) {
this.setClientJwtCredentials(clientJwtCredentials);
}

Expand Down Expand Up @@ -96,13 +96,8 @@ public void addJwtCredentials(final List<ClientJwtCredential> additionalCredenti
}

private static void validateClientJwtCredentials(List<ClientJwtCredential> additionalCredentials, HashMap<String, ClientJwtCredential> clientJwtCredentialHashMap) {
additionalCredentials.forEach(jwtEntry -> {
if (jwtEntry.isValid()) {
clientJwtCredentialHashMap.putIfAbsent(jwtEntry.getSubject() + jwtEntry.getIssuer(), jwtEntry);
} else {
throw new InvalidClientDetailsException("Invalid federated jwt credentials");
}
});
additionalCredentials.forEach(jwtEntry ->
clientJwtCredentialHashMap.putIfAbsent(jwtEntry.getSubject() + jwtEntry.getIssuer(), jwtEntry));
if (clientJwtCredentialHashMap.isEmpty() || clientJwtCredentialHashMap.size() > MAX_KEY_SIZE) {
throw new InvalidClientDetailsException("Invalid private_key_jwt: federated jwt credentials exceeds the maximum of keys. max: + " + MAX_KEY_SIZE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,20 @@ public boolean validateClientJwt(Map<String, String[]> requestParameters, Client
String clientAssertion = UaaStringUtils.getSafeParameterValue(requestParameters.get(CLIENT_ASSERTION));
JWT clientJWT = parseClientAssertion(clientAssertion);
JWTClaimsSet clientClaims = parseClientJWT(clientJWT);
if (!clientId.equals(getClientId(clientClaims))) {
// check if we found trust for private_key_jwt with RFC 7523
// Check if OIDC complaint client_assertion: client_id (from request) == sub (client_assertion) == iss (client_assertion)
if (clientId.equals(getClientIdOidcAssertion(clientClaims))) {
// Validate token according to private_key_jwt with OIDC
return clientId.equals(validateClientJWToken(clientJWT, oidcMetadataFetcher == null ? new JWKSet() :
JWKSet.parse(oidcMetadataFetcher.fetchWebKeySet(clientJwtConfiguration).getKeySetMap()),
clientId, clientId, keyInfoService.getTokenEndpointUrl()).getSubject());
} else {
// Check if we found trust for private_key_jwt with RFC 7523. We allow client_id (from request) != sub (client_assertion)
ClientJwtCredential jwtFederation = getClientJwtFederation(clientJwtConfiguration, clientClaims);
if (jwtFederation != null) {
return validateFederatedClientWT(clientJWT, clientClaims, jwtFederation);
}
throw new BadCredentialsException("Wrong client_assertion");
}
// validate token according to private_key_jwt with OIDC
return clientId.equals(validateClientJWToken(clientJWT, oidcMetadataFetcher == null ? new JWKSet() :
JWKSet.parse(oidcMetadataFetcher.fetchWebKeySet(clientJwtConfiguration).getKeySetMap()),
clientId, clientId, keyInfoService.getTokenEndpointUrl()).getSubject());
} catch (ParseException | URISyntaxException | InvalidTokenException | OidcMetadataFetchingException e) {
throw new BadCredentialsException("Bad client_assertion", e);
}
Expand All @@ -176,7 +178,7 @@ private static ClientJwtCredential getClientJwtFederation(ClientJwtConfiguration
if (clientJwtConfiguration.getClientJwtCredentials() == null) {
return null;
}
return clientJwtConfiguration.getClientJwtCredentials().stream().filter(e -> e.isValid() &&
return clientJwtConfiguration.getClientJwtCredentials().stream().filter(e ->
e.getSubject().equals(clientClaims.getSubject()) &&
e.getIssuer().equals(clientClaims.getIssuer()) &&
isAudienceSupported(e.getAudience(), clientClaims.getAudience())).findFirst().orElse(null);
Expand All @@ -194,7 +196,7 @@ private static JWTClaimsSet parseClientJWT(JWT clientJWT) throws ParseException
return clientJWT != null ? clientJWT.getJWTClaimsSet() : null;
}

private static String getClientId(JWTClaimsSet clientToken) {
private static String getClientIdOidcAssertion(JWTClaimsSet clientToken) {
if (clientToken != null && clientToken.getSubject() != null && clientToken.getIssuer() != null &&
clientToken.getSubject().equals(clientToken.getIssuer()) && clientToken.getAudience() != null && clientToken.getJWTID() != null &&
clientToken.getExpirationTime() != null) {
Expand All @@ -204,9 +206,9 @@ private static String getClientId(JWTClaimsSet clientToken) {
return null;
}

public static String getClientId(String clientAssertion) {
public static String getClientIdOidcAssertion(String clientAssertion) {
try {
return getClientId(parseClientJWT(parseClientAssertion(clientAssertion)));
return getClientIdOidcAssertion(parseClientJWT(parseClientAssertion(clientAssertion)));
} catch (ParseException e) {
throw new BadCredentialsException("Bad client_assertion", e);
}
Expand All @@ -215,8 +217,8 @@ public static String getClientId(String clientAssertion) {
private boolean validateFederatedClientWT(JWT jwtAssertion, JWTClaimsSet clientClaims, ClientJwtCredential jwtFederation) throws OidcMetadataFetchingException, ParseException {
try {
JWKSet jwkSet = retrieveJwkSet(clientClaims);
String expectedAud = jwtFederation.getAudience() != null ? jwtFederation.getAudience() : keyInfoService.getTokenEndpointUrl();
return validateClientJWToken(jwtAssertion, jwkSet, clientClaims.getSubject(), clientClaims.getIssuer(), expectedAud) != null;
String expectedAud = Optional.ofNullable(jwtFederation.getAudience()).orElse(keyInfoService.getTokenEndpointUrl());
return validateClientJWToken(jwtAssertion, jwkSet, jwtFederation.getSubject(), jwtFederation.getIssuer(), expectedAud) != null;
} catch (MalformedURLException | IllegalArgumentException | URISyntaxException e) {
return false;
}
Expand Down Expand Up @@ -262,11 +264,11 @@ private JWKSet retrieveJwkSet(JWTClaimsSet clientClaims) throws MalformedURLExce
}

private JWTClaimsSet validateClientJWToken(JWT jwtAssertion, JWKSet jwkSet, String expectedSub, String expectIss, String expectedAud) {
if (ObjectUtils.isEmpty(jwkSet) || jwkSet.isEmpty()) {
if (Optional.ofNullable(jwkSet).orElse(new JWKSet()).isEmpty()) {
throw new BadCredentialsException("Bad empty jwk_set");
}
Algorithm algorithm = jwtAssertion.getHeader().getAlgorithm();
if (algorithm == null || NOT_SUPPORTED_ALGORITHMS.contains(algorithm) || !(algorithm instanceof JWSAlgorithm)) {
if (!(algorithm instanceof JWSAlgorithm) || NOT_SUPPORTED_ALGORITHMS.contains(algorithm)) {
throw new BadCredentialsException("Bad client_assertion algorithm");
}
JWKSource<SecurityContext> keySource = new ImmutableJWKSet<>(jwkSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import static org.mockito.Mockito.when;

@WithDatabaseContext
class JdbcUnsuccessfulLoginCountingAuditServiceTests {
class JdbcUnsuccessfulLoginCountingAuditServiceTests {

private JdbcUnsuccessfulLoginCountingAuditService auditService;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
package org.cloudfoundry.identity.uaa.audit;


import org.cloudfoundry.identity.uaa.logging.LogSanitizerUtil;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.slf4j.Logger;

import static org.assertj.core.api.Assertions.assertThat;
import static org.cloudfoundry.identity.uaa.audit.AuditEventType.PasswordChangeFailure;
import static org.cloudfoundry.identity.uaa.audit.AuditEventType.UserAuthenticationSuccess;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

Expand All @@ -37,7 +33,7 @@ void log_format_whenThereIsAnAuthType() {
ArgumentCaptor<String> stringCaptor = ArgumentCaptor.forClass(String.class);
verify(mockLogger).info(stringCaptor.capture());
String logMessage = stringCaptor.getValue();
assertThat(logMessage, is("PasswordChangeFailure ('theData'): principal=thePrincipalId, origin=[theOrigin], identityZoneId=[theZoneId], authenticationType=[theAuthType], detailedDescription=[theDescription]"));
assertThat(logMessage).isEqualTo("PasswordChangeFailure ('theData'): principal=thePrincipalId, origin=[theOrigin], identityZoneId=[theZoneId], authenticationType=[theAuthType], detailedDescription=[theDescription]");
}

@Test
Expand All @@ -49,7 +45,7 @@ void log_format_whenAuthTypeIsNull() {
ArgumentCaptor<String> stringCaptor = ArgumentCaptor.forClass(String.class);
verify(mockLogger).info(stringCaptor.capture());
String logMessage = stringCaptor.getValue();
assertThat(logMessage, is("PasswordChangeFailure ('theData'): principal=thePrincipalId, origin=[theOrigin], identityZoneId=[theZoneId], detailedDescription=[theDescription]"));
assertThat(logMessage).isEqualTo("PasswordChangeFailure ('theData'): principal=thePrincipalId, origin=[theOrigin], identityZoneId=[theZoneId], detailedDescription=[theDescription]");
}

@Test
Expand All @@ -60,10 +56,10 @@ void log_sanitizesMaliciousInput() {

ArgumentCaptor<String> stringCaptor = ArgumentCaptor.forClass(String.class);
verify(mockLogger).info(stringCaptor.capture());
assertFalse(stringCaptor.getValue().contains("\r"));
assertFalse(stringCaptor.getValue().contains("\n"));
assertFalse(stringCaptor.getValue().contains("\t"));
assertTrue(stringCaptor.getValue().contains(LogSanitizerUtil.SANITIZED_FLAG));
assertThat(stringCaptor.getValue()).doesNotContain("\r")
.doesNotContain("\n")
.doesNotContain("\t")
.contains(LogSanitizerUtil.SANITIZED_FLAG);
}

@Test
Expand All @@ -74,6 +70,6 @@ void log_doesNotModifyNonMaliciousInput() {

ArgumentCaptor<String> stringCaptor = ArgumentCaptor.forClass(String.class);
verify(mockLogger).info(stringCaptor.capture());
assertFalse(stringCaptor.getValue().contains(LogSanitizerUtil.SANITIZED_FLAG));
assertThat(stringCaptor.getValue()).doesNotContain(LogSanitizerUtil.SANITIZED_FLAG);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -305,20 +305,6 @@ void simpleAddClientWithClientJwtCredendial() {
assertThat(clientDetails.getClientJwtConfig()).isNotNull();
}

@Test
void simpleAddClientWithClientJwtCredendial() throws Exception {
Map<String, Object> map = new HashMap<>();
map.put("id", "foo-jwks");
map.put("secret", "bar");
map.put("scope", "openid");
map.put("authorized-grant-types", GRANT_TYPE_AUTHORIZATION_CODE);
map.put("authorities", "uaa.none");
map.put("redirect-uri", "http://localhost/callback");
map.put("jwt_creds", "[{\"iss\":\"http://localhost:8080/uaa/oauth/token\",\"sub\":\"foo-jwt\"}]");
UaaClientDetails clientDetails = (UaaClientDetails) doSimpleTest(map, clientAdminBootstrap, multitenantJdbcClientDetailsService, clients);
assertNotNull(clientDetails.getClientJwtConfig());
}

@Test
void clientMetadata_getsBootstrapped() {
Map<String, Object> map = new HashMap<>();
Expand Down
Loading

0 comments on commit 01d3a45

Please sign in to comment.