Skip to content

Commit

Permalink
grpc/orleans tests WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
kostapetan committed Dec 6, 2024
1 parent 680c3d0 commit 851f8c7
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentStateGrain.cs

using System.Data;
using Microsoft.AutoGen.Abstractions;

namespace Microsoft.AutoGen.Runtime.Grpc;

internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState<AgentState> state) : Grain
internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState<AgentState> state) : Grain, IAgentGrain
{
/// <inheritdoc />
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag/*, CancellationToken cancellationToken = default*/)
Expand All @@ -22,7 +23,7 @@ public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag/
else
{
//TODO - this is probably not the correct behavior to just throw - I presume we want to somehow let the caller know that the state has changed and they need to re-read it
throw new ArgumentException(
throw new DBConcurrencyException(
"The provided ETag does not match the current ETag. The state has been modified by another request.");
}
return state.Etag;
Expand All @@ -34,3 +35,9 @@ public ValueTask<AgentState> ReadStateAsync(/*CancellationToken cancellationToke
return ValueTask.FromResult(state.State);
}
}

internal interface IAgentGrain : IGrainWithStringKey
{
ValueTask<AgentState> ReadStateAsync();
ValueTask<string> WriteStateAsync(AgentState state, string eTag);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ public sealed class GrpcGateway : BackgroundService, IGateway
private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(30);
private readonly ILogger<GrpcGateway> _logger;
private readonly IClusterClient _clusterClient;
private readonly ConcurrentDictionary<string, AgentState> _agentState = new();
private readonly IRegistryGrain _gatewayRegistry;
private readonly IGateway _reference;

Expand Down Expand Up @@ -60,12 +59,18 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
}
}

internal Task ConnectToWorkerProcess(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
internal async Task ConnectToWorkerProcess(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
_logger.LogInformation("Received new connection from {Peer}.", context.Peer);
var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context);
_workers[workerProcess] = workerProcess;
return workerProcess.Completion;
var completion = new TaskCompletionSource<Task>();
var _ = Task.Run(() =>
{
completion.SetResult(workerProcess.Connect());
});

await completion.Task;
}

public async ValueTask BroadcastEvent(CloudEvent evt)
Expand Down Expand Up @@ -172,18 +177,16 @@ private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection,

public async ValueTask StoreAsync(AgentState value)
{
var agentId = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId));
_agentState[agentId.Key] = value;
var agentState = _clusterClient.GetGrain<IAgentGrain>($"{value.AgentId.Type}:{value.AgentId.Key}");
await agentState.WriteStateAsync(value, value.ETag);
}

public async ValueTask<AgentState> ReadAsync(AgentId agentId)
{
if (_agentState.TryGetValue(agentId.Key, out var state))
{
return state;
}
return new AgentState { AgentId = agentId };
var agentState = _clusterClient.GetGrain<IAgentGrain>($"{agentId.Type}:{agentId.Key}");
return await agentState.ReadStateAsync();
}

internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess)
{
_workers.TryRemove(workerProcess, out _);
Expand All @@ -208,7 +211,7 @@ internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess)
public async ValueTask<RpcResponse> InvokeRequest(RpcRequest request, CancellationToken cancellationToken = default)
{
var agentId = (request.Target.Type, request.Target.Key);
if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted)
if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion?.IsCompleted == true)
{
// Activate the agent on a compatible worker process.
if (_supportedAgentTypes.TryGetValue(request.Target.Type, out var workers))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ namespace Microsoft.AutoGen.Runtime.Grpc;
internal sealed class GrpcWorkerConnection : IAsyncDisposable
{
private static long s_nextConnectionId;
private readonly Task _readTask;
private readonly Task _writeTask;
private Task? _readTask;
private Task? _writeTask;
private readonly string _connectionId = Interlocked.Increment(ref s_nextConnectionId).ToString();
private readonly object _lock = new();
private readonly HashSet<string> _supportedTypes = [];
private readonly GrpcGateway _gateway;
private readonly CancellationTokenSource _shutdownCancellationToken = new();
public Task? Completion { get; private set; }

public GrpcWorkerConnection(GrpcGateway agentWorker, IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
Expand All @@ -25,7 +26,10 @@ public GrpcWorkerConnection(GrpcGateway agentWorker, IAsyncStreamReader<Message>
ResponseStream = responseStream;
ServerCallContext = context;
_outboundMessages = Channel.CreateUnbounded<Message>(new UnboundedChannelOptions { AllowSynchronousContinuations = true, SingleReader = true, SingleWriter = false });
}

public Task Connect()
{
var didSuppress = false;
if (!ExecutionContext.IsFlowSuppressed())
{
Expand All @@ -46,7 +50,7 @@ public GrpcWorkerConnection(GrpcGateway agentWorker, IAsyncStreamReader<Message>
}
}

Completion = Task.WhenAll(_readTask, _writeTask);
return Completion = Task.WhenAll(_readTask, _writeTask);
}

public IAsyncStreamReader<Message> RequestStream { get; }
Expand Down Expand Up @@ -76,8 +80,6 @@ public async Task SendMessage(Message message)
await _outboundMessages.Writer.WriteAsync(message).ConfigureAwait(false);
}

public Task Completion { get; }

public async Task RunReadPump()
{
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
Expand Down Expand Up @@ -122,7 +124,10 @@ public async Task RunWritePump()
public async ValueTask DisposeAsync()
{
_shutdownCancellationToken.Cancel();
await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
if (Completion is not null)
{
await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
}
}

public override string ToString() => $"Connection-{_connectionId}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,20 @@ public async Task Test_OpenChannel()
var requestStream = new TestAsyncStreamReader<Message>(callContext);
var responseStream = new TestServerStreamWriter<Message>(callContext);

await service.RegisterAgent(new RegisterAgentTypeRequest { Type = nameof(PBAgent), RequestId = $"{Guid.NewGuid()}", Events = { "", "" } }, callContext);
await service.RegisterAgent(new RegisterAgentTypeRequest { Type = nameof(GMAgent), RequestId = $"{Guid.NewGuid()}", Events = { "", "" } }, callContext);

await service.OpenChannel(requestStream, responseStream, callContext);

requestStream.AddMessage(new Message { });
var bgEvent = new CloudEvent
{
Id = "1",
Source = "gh/repo/1",
Type = "test",

};

requestStream.AddMessage(new Message { CloudEvent = bgEvent });

requestStream.Complete();

Expand Down Expand Up @@ -62,7 +73,7 @@ public async Task Test_GetState()
var service = new GrpcGatewayService(gateway);
var callContext = TestServerCallContext.Create();

var response = await service.GetState(new AgentId { }, callContext);
var response = await service.GetState(new AgentId { Key = "", Type = "" }, callContext);

response.Should().NotBeNull();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public async Task TestBroadcastEvent()

// 1. Register Agent
// 2. Broadcast Event
// 3.
// 3.

await gateway.BroadcastEvent(evt);
//var registry = fixture.Registry;
//var subscriptions = fixture.Subscriptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ public void Configure(ISiloBuilder siloBuilder)
{
siloBuilder.ConfigureServices(services =>
{
services.AddSerializer(a=> a.AddProtobufSerializer());
services.AddSerializer(a => a.AddProtobufSerializer());
});
siloBuilder.AddMemoryStreams("StreamProvider")
.AddMemoryGrainStorage("PubSubStore")
.AddMemoryGrainStorage("AgentStateStore");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// AgentStateSurrogate.cs

using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Microsoft.AutoGen.Abstractions;

namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans.Surrogates;
Expand All @@ -21,7 +20,7 @@ public struct AgentStateSurrogate
[Id(4)]
public string Etag;
[Id(5)]
public Any ProtoData;
public ByteString ProtoData;
}

[RegisterConverter]
Expand All @@ -35,7 +34,7 @@ public AgentState ConvertFromSurrogate(
TextData = surrogate.TextData,
BinaryData = surrogate.BinaryData,
AgentId = surrogate.AgentId,
ProtoData = surrogate.ProtoData,
// ProtoData = surrogate.ProtoData,
ETag = surrogate.Etag
};

Expand All @@ -47,7 +46,7 @@ public AgentStateSurrogate ConvertToSurrogate(
BinaryData = value.BinaryData,
TextData = value.TextData,
Etag = value.ETag,
ProtoData = value.ProtoData
ProtoData = value.ProtoData.Value
};
}

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Microsoft.AutoGen.Core\Microsoft.AutoGen.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Microsoft.AutoGen.Runtime.Grpc\Microsoft.AutoGen.Runtime.Grpc.csproj" />
<ProjectReference Include="..\Microsoft.Autogen.Tests.Shared\Microsoft.Autogen.Tests.Shared.csproj" />
</ItemGroup>
Expand Down
44 changes: 44 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TestAgent.cs

using System.Collections.Concurrent;
using Microsoft.AutoGen.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Tests.Events;

namespace Microsoft.AutoGen.Runtime.Grpc.Tests;

public class PBAgent([FromKeyedServices("EventTypes")] EventTypes eventTypes, ILogger<Agent>? logger = null)
: Agent(eventTypes, logger)
, IHandle<NewMessageReceived>
, IHandle<GoodBye>
{
public async Task Handle(NewMessageReceived item, CancellationToken cancellationToken = default)
{
ReceivedMessages[AgentId.Key] = item.Message;
var hello = new Hello { Message = item.Message };
await PublishEventAsync(hello);
}
public Task Handle(GoodBye item, CancellationToken cancellationToken)
{
_logger.LogInformation($"Received GoodBye message {item.Message}");
return Task.CompletedTask;
}

public static ConcurrentDictionary<string, object> ReceivedMessages { get; private set; } = new();
}

public class GMAgent([FromKeyedServices("EventTypes")] EventTypes eventTypes, ILogger<Agent>? logger = null)
: Agent(eventTypes, logger)
, IHandle<Hello>
{
public async Task Handle(Hello item, CancellationToken cancellationToken)
{
_logger.LogInformation($"Received Hello message {item.Message}");
ReceivedMessages[AgentId.Key] = item.Message;
await PublishEventAsync(new GoodBye { Message = "" });
}

public static ConcurrentDictionary<string, object> ReceivedMessages { get; private set; } = new();
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ message TextMessage {
string message = 1;
string source = 2;
}
message Input {
message Hello {
string message = 1;
}
message InputProcessed {
Expand Down

0 comments on commit 851f8c7

Please sign in to comment.