Skip to content

Commit

Permalink
prep for Runtime tests WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
kostapetan committed Dec 4, 2024
1 parent 23142d7 commit 280f213
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 98 deletions.
1 change: 1 addition & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
<PackageVersion Include="Microsoft.Orleans.Server" Version="9.0.1" />
<PackageVersion Include="Microsoft.Orleans.Streaming" Version="9.0.1" />
<PackageVersion Include="Microsoft.Orleans.Streaming.EventHubs" Version="9.0.1" />
<PackageVersion Include="Microsoft.Orleans.TestingHost" Version="9.0.1" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="1.29.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Agents.Core" Version="$(MicrosoftSemanticKernelExperimentalVersion)" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.AzureOpenAI" Version="1.29.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,14 @@ private async ValueTask RegisterAgentTypeAsync(string type, Type agentType, Canc

// TODO: Reimplement registration as RPC call
_logger.LogInformation($"{cancellationToken.ToString}"); // TODO: remove this
//await WriteChannelAsync(new Message
//{
// RegisterAgentTypeRequest = new RegisterAgentTypeRequest
// {
// Type = type,
// RequestId = Guid.NewGuid().ToString(),
// //TopicTypes = { topicTypes },
// //StateType = state?.Name,
// Events = { events }
// }
//}, cancellationToken).ConfigureAwait(false);
var response = await _client.RegisterAgentAsync(new RegisterAgentTypeRequest
{
Type = type,
RequestId = Guid.NewGuid().ToString(),
//TopicTypes = { topicTypes },
//StateType = state?.Name,
Events = { events }
}, null, null, cancellationToken);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public sealed class GrpcGateway : BackgroundService, IGateway
private readonly ConcurrentDictionary<string, AgentState> _agentState = new();
private readonly IRegistryGrain _gatewayRegistry;
private readonly IGateway _reference;

// The agents supported by each worker process.
private readonly ConcurrentDictionary<string, List<GrpcWorkerConnection>> _supportedAgentTypes = [];
internal readonly ConcurrentDictionary<GrpcWorkerConnection, GrpcWorkerConnection> _workers = new();
Expand All @@ -35,33 +36,6 @@ public GrpcGateway(IClusterClient clusterClient, ILogger<GrpcGateway> logger)
_reference = clusterClient.CreateObjectReference<IGateway>(this);
_gatewayRegistry = clusterClient.GetGrain<IRegistryGrain>(0);
}
public async ValueTask BroadcastEvent(CloudEvent evt)
{
var tasks = new List<Task>();
foreach (var (key, connection) in _supportedAgentTypes)
{
if (_agentsToEventsMap.TryGetValue(key, out var events) && events.Contains(evt.Type))
{
tasks.Add(SendMessageAsync(connection[0], evt, default));
}
}
await Task.WhenAll(tasks).ConfigureAwait(false);
}
//intetionally not static so can be called by some methods implemented in base class
internal async Task SendMessageAsync(GrpcWorkerConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default)
{
await connection.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false);
}
private void DispatchResponse(GrpcWorkerConnection connection, RpcResponse response)
{
if (!_pendingRequests.TryRemove((connection, response.RequestId), out var completion))
{
_logger.LogWarning("Received response for unknown request.");
return;
}
// Complete the request.
completion.SetResult(response);
}
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
while (!stoppingToken.IsCancellationRequested)
Expand All @@ -85,7 +59,32 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
_logger.LogWarning(exception, "Error removing worker from registry.");
}
}
//new is intentional...

internal Task ConnectToWorkerProcess(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
_logger.LogInformation("Received new connection from {Peer}.", context.Peer);
var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context);
_workers[workerProcess] = workerProcess;
return workerProcess.Completion;
}

