diff --git a/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/HeaderPreservingHeaderTests.cs b/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/HeaderPreservingHeaderTests.cs index 757797e..cc60905 100644 --- a/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/HeaderPreservingHeaderTests.cs +++ b/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/HeaderPreservingHeaderTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. Please refer to the LICENSE file at the root of this project for details using System; +using System.Net; using DragonFruit.Common.Data.Handlers; using DragonFruit.Common.Data.Tests.Handlers.AuthPreservingHandler.Objects; using NUnit.Framework; @@ -33,7 +34,7 @@ public void TestHeaderPreservation() redirectClient.Authorization = $"{auth.Type} {auth.AccessToken}"; // user lookups by username = 301. without our HeaderPreservingHandler we'd get a 401 - redirectClient.Perform(new OrbitTestUserRequest()); + Assert.AreEqual(redirectClient.Perform(new OrbitTestUserRequest()).StatusCode, HttpStatusCode.OK); } } } diff --git a/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/Objects/AuthRequest.cs b/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/Objects/AuthRequest.cs index 4977dd9..af399ce 100644 --- a/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/Objects/AuthRequest.cs +++ b/DragonFruit.Common.Data.Tests/Handlers/AuthPreservingHandler/Objects/AuthRequest.cs @@ -22,6 +22,9 @@ public class AuthRequest : ApiRequest [FormParameter("client_secret")] public string ClientSecret => GetEnvironmentVar("orbit_client_secret"); + [FormParameter("scope")] + public string Scopes => "public"; + private static string GetEnvironmentVar(string var) { var envVar = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) diff --git a/DragonFruit.Common.Data/Handlers/HeaderPreservingRedirectHandler.cs b/DragonFruit.Common.Data/Handlers/HeaderPreservingRedirectHandler.cs index 318a22d..262b27c 100644 --- a/DragonFruit.Common.Data/Handlers/HeaderPreservingRedirectHandler.cs +++ b/DragonFruit.Common.Data/Handlers/HeaderPreservingRedirectHandler.cs @@ -4,6 +4,7 @@ using System; using System.Net; using System.Net.Http; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -35,6 +36,20 @@ public HeaderPreservingRedirectHandler(HttpMessageHandler innerHandler) } } +#if NET5_0 + protected override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken) + { + var response = base.Send(request, cancellationToken); + + if (IsRedirect(response.StatusCode)) + { + response = base.Send(CopyRequest(response), cancellationToken); + } + + return response; + } +#endif + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { var tcs = new TaskCompletionSource(); @@ -56,28 +71,9 @@ protected override Task SendAsync(HttpRequestMessage reques }; } - if (response.StatusCode == HttpStatusCode.MovedPermanently - || response.StatusCode == HttpStatusCode.Moved - || response.StatusCode == HttpStatusCode.Redirect - || response.StatusCode == HttpStatusCode.Found - || response.StatusCode == HttpStatusCode.SeeOther - || response.StatusCode == HttpStatusCode.RedirectKeepVerb - || response.StatusCode == HttpStatusCode.TemporaryRedirect - || (int)response.StatusCode == 308) + if (IsRedirect(response.StatusCode)) { - var newRequest = CopyRequest(response); - - if (response.StatusCode == HttpStatusCode.Redirect - || response.StatusCode == HttpStatusCode.Found - || response.StatusCode == HttpStatusCode.SeeOther) - { - newRequest.Content = null; - newRequest.Method = HttpMethod.Get; - } - - newRequest.RequestUri = response.Headers.Location; - - base.SendAsync(newRequest, cancellationToken) + base.SendAsync(CopyRequest(response), cancellationToken) .ContinueWith(t2 => tcs.SetResult(t2.Result), cancellationToken); } else @@ -91,7 +87,7 @@ protected override Task SendAsync(HttpRequestMessage reques private static HttpRequestMessage CopyRequest(HttpResponseMessage response) { - var oldRequest = response.RequestMessage; + var oldRequest = response.RequestMessage ?? throw new NullReferenceException("Request Message not found"); var newRequest = new HttpRequestMessage(oldRequest.Method, oldRequest.RequestUri); @@ -116,7 +112,7 @@ private static HttpRequestMessage CopyRequest(HttpResponseMessage response) #if NET5_0 foreach (var (key, value) in oldRequest.Options) { - if (value == null || !(value is string s)) + if (!(value is string s)) { continue; } @@ -130,9 +126,7 @@ private static HttpRequestMessage CopyRequest(HttpResponseMessage response) } #endif - if (response.StatusCode == HttpStatusCode.Redirect - || response.StatusCode == HttpStatusCode.Found - || response.StatusCode == HttpStatusCode.SeeOther) + if (AlterMethod(response.StatusCode)) { newRequest.Content = null; newRequest.Method = HttpMethod.Get; @@ -144,5 +138,29 @@ private static HttpRequestMessage CopyRequest(HttpResponseMessage response) return newRequest; } + + #region Switches + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsRedirect(HttpStatusCode code) => code switch + { + HttpStatusCode.Moved => true, + HttpStatusCode.Redirect => true, + HttpStatusCode.SeeOther => true, + HttpStatusCode.RedirectKeepVerb => true, + + _ => false + }; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool AlterMethod(HttpStatusCode code) => code switch + { + HttpStatusCode.Redirect => true, + HttpStatusCode.SeeOther => true, + + _ => false + }; + + #endregion } }