Skip to content

Commit

Permalink
Add event metadata to agent registration, filter events sent from run…
Browse files Browse the repository at this point in the history
…time
  • Loading branch information
kostapetan committed Nov 28, 2024
1 parent f70869f commit 2e1ee71
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 30 deletions.
56 changes: 32 additions & 24 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ protected AgentBase(
runtime.AgentInstance = this;
this.EventTypes = eventTypes;
_logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger<AgentBase>();
// TODO: Remove from constructor
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Expand Down Expand Up @@ -125,6 +126,15 @@ await this.InvokeWithActivityAsync(
case Message.MessageOneofCase.Response:
OnResponseCore(msg.Response);
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
_logger.LogInformation($"Got {msg.MessageCase} with payload {msg.RegisterAgentTypeResponse}"); // TODO: Refactor registration to be part of the AgentRpc service
break;
case Message.MessageOneofCase.AddSubscriptionResponse:
_logger.LogInformation($"Got {msg.MessageCase} with payload {msg.AddSubscriptionResponse}");// TODO: Refactor add subscription to be part of the AgentRpc service
break;
default:
_logger.LogInformation($"Got {msg.MessageCase}");
break;
}
}
public List<string> Subscribe(string topic)
Expand Down Expand Up @@ -258,37 +268,35 @@ static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) =
public Task CallHandler(CloudEvent item)
{
// Only send the event to the handler if the agent type is handling that type
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
foreach (var key in EventTypes.EventsMap.Keys)
// and belongs to the same AgentId
if (EventTypes.EventsMap[GetType()].Contains(item.Type) &&
item.Source == AgentId.Key)
{
if (EventTypes.EventsMap[key].Contains(item.Type))
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);

MethodInfo methodInfo;
try
MethodInfo methodInfo;
try
{
// check that our target actually implements this interface, otherwise call the default static
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
{
// check that our target actually implements this interface, otherwise call the default static
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
{
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
}
else
{
// The error here is we have registered for an event that we do not have code to listen to
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
}
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
}
catch (Exception ex)
else
{
_logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
throw; // TODO: ?
// The error here is we have registered for an event that we do not have code to listen to
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
}
}
catch (Exception ex)
{
_logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
throw; // TODO: ?
}
}

return Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System.Diagnostics;
using System.Reflection;
using System.Threading.Channels;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Grpc.Core;
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -212,7 +214,7 @@ private async ValueTask RegisterAgentTypeAsync(string type, Type agentType, Canc
{
var events = agentType.GetInterfaces()
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => i.GetGenericArguments().First().Name);
.Select(i => GetDescriptorName(i.GetGenericArguments().First()));
//var state = agentType.BaseType?.GetGenericArguments().First();
var topicTypes = agentType.GetCustomAttributes<TopicSubscriptionAttribute>().Select(t => t.Topic);

Expand All @@ -224,12 +226,26 @@ await WriteChannelAsync(new Message
RequestId = Guid.NewGuid().ToString(),
//TopicTypes = { topicTypes },
//StateType = state?.Name,
//Events = { events }
Events = { events }
}
}, cancellationToken).ConfigureAwait(false);
}
}

public static string GetDescriptorName(Type messageType)
{
if (typeof(IMessage).IsAssignableFrom(messageType))
{
var descriptorProperty = messageType.GetProperty("Descriptor", BindingFlags.Public | BindingFlags.Static);
if (descriptorProperty != null)
{
var descriptor = descriptorProperty.GetValue(null) as MessageDescriptor;
return descriptor?.FullName ?? messageType.Name;
}
}
return messageType.Name;
}

// new is intentional
public new async ValueTask SendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default)
{
Expand Down
11 changes: 7 additions & 4 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public sealed class GrpcGateway : BackgroundService, IGateway
public readonly ConcurrentDictionary<IConnection, IConnection> _workers = new();
private readonly ConcurrentDictionary<string, Subscription> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();
private readonly ConcurrentDictionary<string, HashSet<string>> _agentsToEventsMap = new();

// The mapping from agent id to worker process.
private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new();
Expand All @@ -40,12 +41,13 @@ public GrpcGateway(IClusterClient clusterClient, ILogger<GrpcGateway> logger)
}
public async ValueTask BroadcastEvent(CloudEvent evt)
{
// TODO: filter the workers that receive the event
var tasks = new List<Task>(_workers.Count);
foreach (var (_, connection) in _supportedAgentTypes)
foreach (var (key, connection) in _supportedAgentTypes)
{

tasks.Add(this.SendMessageAsync((IConnection)connection[0], evt, default));
if (_agentsToEventsMap.TryGetValue(key, out var events) && events.Contains(evt.Type))
{
tasks.Add(SendMessageAsync(connection[0], evt, default));
}
}
await Task.WhenAll(tasks).ConfigureAwait(false);
}
Expand Down Expand Up @@ -142,6 +144,7 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection,
{
connection.AddSupportedType(msg.Type);
_supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection);
_agentsToEventsMap.TryAdd(msg.Type, new HashSet<string>(msg.Events));

await _gatewayRegistry.RegisterAgentType(msg.Type, _reference).ConfigureAwait(true);
Message response = new()
Expand Down
1 change: 1 addition & 0 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ message Event {
message RegisterAgentTypeRequest {
string request_id = 1;
string type = 2;
repeated string events = 3;
}

message RegisterAgentTypeResponse {
Expand Down

0 comments on commit 2e1ee71

Please sign in to comment.