Skip to content

Commit

Permalink
Test same server reroute
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 27, 2024
1 parent a789992 commit 58d4bf9
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

<ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="Microsoft.AspNetCore.Http.Connections.Client" Version="8.0.8" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Client.Core" Version="8.0.8" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.TestHost" Version="$(MicrosoftAspNetCoreTestHostPackageVersion)" />
<PackageReference Include="Moq" Version="$(MoqPackageVersion)" />
Expand Down
125 changes: 125 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/SameServerRerouteTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Azure.SignalR.Tests;

#nullable enable

public class SameServerRerouteTests : VerifiableLoggedTest
{
public SameServerRerouteTests(ITestOutputHelper output) : base(output)
{
}

[Fact]
public async Task Test()
{
var hub = "foo";
var connectionString = "Endpoint=http://localhost:8080;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGH;Version=1.0;";

using var provider = StartVerifiableLog(out var loggerFactory);

var nameProvider = new DefaultServerNameProvider();
var logger = loggerFactory.CreateLogger<DefaultHubProtocolResolver>();
var hubProtocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[] { new JsonHubProtocol(), new MessagePackHubProtocol() }, logger);

var handler = new EndlessConnectionHandler<FooHub>(loggerFactory);
ConnectionDelegate connectionDelegate = handler.OnConnectedAsync;

var factory = new ServiceConnectionFactory(new ServiceProtocol(),
new ClientConnectionManager(),
new ConnectionFactory(nameProvider, loggerFactory),
loggerFactory,
connectionDelegate,
new ClientConnectionFactory(loggerFactory),
nameProvider,
new DefaultServiceEventHandler(loggerFactory),
new DummyClientInvocationManager(),
hubProtocolResolver);

var serviceEndpoint = new ServiceEndpoint(connectionString);
var serviceEndpointProvider = new ServiceEndpointProvider(serviceEndpoint, new ServiceOptions());
var hubServiceEndpoint = new HubServiceEndpoint(hub, serviceEndpointProvider, serviceEndpoint);

var messageHandler = new StrongServiceConnectionContainer(factory, 1, 1, hubServiceEndpoint, logger);
var ackHandler = new AckHandler();

var connection1 = factory.Create(hubServiceEndpoint, messageHandler, ackHandler, ServiceConnectionType.Default);
var connection2 = factory.Create(hubServiceEndpoint, messageHandler, ackHandler, ServiceConnectionType.Default);

_ = connection1.StartAsync();
_ = connection2.StartAsync();

await Task.Delay(1000);

var audience = $"http://localhost/client/?hub={hub}";
var clientPath = $"http://localhost:8080/client/?hub={hub}";
var accessKey = ConnectionStringParser.Parse(connectionString).AccessKey;
var accessToken = await accessKey.GenerateAccessTokenAsync(audience, Array.Empty<Claim>(), TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);

var connectionOptions = new HttpConnectionOptions();
connectionOptions.Headers.Add("Authorization", $"Bearer {accessToken}");
var connectionFactory = new HttpConnectionFactory(Options.Create(connectionOptions), loggerFactory);
var endpoint = new UriEndPoint(new Uri(clientPath));
var serviceProvider = new ServiceCollection().BuildServiceProvider();

var hubConnection = new HubConnection(connectionFactory, new JsonHubProtocol(), endpoint, serviceProvider, loggerFactory);
_ = hubConnection.StartAsync();

while (true)
{
await hubConnection.SendAsync("foo");
await Task.Delay(1000);
}
}

private class FooHub : Hub
{
}

private sealed class EndlessConnectionHandler<THub> : ConnectionHandler where THub : Hub
{
private readonly HubLifetimeManager<THub> _hubLifetimeManager;

private readonly ILoggerFactory _loggerFactory;

public EndlessConnectionHandler(ILoggerFactory loggerFactory)
{
_loggerFactory = loggerFactory;
var logger = loggerFactory.CreateLogger<DefaultHubLifetimeManager<THub>>();
_hubLifetimeManager = new DefaultHubLifetimeManager<THub>(logger);
}

public override async Task OnConnectedAsync(ConnectionContext connection)
{
var hubConnectionContext = new EndlessHubConnectionContext(connection, new HubConnectionContextOptions(), _loggerFactory);
await _hubLifetimeManager.OnConnectedAsync(hubConnectionContext);
}
}

private sealed class EndlessHubConnectionContext : HubConnectionContext
{
public EndlessHubConnectionContext(ConnectionContext connectionContext,
HubConnectionContextOptions contextOptions,
ILoggerFactory loggerFactory) : base(connectionContext, contextOptions, loggerFactory)
{
}
}
}

0 comments on commit 58d4bf9

Please sign in to comment.