public async ValueTask BroadcastEvent(CloudEvent evt)
{
var tasks = new List<Task>();
foreach (var (key, connection) in _supportedAgentTypes)
{
if (_agentsToEventsMap.TryGetValue(key, out var events) && events.Contains(evt.Type))
{
tasks.Add(SendMessageAsync(connection[0], evt, default));
}
}
await Task.WhenAll(tasks).ConfigureAwait(false);
}
//intetionally not static so can be called by some methods implemented in base class
internal async Task SendMessageAsync(GrpcWorkerConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default)
{
await connection.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false);
}
internal async Task OnReceivedMessageAsync(GrpcWorkerConnection connection, Message message)
{
_logger.LogInformation("Received message {Message} from connection {Connection}.", message, connection);
Expand All @@ -101,32 +100,20 @@ internal async Task OnReceivedMessageAsync(GrpcWorkerConnection connection, Mess
await DispatchEventAsync(message.CloudEvent);
break;
default:
// if it wasn't recognized return bad request
await RespondBadRequestAsync(connection, $"Unknown message type for message '{message}'.");
break;
throw new RpcException(new Status(StatusCode.InvalidArgument, $"Unknown message type for message '{message}'."));
};
}
private async ValueTask RespondBadRequestAsync(GrpcWorkerConnection connection, string error)
private void DispatchResponse(GrpcWorkerConnection connection, RpcResponse response)
{
throw new RpcException(new Status(StatusCode.InvalidArgument, error));
if (!_pendingRequests.TryRemove((connection, response.RequestId), out var completion))
{
_logger.LogWarning("Received response for unknown request.");
return;
}
// Complete the request.
completion.SetResult(response);
}

private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg)
{
connection.AddSupportedType(msg.Type);
_supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection);
_agentsToEventsMap.TryAdd(msg.Type, new HashSet<string>(msg.Events));

await _gatewayRegistry.RegisterAgentType(msg.Type, _reference).ConfigureAwait(true);
}
private async ValueTask DispatchEventAsync(CloudEvent evt)
{
await BroadcastEvent(evt).ConfigureAwait(false);
/*
var topic = _clusterClient.GetStreamProvider("agents").GetStream<Event>(StreamId.Create(evt.Namespace, evt.Type));
await topic.OnNextAsync(evt.ToEvent());
*/
}
private async ValueTask DispatchRequestAsync(GrpcWorkerConnection connection, RpcRequest request)
{
var requestId = request.RequestId;
Expand All @@ -151,6 +138,24 @@ await InvokeRequestDelegate(connection, request, async request =>
//}
}

private async ValueTask DispatchEventAsync(CloudEvent evt)
{
await BroadcastEvent(evt).ConfigureAwait(false);
/*
var topic = _clusterClient.GetStreamProvider("agents").GetStream<Event>(StreamId.Create(evt.Namespace, evt.Type));
await topic.OnNextAsync(evt.ToEvent());
*/
}

private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg)
{
connection.AddSupportedType(msg.Type);
_supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection);
_agentsToEventsMap.TryAdd(msg.Type, new HashSet<string>(msg.Events));

await _gatewayRegistry.RegisterAgentType(msg.Type, _reference).ConfigureAwait(true);
}

private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, RpcRequest request, Func<RpcRequest, Task<RpcResponse>> func)
{
try
Expand All @@ -164,13 +169,7 @@ private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection,
await connection.ResponseStream.WriteAsync(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = ex.Message } }).ConfigureAwait(false);
}
}
internal Task ConnectToWorkerProcess(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
_logger.LogInformation("Received new connection from {Peer}.", context.Peer);
var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context);
_workers[workerProcess] = workerProcess;
return workerProcess.Completion;
}

public async ValueTask StoreAsync(AgentState value)
{
var agentId = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId));
Expand Down Expand Up @@ -205,6 +204,7 @@ internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess)
}
}
}

