diff --git a/src/Microsoft.IdentityModel.Tokens/Validation/Validators.Audience.cs b/src/Microsoft.IdentityModel.Tokens/Validation/Validators.Audience.cs index 17d2ec5ba2..7b0be6598b 100644 --- a/src/Microsoft.IdentityModel.Tokens/Validation/Validators.Audience.cs +++ b/src/Microsoft.IdentityModel.Tokens/Validation/Validators.Audience.cs @@ -5,8 +5,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Threading.Tasks; -using System.Threading; using Microsoft.IdentityModel.Abstractions; using Microsoft.IdentityModel.Logging; @@ -81,17 +79,7 @@ public static void ValidateAudience(IEnumerable audiences, SecurityToken new SecurityTokenInvalidAudienceException(LogHelper.FormatInvariant(LogMessages.IDX10206)) { InvalidAudience = Utility.SerializeAsSingleCommaDelimitedString(audiences) }); - // create enumeration of all valid audiences from validationParameters - IEnumerable validationParametersAudiences; - - if (validationParameters.ValidAudiences == null) - validationParametersAudiences = new[] { validationParameters.ValidAudience }; - else if (string.IsNullOrWhiteSpace(validationParameters.ValidAudience)) - validationParametersAudiences = validationParameters.ValidAudiences; - else - validationParametersAudiences = validationParameters.ValidAudiences.Concat(new[] { validationParameters.ValidAudience }); - - if (AudienceIsValid(audiences, validationParameters, validationParametersAudiences)) + if (AudienceIsValid(audiences, validationParameters)) return; SecurityTokenInvalidAudienceException ex = new SecurityTokenInvalidAudienceException( @@ -174,17 +162,7 @@ internal static AudienceValidationResult ValidateAudience(IEnumerable au typeof(SecurityTokenInvalidAudienceException), new StackFrame(true))); - // create enumeration of all valid audiences from validationParameters - IEnumerable validationParametersAudiences; - - if (validationParameters.ValidAudiences == null) - validationParametersAudiences = new[] { validationParameters.ValidAudience }; - else if (string.IsNullOrWhiteSpace(validationParameters.ValidAudience)) - validationParametersAudiences = validationParameters.ValidAudiences; - else - validationParametersAudiences = validationParameters.ValidAudiences.Concat(new[] { validationParameters.ValidAudience }); - - string? validAudience = AudienceIsValidReturning(audiences, validationParameters, validationParametersAudiences); + string? validAudience = AudienceIsValidReturning(audiences, validationParameters); if (validAudience != null) { return new AudienceValidationResult(validAudience); @@ -203,45 +181,72 @@ internal static AudienceValidationResult ValidateAudience(IEnumerable au new StackFrame(true))); } - private static bool AudienceIsValid(IEnumerable audiences, TokenValidationParameters validationParameters, IEnumerable validationParametersAudiences) + private static bool AudienceIsValid(IEnumerable audiences, TokenValidationParameters validationParameters) + { + return AudienceIsValidReturning(audiences, validationParameters) != null; + } + + private static string? AudienceIsValidReturning(IEnumerable audiences, TokenValidationParameters validationParameters) { - return AudienceIsValidReturning(audiences, validationParameters, validationParametersAudiences) != null; + string? validAudience = null; + if (!string.IsNullOrWhiteSpace(validationParameters.ValidAudience)) + validAudience = AudiencesMatchSingle(audiences, validationParameters.ValidAudience, validationParameters.IgnoreTrailingSlashWhenValidatingAudience); + + if (validAudience == null && validationParameters.ValidAudiences != null) + validAudience = AudiencesMatchList(audiences, validationParameters.ValidAudiences, validationParameters.IgnoreTrailingSlashWhenValidatingAudience); + + return validAudience; + } + + private static string? AudiencesMatchSingle(IEnumerable audiences, string validAudience, bool ignoreTrailingSlashWhenValidatingAudience) + { + foreach (string tokenAudience in audiences) + { + if (string.IsNullOrWhiteSpace(tokenAudience)) + continue; + + if (AudiencesMatch(ignoreTrailingSlashWhenValidatingAudience, tokenAudience, validAudience)) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogMessages.IDX10234, LogHelper.MarkAsNonPII(tokenAudience)); + + return tokenAudience; + } + } + + return null; } - private static string? AudienceIsValidReturning(IEnumerable audiences, TokenValidationParameters validationParameters, IEnumerable validationParametersAudiences) + private static string? AudiencesMatchList(IEnumerable audiences, IEnumerable validAudiences, bool ignoreTrailingSlashWhenValidatingAudience) { foreach (string tokenAudience in audiences) { if (string.IsNullOrWhiteSpace(tokenAudience)) continue; - foreach (string validAudience in validationParametersAudiences) + foreach (string validAudience in validAudiences) { - if (string.IsNullOrWhiteSpace(validAudience)) + if (string.IsNullOrEmpty(validAudience)) continue; - if (AudiencesMatch(validationParameters, tokenAudience, validAudience)) + if (AudiencesMatch(ignoreTrailingSlashWhenValidatingAudience, tokenAudience, validAudience)) { if (LogHelper.IsEnabled(EventLogLevel.Informational)) LogHelper.LogInformation(LogMessages.IDX10234, LogHelper.MarkAsNonPII(tokenAudience)); - return validAudience; + return tokenAudience; } } } return null; } -#nullable disable - private static bool AudiencesMatch(TokenValidationParameters validationParameters, string tokenAudience, string validAudience) + private static bool AudiencesMatch(bool ignoreTrailingSlashWhenValidatingAudience, string tokenAudience, string validAudience) { if (validAudience.Length == tokenAudience.Length) - { - if (string.Equals(validAudience, tokenAudience)) - return true; - } - else if (validationParameters.IgnoreTrailingSlashWhenValidatingAudience && AudiencesMatchIgnoringTrailingSlash(tokenAudience, validAudience)) + return string.Equals(validAudience, tokenAudience); + else if (ignoreTrailingSlashWhenValidatingAudience && AudiencesMatchIgnoringTrailingSlash(tokenAudience, validAudience)) return true; return false; diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/Validation/AudienceValidationResultTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/Validation/AudienceValidationResultTests.cs index 20615327e2..aa6a45b784 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/Validation/AudienceValidationResultTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/Validation/AudienceValidationResultTests.cs @@ -381,7 +381,7 @@ public static TheoryData ValidateAudienceTheoryDat Audiences = audiences1, TestId = "ValidAudienceWithSlashTVPTrue", ValidationParameters = new TokenValidationParameters{ ValidAudience = audience1 + "/" }, - AudienceValidationResult = new AudienceValidationResult(audience1Slash) + AudienceValidationResult = new AudienceValidationResult(audience1) }, new AudienceValidationTheoryData { @@ -407,7 +407,7 @@ public static TheoryData ValidateAudienceTheoryDat Audiences = audiences1, TestId = "ValidAudiencesWithSlashTVPTrue", ValidationParameters = new TokenValidationParameters{ ValidAudiences = audiences1WithSlash }, - AudienceValidationResult = new AudienceValidationResult(audience1Slash) + AudienceValidationResult = new AudienceValidationResult(audience1) }, new AudienceValidationTheoryData { @@ -490,7 +490,7 @@ public static TheoryData ValidateAudienceTheoryDat Audiences = audiences1WithSlash, TestId = "TokenAudienceWithSlashTVPTrue", ValidationParameters = new TokenValidationParameters{ ValidAudience = audience1 }, - AudienceValidationResult = new AudienceValidationResult(audience1) + AudienceValidationResult = new AudienceValidationResult(audience1Slash) }, new AudienceValidationTheoryData { @@ -535,7 +535,7 @@ public static TheoryData ValidateAudienceTheoryDat Audiences = audiences1WithSlash, TestId = "TokenAudiencesWithSlashTVPTrue", ValidationParameters = new TokenValidationParameters{ ValidAudience = audience1 }, - AudienceValidationResult = new AudienceValidationResult(audience1) + AudienceValidationResult = new AudienceValidationResult(audience1Slash) }, new AudienceValidationTheoryData {