Skip to content

Commit

Permalink
grpc testing WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
kostapetan committed Dec 6, 2024
1 parent 851f8c7 commit e47ad00
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,14 @@ public bool CheckIfTypeHandles(Type type, string eventName)
}
return null;
}

public HashSet<string>? GetEventsForAgent(Type agent)
{
if (_eventsMap.TryGetValue(agent, out var events))
{
return events;
}
return null;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public static EventTypes GetAgentsMetadata(params Assembly[] assemblies)

var eventsMap = assemblies
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(Agent)) && !type.IsAbstract)
.Where(type => IsSubclassOfGeneric(type, typeof(Agent)) && !type.IsAbstract)
.Select(t => (t, t.GetInterfaces()
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "").ToHashSet()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ public sealed class GrpcGateway : BackgroundService, IGateway
// The agents supported by each worker process.
private readonly ConcurrentDictionary<string, List<GrpcWorkerConnection>> _supportedAgentTypes = [];
internal readonly ConcurrentDictionary<GrpcWorkerConnection, GrpcWorkerConnection> _workers = new();
internal readonly ConcurrentDictionary<string, GrpcWorkerConnection> _workersByConnection = new();
private readonly ConcurrentDictionary<string, HashSet<string>> _agentsToEventsMap = new();

// The mapping from agent id to worker process.
private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new();
// RPC
private readonly ConcurrentDictionary<(GrpcWorkerConnection, string), TaskCompletionSource<RpcResponse>> _pendingRequests = new();

public int WorkersCount => _workers.Count;

// InMemory Message Queue

public GrpcGateway(IClusterClient clusterClient, ILogger<GrpcGateway> logger)
Expand Down Expand Up @@ -59,18 +63,22 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
}
}

internal async Task ConnectToWorkerProcess(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
internal async Task<string> ConnectToWorkerProcess(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
_logger.LogInformation("Received new connection from {Peer}.", context.Peer);
var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context);
var connectionId = Guid.NewGuid().ToString();
_workers[workerProcess] = workerProcess;
_workersByConnection[connectionId] = workerProcess;

var completion = new TaskCompletionSource<Task>();
var _ = Task.Run(() =>
{
completion.SetResult(workerProcess.Connect());
});

await completion.Task;
return connectionId;
}

public async ValueTask BroadcastEvent(CloudEvent evt)
Expand Down Expand Up @@ -152,15 +160,24 @@ private async ValueTask DispatchEventAsync(CloudEvent evt)
*/
}

private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg)
public async ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(RegisterAgentTypeRequest request)
{
connection.AddSupportedType(msg.Type);
_supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection);
_agentsToEventsMap.TryAdd(msg.Type, new HashSet<string>(msg.Events));
var connection = _workersByConnection[request.RequestId];
connection.AddSupportedType(request.Type);
_supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection);
_agentsToEventsMap.TryAdd(request.Type, new HashSet<string>(request.Events));

await _gatewayRegistry.RegisterAgentType(msg.Type, _reference).ConfigureAwait(true);
await _gatewayRegistry.RegisterAgentType(request.Type, _reference).ConfigureAwait(true);
var response = new RegisterAgentTypeResponse {
};
return response;
}

//private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg)
//{

//}