public async ValueTask<RpcResponse> InvokeRequest(RpcRequest request, CancellationToken cancellationToken = default)
{
var agentId = (request.Target.Type, request.Target.Key);
Expand Down Expand Up @@ -233,8 +233,13 @@ public async ValueTask<RpcResponse> InvokeRequest(RpcRequest request, Cancellati
return response;
}

async ValueTask<RpcResponse> IGateway.InvokeRequest(RpcRequest request)
public ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(RegisterAgentTypeRequest request)
{
throw new NotImplementedException();
}

public ValueTask<RpcResponse> InvokeRequest(RpcRequest request)
{
return await InvokeRequest(request).ConfigureAwait(false);
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,16 @@ public override async Task<SaveStateResponse> SaveState(AgentState request, Serv
Success = true // TODO: Implement error handling
};
}

public override Task<AddSubscriptionResponse> AddSubscription(AddSubscriptionRequest request, ServerCallContext context)
{
// TODO: This should map to Orleans Streaming explicit api
return base.AddSubscription(request, context);
}

public override async Task<RegisterAgentTypeResponse> RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context)
{
// TODO: This should add the agent to registry
return await Gateway.RegisterAgentTypeAsync(request);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentWorkerHostingExtensions.cs
// GrpcRuntimeHostingExtensions.cs

using System.Diagnostics;
using Microsoft.AspNetCore.Builder;
Expand All @@ -11,30 +11,28 @@

namespace Microsoft.AutoGen.Runtime.Grpc;

public static class AgentWorkerHostingExtensions
public static class GrpcRuntimeHostingExtensions
{
public static WebApplicationBuilder AddAgentService(this WebApplicationBuilder builder, bool inMemoryOrleans = false, bool useGrpc = true)
public static WebApplicationBuilder AddGrpcRuntime(this WebApplicationBuilder builder, bool inMemoryOrleans = false)
{
// TODO: allow for configuration of Orleans, for state and streams
builder.AddOrleans(inMemoryOrleans);

builder.Services.TryAddSingleton(DistributedContextPropagator.Current);

if (useGrpc)
builder.WebHost.ConfigureKestrel(serverOptions =>
{
builder.WebHost.ConfigureKestrel(serverOptions =>
// TODO: make port configurable
serverOptions.ListenAnyIP(5001, listenOptions =>
{
// TODO: make port configurable
serverOptions.ListenAnyIP(5001, listenOptions =>
{
listenOptions.Protocols = HttpProtocols.Http2;
listenOptions.UseHttps();
});
listenOptions.Protocols = HttpProtocols.Http2;
listenOptions.UseHttps();
});
});

builder.Services.AddGrpc();
builder.Services.AddSingleton<GrpcGateway>();
builder.Services.AddSingleton(sp => (IHostedService)sp.GetRequiredService<GrpcGateway>());
}
builder.Services.AddGrpc();
builder.Services.AddSingleton<GrpcGateway>();
builder.Services.AddSingleton(sp => (IHostedService)sp.GetRequiredService<GrpcGateway>());

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ public interface IGateway : IGrainObserver
ValueTask BroadcastEvent(CloudEvent evt);
ValueTask StoreAsync(AgentState value);
ValueTask<AgentState> ReadAsync(AgentId agentId);
ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(RegisterAgentTypeRequest request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

var builder = WebApplication.CreateBuilder(args);

builder.AddAgentService(inMemoryOrleans: true, useGrpc: true);
builder.AddGrpcRuntime(inMemoryOrleans: true);

var app = builder.Build();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// HandleInterfaceTest.cs
// HandleInterfaceTests.cs

using FluentAssertions;
using Tests.Events;
using Xunit;

namespace Microsoft.AutoGen.Core.Tests;

public class HandleInterfaceTest
public class HandleInterfaceTests
{
[Fact]
public void Can_Get_Handler_Methods_From_Class()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using Xunit;

namespace Microsoft.AutoGen.Core.Tests;
namespace Microsoft.AutoGen.Core.Tests.Helpers;

[CollectionDefinition(Name)]
public sealed class ClusterFixtureCollection : ICollectionFixture<InMemoryAgentRuntimeFixture>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using Microsoft.Extensions.Logging;
using Moq;

namespace Microsoft.AutoGen.Core.Tests;
namespace Microsoft.AutoGen.Core.Tests.Helpers;

public sealed class InMemoryAgentRuntimeFixture : IDisposable
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using FluentAssertions;
using Google.Protobuf.Reflection;
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Core.Tests.Helpers;
using Microsoft.Extensions.DependencyInjection;
using Tests.Events;
using Xunit;
Expand Down
9 changes: 0 additions & 9 deletions dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Class1.cs

This file was deleted.

Loading

0 comments on commit 280f213

Please sign in to comment.