Skip to content

Commit

Permalink
Refactor code into ClaimsIdentityFactory.
Browse files Browse the repository at this point in the history
  • Loading branch information
pmaytak committed Jul 12, 2024
1 parent b19d572 commit 30b17fe
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ private ClaimsIdentity CreateClaimsIdentityWithMapping(JsonWebToken jwtToken, To
{
_ = validationParameters ?? throw LogHelper.LogArgumentNullException(nameof(validationParameters));

ClaimsIdentity identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, issuer);
ClaimsIdentity identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, issuer);
foreach (Claim jwtClaim in jwtToken.Claims)
{
bool wasMapped = _inboundClaimTypeMap.TryGetValue(jwtClaim.Type, out string claimType);
Expand Down Expand Up @@ -281,7 +281,7 @@ private ClaimsIdentity CreateClaimsIdentityPrivate(JsonWebToken jwtToken, TokenV
{
_ = validationParameters ?? throw LogHelper.LogArgumentNullException(nameof(validationParameters));

ClaimsIdentity identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, issuer);
ClaimsIdentity identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, issuer);
foreach (Claim jwtClaim in jwtToken.Claims)
{
string claimType = jwtClaim.Type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ protected virtual IEnumerable<ClaimsIdentity> ProcessStatements(SamlSecurityToke

if (!identityDict.TryGetValue(statement.Subject, out ClaimsIdentity identity))
{
identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(samlToken, validationParameters, issuer);
identity = ClaimsIdentityFactory.Create(samlToken, validationParameters, issuer);
ProcessSubject(statement.Subject, identity, issuer);
identityDict.Add(statement.Subject, identity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ protected virtual ClaimsIdentity CreateClaimsIdentity(Saml2SecurityToken samlTok
actualIssuer = ClaimsIdentity.DefaultIssuer;
}

var identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(samlToken, validationParameters, issuer);
var identity = ClaimsIdentityFactory.Create(samlToken, validationParameters, issuer);

ProcessSubject(samlToken.Assertion.Subject, identity, actualIssuer);
ProcessStatements(samlToken.Assertion.Statements, identity, actualIssuer);
Expand Down
36 changes: 36 additions & 0 deletions src/Microsoft.IdentityModel.Tokens/ClaimsIdentityFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

namespace Microsoft.IdentityModel.Tokens
{
/// <summary>
/// Facilitates the creation of <see cref="ClaimsIdentity"/> and <see cref="CaseSensitiveClaimsIdentity"/> instances based on the <see cref="AppContextSwitches.UseClaimsIdentityTypeSwitch"/>.
/// </summary>
internal static class ClaimsIdentityFactory
{
internal static ClaimsIdentity Create(IEnumerable<Claim> claims)
Expand All @@ -23,5 +26,38 @@ internal static ClaimsIdentity Create(IEnumerable<Claim> claims, string authenti

return new CaseSensitiveClaimsIdentity(claims, authenticationType);
}

internal static ClaimsIdentity Create(string authenticationType, string nameType, string roleType, SecurityToken securityToken)
{
if (AppContextSwitches.UseClaimsIdentityType())
return new ClaimsIdentity(authenticationType: authenticationType, nameType: nameType, roleType: roleType);

return new CaseSensitiveClaimsIdentity(authenticationType: authenticationType, nameType: nameType, roleType: roleType)
{
SecurityToken = securityToken,
};
}

internal static ClaimsIdentity Create(SecurityToken securityToken, TokenValidationParameters validationParameters, string issuer)
{
ClaimsIdentity claimsIdentity = validationParameters.CreateClaimsIdentity(securityToken, issuer);

if (claimsIdentity is not CaseSensitiveClaimsIdentity && !AppContextSwitches.UseClaimsIdentityType())
{
claimsIdentity = new CaseSensitiveClaimsIdentity(claimsIdentity);
}

return claimsIdentity;
}

internal static ClaimsIdentity Create(TokenHandler tokenHandler, SecurityToken securityToken, TokenValidationParameters validationParameters, string issuer)
{
ClaimsIdentity claimsIdentity = tokenHandler.CreateClaimsIdentityInternal(securityToken, validationParameters, issuer);

if (claimsIdentity is not CaseSensitiveClaimsIdentity && !AppContextSwitches.UseClaimsIdentityType())
claimsIdentity = new CaseSensitiveClaimsIdentity(claimsIdentity);

return claimsIdentity;
}
}
}
12 changes: 0 additions & 12 deletions src/Microsoft.IdentityModel.Tokens/TokenHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,6 @@ internal virtual ClaimsIdentity CreateClaimsIdentityInternal(SecurityToken secur
MarkAsNonPII("internal virtual ClaimsIdentity CreateClaimsIdentityInternal(SecurityToken securityToken, TokenValidationParameters tokenValidationParameters, string issuer)"),
MarkAsNonPII(GetType().FullName))));
}

internal static ClaimsIdentity CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(SecurityToken securityToken, TokenValidationParameters validationParameters, string issuer)
{
ClaimsIdentity identity = validationParameters.CreateClaimsIdentity(securityToken, issuer);

if (identity is not CaseSensitiveClaimsIdentity && !AppContextSwitches.UseClaimsIdentityType())
{
identity = new CaseSensitiveClaimsIdentity(identity);
}

return identity;
}
#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,7 @@ public virtual ClaimsIdentity CreateClaimsIdentity(SecurityToken securityToken,
if (LogHelper.IsEnabled(EventLogLevel.Informational))
LogHelper.LogInformation(LogMessages.IDX10245, securityToken);

if (AppContextSwitches.UseClaimsIdentityType())
return new ClaimsIdentity(authenticationType: AuthenticationType ?? DefaultAuthenticationType, nameType: nameClaimType ?? ClaimsIdentity.DefaultNameClaimType, roleType: roleClaimType ?? ClaimsIdentity.DefaultRoleClaimType);
else
return new CaseSensitiveClaimsIdentity(authenticationType: AuthenticationType ?? DefaultAuthenticationType, nameType: nameClaimType ?? ClaimsIdentity.DefaultNameClaimType, roleType: roleClaimType ?? ClaimsIdentity.DefaultRoleClaimType)
{
SecurityToken = securityToken,
};
return ClaimsIdentityFactory.Create(authenticationType: AuthenticationType ?? DefaultAuthenticationType, nameType: nameClaimType ?? ClaimsIdentity.DefaultNameClaimType, roleType: roleClaimType ?? ClaimsIdentity.DefaultRoleClaimType, securityToken);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ internal ClaimsIdentity ClaimsIdentityNoLocking

if (_validationParameters != null && SecurityToken != null && _tokenHandler != null && Issuer != null)
{
_claimsIdentity = _tokenHandler.CreateClaimsIdentityInternal(SecurityToken, _validationParameters, Issuer);
if (_claimsIdentity is not CaseSensitiveClaimsIdentity && !AppContextSwitches.UseClaimsIdentityType())
_claimsIdentity = new CaseSensitiveClaimsIdentity(_claimsIdentity);
_claimsIdentity = ClaimsIdentityFactory.Create(_tokenHandler, SecurityToken, _validationParameters, Issuer);
}

_claimsIdentityInitialized = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ private ClaimsPrincipal ValidateTokenPayload(JwtSecurityToken jwtToken, TokenVal

Validators.ValidateTokenType(jwtToken.Header.Typ, jwtToken, validationParameters);

var identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, issuer);
var identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, issuer);
if (validationParameters.SaveSigninToken)
identity.BootstrapContext = jwtToken.RawData;

Expand All @@ -1200,7 +1200,7 @@ private ClaimsPrincipal ValidateTokenPayload(JwtSecurityToken jwtToken, TokenVal

private static ClaimsPrincipal CreateClaimsPrincipalFromToken(JwtSecurityToken jwtToken, string issuer, TokenValidationParameters validationParameters)
{
var identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, issuer);
var identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, issuer);
if (validationParameters.SaveSigninToken)
identity.BootstrapContext = jwtToken.RawData;

Expand Down Expand Up @@ -1527,7 +1527,7 @@ protected virtual ClaimsIdentity CreateClaimsIdentity(JwtSecurityToken jwtToken,

private ClaimsIdentity CreateClaimsIdentityWithMapping(JwtSecurityToken jwtToken, string actualIssuer, TokenValidationParameters validationParameters)
{
ClaimsIdentity identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, actualIssuer);
ClaimsIdentity identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, actualIssuer);
foreach (Claim jwtClaim in jwtToken.Claims)
{
if (_inboundClaimFilter.Contains(jwtClaim.Type))
Expand Down Expand Up @@ -1573,7 +1573,7 @@ private ClaimsIdentity CreateClaimsIdentityWithMapping(JwtSecurityToken jwtToken

private ClaimsIdentity CreateClaimsIdentityWithoutMapping(JwtSecurityToken jwtToken, string actualIssuer, TokenValidationParameters validationParameters)
{
ClaimsIdentity identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, actualIssuer);
ClaimsIdentity identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, actualIssuer);
foreach (Claim jwtClaim in jwtToken.Claims)
{
if (_inboundClaimFilter.Contains(jwtClaim.Type))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void CreateCaseSensitveClaimsIdentity_FromTokenValidationParameters_Retur
tokenValidationParameters.NameClaimType = "custom-name";
tokenValidationParameters.RoleClaimType = "custom-role";

var actualClaimsIdentity = TokenHandler.CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jsonWebToken, tokenValidationParameters, Default.Issuer);
var actualClaimsIdentity = ClaimsIdentityFactory.Create(jsonWebToken, tokenValidationParameters, Default.Issuer);

Assert.IsType<CaseSensitiveClaimsIdentity>(actualClaimsIdentity);
Assert.NotNull(((CaseSensitiveClaimsIdentity)actualClaimsIdentity).SecurityToken);
Expand All @@ -86,7 +86,7 @@ public void CreateCaseSensitveClaimsIdentity_FromDerivedTokenValidationParameter
tokenValidationParameters.NameClaimType = "custom-name";
tokenValidationParameters.RoleClaimType = "custom-role";

var actualClaimsIdentity = TokenHandler.CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jsonWebToken, tokenValidationParameters, Default.Issuer);
var actualClaimsIdentity = ClaimsIdentityFactory.Create(jsonWebToken, tokenValidationParameters, Default.Issuer);

Assert.IsType<CaseSensitiveClaimsIdentity>(actualClaimsIdentity);
Assert.Equal(tokenValidationParameters.AuthenticationType, actualClaimsIdentity.AuthenticationType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ private ClaimsIdentity CreateClaimsIdentityWithMapping(JsonWebToken jwtToken, To
{
_ = validationParameters ?? throw LogHelper.LogArgumentNullException(nameof(validationParameters));

ClaimsIdentity identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, issuer);
ClaimsIdentity identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, issuer);
foreach (Claim jwtClaim in jwtToken.Claims)
{
bool wasMapped = _inboundClaimTypeMap.TryGetValue(jwtClaim.Type, out string claimType);
Expand Down Expand Up @@ -875,7 +875,7 @@ private ClaimsIdentity CreateClaimsIdentityPrivate(JsonWebToken jwtToken, TokenV
{
_ = validationParameters ?? throw LogHelper.LogArgumentNullException(nameof(validationParameters));

ClaimsIdentity identity = CreateCaseSensitiveClaimsIdentityFromTokenValidationParameters(jwtToken, validationParameters, issuer);
ClaimsIdentity identity = ClaimsIdentityFactory.Create(jwtToken, validationParameters, issuer);
foreach (Claim jwtClaim in jwtToken.Claims)
{
string claimType = jwtClaim.Type;
Expand Down

0 comments on commit 30b17fe

Please sign in to comment.