Skip to content

Commit

Permalink
Test same server reroute
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 27, 2024
1 parent a789992 commit f65f6c8
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 12 deletions.
4 changes: 2 additions & 2 deletions samples/ChatSample/ChatSample.Net70/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
builder.Services.AddSignalR(o =>
{
o.MaximumParallelInvocationsPerClient = 2;
}).AddAzureSignalR("<connection-string>");
}).AddAzureSignalR("Endpoint=http://localhost:8080;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGH;Version=1.0;");

var app = builder.Build();

Expand All @@ -20,7 +20,7 @@
app.UseHsts();
}

app.UseHttpsRedirection();
// app.UseHttpsRedirection();
app.UseStaticFiles();

app.UseRouting();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

namespace Microsoft.Azure.SignalR;

#nullable enable

internal interface IServiceMessageHandler
{
Task HandlePingAsync(PingMessage pingMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
</PropertyGroup>

<PropertyGroup>
<TargetFrameworks>netstandard2.0;net6.0;net8.0;net462;</TargetFrameworks>
<TargetFrameworks>netstandard2.0;net6.0;net8.0;</TargetFrameworks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

Expand All @@ -18,7 +18,7 @@
<FrameworkReference Include="Microsoft.AspNetCore.App" />
</ItemGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard2.0'">
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.SignalR" Version="$(MicrosoftAspNetCoreSignalRPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Localization" Version="$(MicrosoftAspNetCoreLocalizationPackageVersion)" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ internal class StrongServiceConnectionContainer : ServiceConnectionContainerBase
{
private readonly int? _maxConnectionCount;

public StrongServiceConnectionContainer(
IServiceConnectionFactory serviceConnectionFactory,
int fixedConnectionCount,
int? maxConnectionCount,
HubServiceEndpoint endpoint,
ILogger logger) : base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger)
public StrongServiceConnectionContainer(IServiceConnectionFactory serviceConnectionFactory,
int fixedConnectionCount,
int? maxConnectionCount,
HubServiceEndpoint endpoint,
ILogger logger) : base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger)
{
_maxConnectionCount = maxConnectionCount.HasValue ? (maxConnectionCount.Value > fixedConnectionCount ? maxConnectionCount.Value : fixedConnectionCount) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ internal class WeakServiceConnectionContainer : ServiceConnectionContainerBase
protected override ServiceConnectionType InitialConnectionType => ServiceConnectionType.Weak;

public WeakServiceConnectionContainer(IServiceConnectionFactory serviceConnectionFactory,
int fixedConnectionCount, HubServiceEndpoint endpoint, ILogger logger)
int fixedConnectionCount,
HubServiceEndpoint endpoint,
ILogger logger)
: base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger)
{
}
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public bool TryParseMessage(ref ReadOnlySequence<byte> input, out ServiceMessage

var messageType = reader.ReadInt32();

if (messageType == ServiceProtocolConstants.ConnectionFlowControlMessageType)
{
Console.WriteLine("123");
}

switch (messageType)
{
case ServiceProtocolConstants.HandshakeRequestType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ public ServiceConnectionFactory(
_hubProtocolResolver = hubProtocolResolver;
}

public virtual IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type)
public virtual IServiceConnection Create(HubServiceEndpoint endpoint,
IServiceMessageHandler serviceMessageHandler,
AckHandler ackHandler,
ServiceConnectionType type)
{
return new ServiceConnection(
_serviceProtocol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

<ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="Microsoft.AspNetCore.Http.Connections.Client" Version="8.0.8" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Client.Core" Version="8.0.8" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.TestHost" Version="$(MicrosoftAspNetCoreTestHostPackageVersion)" />
<PackageReference Include="Moq" Version="$(MoqPackageVersion)" />
Expand Down
125 changes: 125 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/SameServerRerouteTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Azure.SignalR.Tests;

#nullable enable

public class SameServerRerouteTests : VerifiableLoggedTest
{
public SameServerRerouteTests(ITestOutputHelper output) : base(output)
{
}

[Fact]
public async Task Test()
{
var hub = "foo";
var connectionString = "Endpoint=http://localhost:8080;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGH;Version=1.0;";

using var provider = StartVerifiableLog(out var loggerFactory);

var nameProvider = new DefaultServerNameProvider();
var logger = loggerFactory.CreateLogger<DefaultHubProtocolResolver>();
var hubProtocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[] { new JsonHubProtocol(), new MessagePackHubProtocol() }, logger);

var handler = new EndlessConnectionHandler<FooHub>(loggerFactory);
ConnectionDelegate connectionDelegate = handler.OnConnectedAsync;

var factory = new ServiceConnectionFactory(new ServiceProtocol(),
new ClientConnectionManager(),
new ConnectionFactory(nameProvider, loggerFactory),
loggerFactory,
connectionDelegate,
new ClientConnectionFactory(loggerFactory),
nameProvider,
new DefaultServiceEventHandler(loggerFactory),
new DummyClientInvocationManager(),
hubProtocolResolver);

var serviceEndpoint = new ServiceEndpoint(connectionString);
var serviceEndpointProvider = new ServiceEndpointProvider(serviceEndpoint, new ServiceOptions());
var hubServiceEndpoint = new HubServiceEndpoint(hub, serviceEndpointProvider, serviceEndpoint);

var messageHandler = new StrongServiceConnectionContainer(factory, 1, 1, hubServiceEndpoint, logger);
var ackHandler = new AckHandler();

var connection1 = factory.Create(hubServiceEndpoint, messageHandler, ackHandler, ServiceConnectionType.Default);
var connection2 = factory.Create(hubServiceEndpoint, messageHandler, ackHandler, ServiceConnectionType.Default);

_ = connection1.StartAsync();
_ = connection2.StartAsync();

await Task.Delay(1000);

var audience = $"http://localhost/client/?hub={hub}";
var clientPath = $"http://localhost:8080/client/?hub={hub}";
var accessKey = ConnectionStringParser.Parse(connectionString).AccessKey;
var accessToken = await accessKey.GenerateAccessTokenAsync(audience, Array.Empty<Claim>(), TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);

var connectionOptions = new HttpConnectionOptions();
connectionOptions.Headers.Add("Authorization", $"Bearer {accessToken}");
var connectionFactory = new HttpConnectionFactory(Options.Create(connectionOptions), loggerFactory);
var endpoint = new UriEndPoint(new Uri(clientPath));
var serviceProvider = new ServiceCollection().BuildServiceProvider();

var hubConnection = new HubConnection(connectionFactory, new JsonHubProtocol(), endpoint, serviceProvider, loggerFactory);
_ = hubConnection.StartAsync();

while (true)
{
await hubConnection.SendAsync("foo");
await Task.Delay(1000);
}
}

private class FooHub : Hub
{
}

private sealed class EndlessConnectionHandler<THub> : ConnectionHandler where THub : Hub
{
private readonly HubLifetimeManager<THub> _hubLifetimeManager;

private readonly ILoggerFactory _loggerFactory;

public EndlessConnectionHandler(ILoggerFactory loggerFactory)
{
_loggerFactory = loggerFactory;
var logger = loggerFactory.CreateLogger<DefaultHubLifetimeManager<THub>>();
_hubLifetimeManager = new DefaultHubLifetimeManager<THub>(logger);
}

public override async Task OnConnectedAsync(ConnectionContext connection)
{
var hubConnectionContext = new EndlessHubConnectionContext(connection, new HubConnectionContextOptions(), _loggerFactory);
await _hubLifetimeManager.OnConnectedAsync(hubConnectionContext);
}
}

private sealed class EndlessHubConnectionContext : HubConnectionContext
{
public EndlessHubConnectionContext(ConnectionContext connectionContext,
HubConnectionContextOptions contextOptions,
ILoggerFactory loggerFactory) : base(connectionContext, contextOptions, loggerFactory)
{
}
}
}

0 comments on commit f65f6c8

Please sign in to comment.