diff --git a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs index 5a3f71ecde0d..427eb4cf0c67 100644 --- a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs @@ -162,20 +162,35 @@ private async ValueTask DispatchEventAsync(CloudEvent evt) public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request) { - var connection = _workersByConnection[request.RequestId]; - connection.AddSupportedType(request.Type); - _supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection); - _agentsToEventsMap.TryAdd(request.Type, new HashSet(request.Events)); + try + { + var connection = _workersByConnection[request.RequestId]; + connection.AddSupportedType(request.Type); + _supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection); + _agentsToEventsMap.TryAdd(request.Type, new HashSet(request.Events)); - await _gatewayRegistry.RegisterAgentType(request.Type, _reference).ConfigureAwait(true); - var response = new RegisterAgentTypeResponse { - }; - return response; + await _gatewayRegistry.RegisterAgentType(request.Type, _reference).ConfigureAwait(true); + return new RegisterAgentTypeResponse + { + Success = true, + RequestId = request.RequestId + }; + } + catch (Exception ex) + { + return new RegisterAgentTypeResponse + { + Success = false, + RequestId = request.RequestId, + Error = ex.Message + }; + } } + // TODO: consider adding this back for backwards compatibility //private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg) //{ - + //} private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, RpcRequest request, Func> func) diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/HandleInterfaceTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/HandleInterfaceTests.cs index 8cf74a74bf3b..862a42e44ed9 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/HandleInterfaceTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/HandleInterfaceTests.cs @@ -30,7 +30,7 @@ public void Handlers_Supports_Cancellation() } [Fact] - public void Can_build_handlers_lookup() + public void Can_Build_Handlers_Lookup() { var source = typeof(TestAgent); var handlers = source.GetHandlersLookupTable(); diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/Helpers/ClusterFixtureCollection.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/Helpers/ClusterFixtureCollection.cs index 5f643410d8be..5d80c25a11be 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/Helpers/ClusterFixtureCollection.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/Helpers/ClusterFixtureCollection.cs @@ -6,7 +6,7 @@ namespace Microsoft.AutoGen.Core.Tests.Helpers; [CollectionDefinition(Name)] -public sealed class ClusterFixtureCollection : ICollectionFixture +public sealed class CoreFixtureCollection : ICollectionFixture { - public const string Name = nameof(ClusterFixtureCollection); + public const string Name = nameof(CoreFixtureCollection); } diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/InMemoryAgentTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/InMemoryAgentTests.cs index 7e71342fc71d..564b60b3ec05 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/InMemoryAgentTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/InMemoryAgentTests.cs @@ -11,7 +11,7 @@ namespace Microsoft.AutoGen.Core.Tests; -[Collection(ClusterFixtureCollection.Name)] +[Collection(CoreFixtureCollection.Name)] public partial class InMemoryAgentTests(InMemoryAgentRuntimeFixture fixture) { private readonly InMemoryAgentRuntimeFixture _fixture = fixture; diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs index a303d175ef7a..76110d1c29a6 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs @@ -90,21 +90,39 @@ public async Task Test_Message_Goes_To_Right_Worker() } - private RegisterAgentTypeRequest CreateRegistrationRequest(EventTypes eventTypes, Type type, string requestId) + [Fact] + public async Task Test_RegisterAgent_Should_Succeed() { - var registration = new RegisterAgentTypeRequest - { - Type = type.Name, - RequestId = requestId - }; - registration.Events.AddRange(eventTypes.GetEventsForAgent(type)?.ToList()); + var logger = Mock.Of>(); + var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); + var service = new GrpcGatewayService(gateway); + using var client = new TestGrpcClient(); - return registration; + var assembly = typeof(PBAgent).Assembly; + var eventTypes = ReflectionHelper.GetAgentsMetadata(assembly); + + await service.OpenChannel(client.RequestStream, client.ResponseStream, client.CallContext); + var responseMessage = await client.ReadNext(); + + var connectionId = responseMessage!.Response.RequestId; + + var response = await service.RegisterAgent(CreateRegistrationRequest(eventTypes, typeof(PBAgent), connectionId), client.CallContext); + response.Success.Should().BeTrue(); } - private string GetFullName(Type type) + [Fact] + public async Task Test_RegisterAgent_Should_Fail_For_Wrong_ConnectionId() { - return ReflectionHelper.GetMessageDescriptor(type)!.FullName; + var logger = Mock.Of>(); + var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); + var service = new GrpcGatewayService(gateway); + using var client = new TestGrpcClient(); + + var assembly = typeof(PBAgent).Assembly; + var eventTypes = ReflectionHelper.GetAgentsMetadata(assembly); + + var response = await service.RegisterAgent(CreateRegistrationRequest(eventTypes, typeof(PBAgent), "faulty_connection_id"), client.CallContext); + response.Success.Should().BeFalse(); } [Fact] @@ -132,4 +150,21 @@ public async Task Test_GetState() response.Should().NotBeNull(); } + + private RegisterAgentTypeRequest CreateRegistrationRequest(EventTypes eventTypes, Type type, string requestId) + { + var registration = new RegisterAgentTypeRequest + { + Type = type.Name, + RequestId = requestId + }; + registration.Events.AddRange(eventTypes.GetEventsForAgent(type)?.ToList()); + + return registration; + } + + private string GetFullName(Type type) + { + return ReflectionHelper.GetMessageDescriptor(type)!.FullName; + } } diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs index f839b318b85f..09017c79af17 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs @@ -4,7 +4,7 @@ using Microsoft.AutoGen.Abstractions; namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc; -internal class TestGrpcClient: IDisposable +internal sealed class TestGrpcClient: IDisposable { public TestAsyncStreamReader RequestStream { get; } public TestServerStreamWriter ResponseStream { get; }