private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, RpcRequest request, Func<RpcRequest, Task<RpcResponse>> func)
{
try
Expand Down Expand Up @@ -236,10 +253,7 @@ public async ValueTask<RpcResponse> InvokeRequest(RpcRequest request, Cancellati
return response;
}

public async ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(RegisterAgentTypeRequest request)
{
return new RegisterAgentTypeResponse();
}


public ValueTask<RpcResponse> InvokeRequest(RpcRequest request)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public override async Task OpenChannel(IAsyncStreamReader<Message> requestStream
{
try
{
await Gateway.ConnectToWorkerProcess(requestStream, responseStream, context).ConfigureAwait(true);
var connectionId = await Gateway.ConnectToWorkerProcess(requestStream, responseStream, context).ConfigureAwait(true);
// Generate a response with the connectionId, to be used in subsequent RPC requests to the service
await responseStream.WriteAsync(new Message { Response = new RpcResponse { RequestId = connectionId } });
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

using FluentAssertions;
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Core;
using Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc;
using Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans;
using Microsoft.Extensions.Logging;
using Moq;
using Tests.Events;

namespace Microsoft.AutoGen.Runtime.Grpc.Tests;
[Collection(ClusterCollection.Name)]
Expand All @@ -26,30 +28,63 @@ public async Task Test_OpenChannel()
var gateway = new GrpcGateway(_fixture.Cluster.Client, logger);
var service = new GrpcGatewayService(gateway);
var callContext = TestServerCallContext.Create();
var requestStream = new TestAsyncStreamReader<Message>(callContext);
var responseStream = new TestServerStreamWriter<Message>(callContext);

await service.RegisterAgent(new RegisterAgentTypeRequest { Type = nameof(PBAgent), RequestId = $"{Guid.NewGuid()}", Events = { "", "" } }, callContext);
await service.RegisterAgent(new RegisterAgentTypeRequest { Type = nameof(GMAgent), RequestId = $"{Guid.NewGuid()}", Events = { "", "" } }, callContext);
using var requestStream = new TestAsyncStreamReader<Message>(callContext);
using var responseStream = new TestServerStreamWriter<Message>(callContext);

gateway.WorkersCount.Should().Be(0);
await service.OpenChannel(requestStream, responseStream, callContext);
gateway.WorkersCount.Should().Be(1);
}

var bgEvent = new CloudEvent
{
Id = "1",
Source = "gh/repo/1",
Type = "test",
[Fact]
public async Task Test_Message_Exchange_Through_Gateway()
{
var logger = Mock.Of<ILogger<GrpcGateway>>();
var gateway = new GrpcGateway(_fixture.Cluster.Client, logger);
var service = new GrpcGatewayService(gateway);
var callContext = TestServerCallContext.Create();

};
using var requestStream = new TestAsyncStreamReader<Message>(callContext);
using var responseStream = new TestServerStreamWriter<Message>(callContext);

var assembly = typeof(PBAgent).Assembly;
var eventTypes = ReflectionHelper.GetAgentsMetadata(assembly);

requestStream.AddMessage(new Message { CloudEvent = bgEvent });
await service.OpenChannel(requestStream, responseStream, callContext);
var responseMessage = await responseStream.ReadNextAsync();

requestStream.Complete();
await service.RegisterAgent(CreateRegistrationRequest(eventTypes, typeof(PBAgent), responseMessage!.Response.RequestId), callContext);
await service.RegisterAgent(CreateRegistrationRequest(eventTypes, typeof(GMAgent), responseMessage!.Response.RequestId), callContext);

responseStream.Complete();
var inputEvent = new NewMessageReceived { Message = "Hello" }.ToCloudEvent("gh-gh-gh");

var responseMessage = await responseStream.ReadNextAsync();
responseMessage.Should().NotBeNull();
requestStream.AddMessage(new Message { CloudEvent = inputEvent });
var newMessageReceived = await responseStream.ReadNextAsync();
newMessageReceived!.CloudEvent.Type.Should().Be(GetFullName(typeof(NewMessageReceived)));

// Simulate an agent, by publishing a new message in the request stream

var outputEvent = await responseStream.ReadNextAsync();
outputEvent!.CloudEvent.Type.Should().Be(GetFullName(typeof(Hello)));

}

private RegisterAgentTypeRequest CreateRegistrationRequest(EventTypes eventTypes, Type type, string requestId)
{
var registration = new RegisterAgentTypeRequest
{
Type = type.Name,
RequestId = requestId
};
registration.Events.AddRange(eventTypes.GetEventsForAgent(type)?.ToList());

return registration;
}

private string GetFullName(Type type)
{
return ReflectionHelper.GetMessageDescriptor(type)!.FullName;
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc;

public class TestAsyncStreamReader<T> : IAsyncStreamReader<T> where T : class
public class TestAsyncStreamReader<T> : IDisposable, IAsyncStreamReader<T>
where T : class
{
private readonly Channel<T> _channel;
private readonly ServerCallContext _serverCallContext;
Expand Down Expand Up @@ -60,4 +61,9 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
return false;
}
}

public void Dispose()
{
Complete();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc;

public class TestServerStreamWriter<T> : IServerStreamWriter<T> where T : class
public class TestServerStreamWriter<T> : IDisposable, IServerStreamWriter<T> where T : class
{
private readonly ServerCallContext _serverCallContext;
private readonly Channel<T> _channel;
Expand Down Expand Up @@ -78,4 +78,9 @@ public Task WriteAsync(T message)
{
return WriteAsync(message, CancellationToken.None);
}

public void Dispose()
{
Complete();
}
}

0 comments on commit e47ad00

Please sign in to comment.