Skip to content

Commit

Permalink
Redesigned Audiences to use IList
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshLozensky committed May 19, 2024
1 parent 1cead86 commit 17fb1a9
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class BenchmarkUtils

public const string Audience = "http://www.contoso.com/protected";

public readonly static IEnumerable<string> Audiences = new string[] {
public readonly static IList<string> Audiences = new string[] {
"http://www.contoso.com/protected",
"http://www.contoso.com/protected1",
"http://www.contoso.com/protected2",
Expand Down
23 changes: 12 additions & 11 deletions benchmark/Microsoft.IdentityModel.Benchmarks/CreateTokenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,30 @@ public void Setup()
_tokenDescriptor = new SecurityTokenDescriptor
{
Claims = BenchmarkUtils.Claims,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256
};

_tokenDescriptorSingleAudienceUsingAudiencesMember = new SecurityTokenDescriptor
{
Claims = BenchmarkUtils.ClaimsNoAudience,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256
};

_tokenDescriptorMultipleAudiencesMemberOnly = new SecurityTokenDescriptor
{
Claims = BenchmarkUtils.ClaimsNoAudience,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256,
Audiences = BenchmarkUtils.Audiences,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256
};

_tokenDescriptorMultipleAudiencesMemberAndClaims = new SecurityTokenDescriptor
{
Claims = BenchmarkUtils.ClaimsMultipleAudiences,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256,
Audiences = BenchmarkUtils.Audiences,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256
};

_tokenDescriptorSingleAudienceUsingAudiencesMember = new SecurityTokenDescriptor
{
Claims = BenchmarkUtils.ClaimsNoAudience,
SigningCredentials = BenchmarkUtils.SigningCredentialsRsaSha256,
Audiences = new string[] { BenchmarkUtils.Audience }
};
_tokenDescriptorSingleAudienceUsingAudiencesMember.Audiences.Add(BenchmarkUtils.Audience);
_tokenDescriptorMultipleAudiencesMemberOnly.AddAudiences(BenchmarkUtils.Audiences);
_tokenDescriptorMultipleAudiencesMemberAndClaims.AddAudiences(BenchmarkUtils.Audiences);
}

[Benchmark]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,12 +706,15 @@ internal static void WriteJwsPayload(

writer.WriteStartObject();

// TODO at next major version (8.0) use only Audiences as SecurityTokenDescriptor.Audience will be removed.
if (!tokenDescriptor.Audiences.IsNullOrEmpty())
{
writer.WritePropertyName(JwtPayloadUtf8Bytes.Aud);
writer.WriteStartArray();
foreach (string audience in tokenDescriptor.Audiences){ writer.WriteStringValue(audience);}
foreach (string audience in tokenDescriptor.Audiences) { writer.WriteStringValue(audience); }

if (!string.IsNullOrEmpty(tokenDescriptor.Audience))
writer.WriteStringValue(tokenDescriptor.Audience);

writer.WriteEndArray();
audienceSet = true;
}
Expand Down Expand Up @@ -763,7 +766,6 @@ internal static void WriteJwsPayload(
audienceChecked = true;
if (audienceSet)
{
// TODO at next major version Audience will be removed at that time remove this local variable.
string descriptorMemberName = null;
if (!tokenDescriptor.Audiences.IsNullOrEmpty())
descriptorMemberName = nameof(tokenDescriptor.Audiences);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,6 @@ public SamlAudienceRestrictionCondition(IEnumerable<Uri> audiences)
Audiences = (audiences == null) ? throw LogArgumentNullException(nameof(audiences)) : new List<Uri>(audiences);
}

/// <summary>
/// Creates an instance of <see cref="SamlAudienceRestrictionCondition"/>.
/// </summary>
/// <param name="audiences">An <see cref="IEnumerable{String}"/> containing the audiences for a <see cref="SamlAssertion"/>.</param>
internal SamlAudienceRestrictionCondition(IEnumerable<string> audiences)
{
if (audiences == null)
throw LogArgumentNullException(nameof(audiences));

List<Uri> audienceUris = new();
foreach (var aud in audiences) { audienceUris.Add(new Uri(aud)); }
Audiences = audienceUris;
}

/// <summary>
/// Gets the <see cref="ICollection{stringT}"/> of audiences for a <see cref="SamlAssertion"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ protected virtual IEnumerable<ClaimsIdentity> CreateClaimsIdentities(SamlSecurit
/// <exception cref="ArgumentNullException">if <paramref name="tokenDescriptor"/> is null.</exception>
protected virtual SamlConditions CreateConditions(SecurityTokenDescriptor tokenDescriptor)
{
if (null == tokenDescriptor)
if (tokenDescriptor == null)
throw LogArgumentNullException(nameof(tokenDescriptor));

var conditions = new SamlConditions();
Expand All @@ -368,15 +368,30 @@ protected virtual SamlConditions CreateConditions(SecurityTokenDescriptor tokenD
else if (SetDefaultTimesOnTokenCreation)
conditions.NotOnOrAfter = DateTime.UtcNow + TimeSpan.FromMinutes(TokenLifetimeInMinutes);

// TODO at next major version (8.0) use only Audiences as SecurityTokenDescriptor.Audience will be removed.
if (!tokenDescriptor.Audiences.IsNullOrEmpty())
conditions.Conditions.Add(new SamlAudienceRestrictionCondition(tokenDescriptor.Audiences));
else if (!string.IsNullOrEmpty(tokenDescriptor.Audience))
conditions.Conditions.Add(new SamlAudienceRestrictionCondition(new Uri(tokenDescriptor.Audience)));
var uriList = createUriList(tokenDescriptor);
if (!uriList.IsNullOrEmpty())
conditions.Conditions.Add(new SamlAudienceRestrictionCondition(uriList));

return conditions;
}

private static List<Uri> createUriList(SecurityTokenDescriptor tokenDescriptor)
{
var uriList = new List<Uri>();
if (!tokenDescriptor.Audiences.IsNullOrEmpty())
{
foreach (var audience in tokenDescriptor.Audiences.Where(aud => !string.IsNullOrWhiteSpace(aud)))
uriList.Add(new Uri(audience));

if(!string.IsNullOrWhiteSpace(tokenDescriptor.Audience) && !tokenDescriptor.Audiences.Contains(tokenDescriptor.Audience))
uriList.Add(new Uri(tokenDescriptor.Audience));
}
else if (!string.IsNullOrWhiteSpace(tokenDescriptor.Audience))
uriList.Add(new Uri(tokenDescriptor.Audience));

return uriList;
}

/// <summary>
/// Generates an enumeration of SamlStatements from a SecurityTokenDescriptor.
/// Only SamlAttributeStatements and SamlAuthenticationStatements are generated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,12 @@ protected virtual Saml2Conditions CreateConditions(SecurityTokenDescriptor token
else if (SetDefaultTimesOnTokenCreation)
conditions.NotOnOrAfter = DateTime.UtcNow + TimeSpan.FromMinutes(TokenLifetimeInMinutes);

// TODO: At next major version remove use of the Audience property.
if (!tokenDescriptor.Audiences.IsNullOrEmpty())
conditions.AudienceRestrictions.Add(new Saml2AudienceRestriction(tokenDescriptor.Audiences));
else if (!string.IsNullOrEmpty(tokenDescriptor.Audience))
conditions.AudienceRestrictions.Add(new Saml2AudienceRestriction(tokenDescriptor.Audience));
var audienceRestriction = new Saml2AudienceRestriction(tokenDescriptor.Audiences.Where(aud => !string.IsNullOrWhiteSpace(aud)));

if (!string.IsNullOrWhiteSpace(tokenDescriptor.Audience) && !tokenDescriptor.Audiences.Contains(tokenDescriptor.Audience))
audienceRestriction.Audiences.Add(tokenDescriptor.Audience);

conditions.AudienceRestrictions.Add(audienceRestriction);

return conditions;
}
Expand Down
52 changes: 14 additions & 38 deletions src/Microsoft.IdentityModel.Tokens/SecurityTokenDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
using System.Threading;

namespace Microsoft.IdentityModel.Tokens
{
Expand All @@ -13,55 +14,30 @@ namespace Microsoft.IdentityModel.Tokens
/// </summary>
public class SecurityTokenDescriptor
{
// TODO: At next major version (8.0), remove Audience and logic for combining with Audiences.
private HashSet<string> _audiences;
private List<string> _audiences;

/// <summary>
/// Gets or sets the value of the 'audience' claim. Will be deprecated in favor of <see cref="Audiences"/> in the next
/// major version (8.x).
/// Gets or sets the value of the {"": audience} claim. Will be combined with <see cref="Audiences"/> and any "Aud" claims in
/// <see cref="Claims"/> or <see cref="Subject"/> when creating a token.
/// </summary>
public string Audience { get; set; }

/// <summary>
/// Gets or sets one or more audiences to include in the token's 'Aud' claim. Automatically removes duplicates and empty,
/// null, or whitespace-only strings. Does not use a threadsafe collection.
/// Gets the list audiences to include in the token's 'Aud' claim. Will be combined with <see cref="Audiences"/> and any
/// "Aud" claims in <see cref="Claims"/> or <see cref="Subject"/> when creating a token.
/// </summary>
public IEnumerable<string> Audiences {
get
{
// If Audiences isn't set, return null since this will be the behavior once Audience is removed.
if (_audiences.IsNullOrEmpty())
return null;

// If both Audience and Audiences are set, return the union of the two.
else if (!string.IsNullOrEmpty(Audience))
return _audiences.Union([Audience]);

// If only Audiences is set, return it
else
return _audiences;
}
set
{
_audiences = new HashSet<string>(value);
_audiences.RemoveWhere(string.IsNullOrWhiteSpace);
}
}
public IList<string> Audiences => _audiences ?? Interlocked.CompareExchange(ref _audiences, [], null) ?? _audiences;

/// <summary>
/// Adds an audience to the <see cref="Audiences"/> collection. Won't add duplicate, null, empty, or whitespace-only strings.
/// Enables adding multiple audiences to the Audiences member at once.
/// </summary>
/// <param name="audience">An audience to be added to the Aud claim</param>
public void AddAudience(string audience)
/// <param name="auds">List of strings with each representing an audience to add to the 'Aud' claim</param>
public void AddAudiences(IList<string> auds)
{
if (string.IsNullOrWhiteSpace(audience))
return;

if (_audiences == null)
_audiences = new HashSet<string>();

_audiences.Add(audience);
}
_ = Audiences;
if (auds != null)
_audiences.AddRange(auds);
}

/// <summary>
/// Defines the compression algorithm that will be used to compress the JWT token payload.
Expand Down
Loading

0 comments on commit 17fb1a9

Please sign in to comment.