diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index ef0df4945ff9..a778e356dea1 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -93,9 +93,6 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "AgentChat", "AgentChat", "{C7A2D42D-9277-47AC-862B-D86DF9D6AD48}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "dev-team", "dev-team", "{616F30DF-1F41-4047-BAA4-64BA03BF5AEA}" - ProjectSection(SolutionItems) = preProject - samples\dev-team\DevTeam.ServiceDefaults\DevTeam.ServiceDefaults.csproj = samples\dev-team\DevTeam.ServiceDefaults\DevTeam.ServiceDefaults.csproj - EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DevTeam.AgentHost", "samples\dev-team\DevTeam.AgentHost\DevTeam.AgentHost.csproj", "{7228A701-C79D-4E15-BF45-48D11F721A84}" EndProject @@ -127,6 +124,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgents.Web", "samples\ EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Hello", "samples\Hello\Hello.csproj", "{6C9135E6-9D15-4D86-B3F4-9666DB87060A}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.ServiceDefaults", "src\Microsoft.AutoGen.ServiceDefaults\Microsoft.AutoGen.ServiceDefaults.csproj", "{F70C6FD7-9615-4EDD-8D55-5460FCC5A46D}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -337,6 +336,10 @@ Global {6C9135E6-9D15-4D86-B3F4-9666DB87060A}.Debug|Any CPU.Build.0 = Debug|Any CPU {6C9135E6-9D15-4D86-B3F4-9666DB87060A}.Release|Any CPU.ActiveCfg = Release|Any CPU {6C9135E6-9D15-4D86-B3F4-9666DB87060A}.Release|Any CPU.Build.0 = Release|Any CPU + {F70C6FD7-9615-4EDD-8D55-5460FCC5A46D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F70C6FD7-9615-4EDD-8D55-5460FCC5A46D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F70C6FD7-9615-4EDD-8D55-5460FCC5A46D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F70C6FD7-9615-4EDD-8D55-5460FCC5A46D}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -397,6 +400,7 @@ Global {6B88F4B3-26AB-4034-B0AC-5BA6EEDEB8E5} = {F7AC0FF1-8500-49C6-8CB3-97C6D52C8BEF} {8B56BE22-5CF4-44BB-AFA5-732FEA2AFF0B} = {F7AC0FF1-8500-49C6-8CB3-97C6D52C8BEF} {6C9135E6-9D15-4D86-B3F4-9666DB87060A} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} + {F70C6FD7-9615-4EDD-8D55-5460FCC5A46D} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/samples/Hello/Program.cs b/dotnet/samples/Hello/Program.cs index 220d660ab53f..1e80821a37ae 100644 --- a/dotnet/samples/Hello/Program.cs +++ b/dotnet/samples/Hello/Program.cs @@ -1,23 +1,15 @@ using Microsoft.AutoGen.Agents.Abstractions; using Microsoft.AutoGen.Agents.Client; -using Microsoft.AutoGen.Agents.Runtime; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -var builder = Host.CreateApplicationBuilder(args); -builder.AddAgentService(); -builder.UseOrleans(siloBuilder => -{ - siloBuilder.UseLocalhostClustering(); ; -}); -builder.AddAgentWorker("https://localhost:5000"); -var app = builder.Build(); -await app.StartAsync(); -app.Services.GetRequiredService(); -var evt = new NewMessageReceived +// send a message to the agent +var app = await App.PublishMessageAsync("HelloAgents", new NewMessageReceived { Message = "World" -}.ToCloudEvent("HelloAgents"); +}, local: true); + +await App.RuntimeApp!.WaitForShutdownAsync(); await app.WaitForShutdownAsync(); [TopicSubscription("HelloAgents")] @@ -32,27 +24,33 @@ public class HelloAgent( { public async Task Handle(NewMessageReceived item) { - var response = await SayHello(item.Message); + var response = await SayHello(item.Message).ConfigureAwait(false); var evt = new Output { Message = response }.ToCloudEvent(this.AgentId.Key); - await PublishEvent(evt); + await PublishEvent(evt).ConfigureAwait(false); + var goodbye = new ConversationClosed + { + UserId = this.AgentId.Key, + UserMessage = "Goodbye" + }.ToCloudEvent(this.AgentId.Key); + await PublishEvent(goodbye).ConfigureAwait(false); } - public async Task Handle(ConversationClosed item) { - var goodbye = "Goodbye!"; + var goodbye = $"********************* {item.UserId} said {item.UserMessage} ************************"; var evt = new Output { Message = goodbye }.ToCloudEvent(this.AgentId.Key); - await PublishEvent(evt); + await PublishEvent(evt).ConfigureAwait(false); + await Task.Delay(60000); + await App.ShutdownAsync(); } - public async Task SayHello(string ask) { - var response = $"Hello {ask}"; + var response = $"\n\n\n\n***************Hello {ask}**********************\n\n\n\n"; return response; } } diff --git a/dotnet/samples/dev-team/DevTeam.AgentHost/DevTeam.AgentHost.csproj b/dotnet/samples/dev-team/DevTeam.AgentHost/DevTeam.AgentHost.csproj index 5ed148e79e75..1f6da64e3301 100644 --- a/dotnet/samples/dev-team/DevTeam.AgentHost/DevTeam.AgentHost.csproj +++ b/dotnet/samples/dev-team/DevTeam.AgentHost/DevTeam.AgentHost.csproj @@ -10,7 +10,7 @@ - + diff --git a/dotnet/samples/dev-team/DevTeam.Agents/DevTeam.Agents.csproj b/dotnet/samples/dev-team/DevTeam.Agents/DevTeam.Agents.csproj index b06778401694..065f31b0dd0f 100644 --- a/dotnet/samples/dev-team/DevTeam.Agents/DevTeam.Agents.csproj +++ b/dotnet/samples/dev-team/DevTeam.Agents/DevTeam.Agents.csproj @@ -10,7 +10,7 @@ - + diff --git a/dotnet/samples/dev-team/DevTeam.Backend/DevTeam.Backend.csproj b/dotnet/samples/dev-team/DevTeam.Backend/DevTeam.Backend.csproj index 2227fc82ec04..5200f1f45caf 100644 --- a/dotnet/samples/dev-team/DevTeam.Backend/DevTeam.Backend.csproj +++ b/dotnet/samples/dev-team/DevTeam.Backend/DevTeam.Backend.csproj @@ -29,7 +29,7 @@ - + diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentClient.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentClient.cs index f661eaa58d2e..00518d3ce28f 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentClient.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentClient.cs @@ -1,11 +1,11 @@ using System.Diagnostics; +using Google.Protobuf; using Microsoft.AutoGen.Agents.Abstractions; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Microsoft.AutoGen.Agents.Client; -// TODO: Extract this to be part of the Client public sealed class AgentClient(ILogger logger, AgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator, [FromKeyedServices("EventTypes")] EventTypes eventTypes) : AgentBase(new ClientContext(logger, runtime, distributedContextPropagator), eventTypes) @@ -13,6 +13,10 @@ public sealed class AgentClient(ILogger logger, AgentWorkerRuntime public async ValueTask PublishEventAsync(CloudEvent evt) => await PublishEvent(evt); public async ValueTask SendRequestAsync(AgentId target, string method, Dictionary parameters) => await RequestAsync(target, method, parameters); + public async ValueTask PublishEventAsync(string topic, IMessage evt) + { + await PublishEventAsync(evt.ToCloudEvent(topic)).ConfigureAwait(false); + } private sealed class ClientContext(ILogger logger, AgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator) : IAgentContext { public AgentId AgentId { get; } = new AgentId("client", Guid.NewGuid().ToString()); diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentWorkerRuntime.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentWorkerRuntime.cs index a3e52adef1ef..5483efd72bc2 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentWorkerRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/AgentWorkerRuntime.cs @@ -170,7 +170,7 @@ private async ValueTask RegisterAgentType(string type, Type agentType) var events = agentType.GetInterfaces() .Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>)) .Select(i => i.GetGenericArguments().First().Name); - var state = agentType.BaseType?.GetGenericArguments().First(); + //var state = agentType.BaseType?.GetGenericArguments().First(); var topicTypes = agentType.GetCustomAttributes().Select(t => t.Topic); await WriteChannelAsync(new Message diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/ConsoleAgent/ConsoleAgent.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/ConsoleAgent/ConsoleAgent.cs index 2fda5654f372..5ebefba8b321 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/ConsoleAgent/ConsoleAgent.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/ConsoleAgent/ConsoleAgent.cs @@ -3,7 +3,7 @@ namespace Microsoft.AutoGen.Agents.Client; -public class ConsoleAgent : IOAgent, +public abstract class ConsoleAgent : IOAgent, IUseConsole, IHandle, IHandle @@ -32,6 +32,7 @@ public override async Task Handle(Output item) { // Assuming item has a property `Content` that we want to write to the console Console.WriteLine(item.Message); + await ProcessOutput(item.Message); var evt = new OutputWritten { diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/FileAgent/FileAgent.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/FileAgent/FileAgent.cs index f276b830dd58..12965b432c48 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/FileAgent/FileAgent.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/FileAgent/FileAgent.cs @@ -1,27 +1,26 @@ using Microsoft.AutoGen.Agents.Abstractions; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Microsoft.AutoGen.Agents.Client; [TopicSubscription("FileIO")] -public class FileAgent : IOAgent, +public abstract class FileAgent( + IAgentContext context, + [FromKeyedServices("EventTypes")] EventTypes typeRegistry, + string inputPath = "input.txt", + string outputPath = "output.txt" + ) : IOAgent(context, typeRegistry), IUseFiles, IHandle, IHandle { - public FileAgent(IAgentContext context, EventTypes typeRegistry, string filePath) : base(context, typeRegistry) - { - _filePath = filePath; - } - private readonly string _filePath; - public override async Task Handle(Input item) { - // validate that the file exists - if (!File.Exists(_filePath)) + if (!File.Exists(inputPath)) { - string errorMessage = $"File not found: {_filePath}"; + var errorMessage = $"File not found: {inputPath}"; Logger.LogError(errorMessage); //publish IOError event var err = new IOError @@ -31,36 +30,30 @@ public override async Task Handle(Input item) await PublishEvent(err); return; } - string content; using (var reader = new StreamReader(item.Message)) { content = await reader.ReadToEndAsync(); } - await ProcessInput(content); - var evt = new InputProcessed { Route = _route }.ToCloudEvent(this.AgentId.Key); await PublishEvent(evt); } - public override async Task Handle(Output item) { - using (var writer = new StreamWriter(_filePath, append: true)) + using (var writer = new StreamWriter(outputPath, append: true)) { await writer.WriteLineAsync(item.Message); } - var evt = new OutputWritten { Route = _route }.ToCloudEvent(this.AgentId.Key); await PublishEvent(evt); } - public override async Task ProcessInput(string message) { var evt = new InputProcessed @@ -70,14 +63,12 @@ public override async Task ProcessInput(string message) await PublishEvent(evt); return message; } - public override Task ProcessOutput(string message) { // Implement your output processing logic here return Task.CompletedTask; } } - public interface IUseFiles { } diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/WebAPIAgent/WebAPIAgent.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/WebAPIAgent/WebAPIAgent.cs index b7e3abbc582e..49babfdf5982 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/WebAPIAgent/WebAPIAgent.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/Agents/IOAgent/WebAPIAgent/WebAPIAgent.cs @@ -6,7 +6,7 @@ namespace Microsoft.AutoGen.Agents.Client; -public class WebAPIAgent : IOAgent, +public abstract class WebAPIAgent : IOAgent, IUseWebAPI, IHandle, IHandle @@ -16,8 +16,8 @@ public class WebAPIAgent : IOAgent, public WebAPIAgent( IAgentContext context, [FromKeyedServices("EventTypes")] EventTypes typeRegistry, - string url, - ILogger logger) : base( + ILogger logger, + string url = "/agents/webio") : base( context, typeRegistry) { diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/App.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/App.cs new file mode 100644 index 000000000000..e8b3651fa613 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/App.cs @@ -0,0 +1,53 @@ +using Google.Protobuf; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AutoGen.Agents.Client; + +public static class App +{ + // need a variable to store the runtime instance + public static WebApplication? RuntimeApp { get; set; } + public static WebApplication? ClientApp { get; set; } + public static async ValueTask StartAsync(AgentTypes? agentTypes = null, bool local = false) + { + // start the server runtime + RuntimeApp ??= await Runtime.Host.StartAsync(local); + var clientBuilder = WebApplication.CreateBuilder(); + var appBuilder = clientBuilder.AddAgentWorker(); + agentTypes ??= AgentTypes.GetAgentTypesFromAssembly() + ?? throw new InvalidOperationException("No agent types found in the assembly"); + foreach (var type in agentTypes.Types) + { + appBuilder.AddAgent(type.Key, type.Value); + } + ClientApp = clientBuilder.Build(); + await ClientApp.StartAsync().ConfigureAwait(false); + return ClientApp; + } + + public static async ValueTask PublishMessageAsync( + string topic, + IMessage message, + AgentTypes? agentTypes = null, + bool local = false) + { + if (ClientApp == null) + { + ClientApp = await App.StartAsync(agentTypes, local); + } + var client = ClientApp.Services.GetRequiredService() ?? throw new InvalidOperationException("Client not started"); + await client.PublishEventAsync(topic, message).ConfigureAwait(false); + return ClientApp; + } + + public static async ValueTask ShutdownAsync() + { + if (ClientApp == null) + { + throw new InvalidOperationException("Client not started"); + } + await RuntimeApp!.StopAsync(); + await ClientApp.StopAsync(); + } +} diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/HostBuilderExtensions.cs b/dotnet/src/Microsoft.AutoGen.Agents.Client/HostBuilderExtensions.cs index d0b0a71a2123..76c10b5773f1 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/HostBuilderExtensions.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/HostBuilderExtensions.cs @@ -14,7 +14,8 @@ namespace Microsoft.AutoGen.Agents.Client; public static class HostBuilderExtensions { - public static AgentApplicationBuilder AddAgentWorker(this IHostApplicationBuilder builder, string agentServiceAddress) + private const string _defaultAgentServiceAddress = "https://localhost:5001"; + public static AgentApplicationBuilder AddAgentWorker(this IHostApplicationBuilder builder, string agentServiceAddress = _defaultAgentServiceAddress, bool local = false) { builder.Services.AddGrpcClient(options => { @@ -48,6 +49,7 @@ public static AgentApplicationBuilder AddAgentWorker(this IHostApplicationBuilde }); }); builder.Services.TryAddSingleton(DistributedContextPropagator.Current); + builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(sp => sp.GetRequiredService()); builder.Services.AddKeyedSingleton("EventTypes", (sp, key) => @@ -64,7 +66,7 @@ public static AgentApplicationBuilder AddAgentWorker(this IHostApplicationBuilde var eventsMap = AppDomain.CurrentDomain.GetAssemblies() .SelectMany(assembly => assembly.GetTypes()) - .Where(type => IsSubclassOfGeneric(type, typeof(AiAgent<>)) && !type.IsAbstract) + .Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract) .Select(t => (t, t.GetInterfaces() .Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>)) .Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet())) @@ -80,8 +82,10 @@ public static AgentApplicationBuilder AddAgentWorker(this IHostApplicationBuilde var property = type.GetProperty("Descriptor", BindingFlags.Static | BindingFlags.Public); return property?.GetValue(null) as MessageDescriptor; } - - private static bool IsSubclassOfGeneric(Type type, Type genericBaseType) +} +public sealed class ReflectionHelper +{ + public static bool IsSubclassOfGeneric(Type type, Type genericBaseType) { while (type != null && type != typeof(object)) { @@ -98,7 +102,21 @@ private static bool IsSubclassOfGeneric(Type type, Type genericBaseType) return false; } } +public sealed class AgentTypes(Dictionary types) +{ + public Dictionary Types { get; } = types; + public static AgentTypes? GetAgentTypesFromAssembly() + { + var agents = AppDomain.CurrentDomain.GetAssemblies() + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) + && !type.IsAbstract + && !type.Name.Equals("AgentClient")) + .ToDictionary(type => type.Name, type => type); + return new AgentTypes(agents); + } +} public sealed class EventTypes(TypeRegistry typeRegistry, Dictionary types, Dictionary> eventsMap) { public TypeRegistry TypeRegistry { get; } = typeRegistry; @@ -114,5 +132,10 @@ public AgentApplicationBuilder AddAgent< builder.Services.AddKeyedSingleton("AgentTypes", (sp, key) => Tuple.Create(typeName, typeof(TAgent))); return this; } + public AgentApplicationBuilder AddAgent(string typeName, Type agentType) + { + builder.Services.AddKeyedSingleton("AgentTypes", (sp, key) => Tuple.Create(typeName, agentType)); + return this; + } } diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Client/Microsoft.AutoGen.Agents.Client.csproj b/dotnet/src/Microsoft.AutoGen.Agents.Client/Microsoft.AutoGen.Agents.Client.csproj index 0af9d0a1f0dd..6d708bb55038 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Client/Microsoft.AutoGen.Agents.Client.csproj +++ b/dotnet/src/Microsoft.AutoGen.Agents.Client/Microsoft.AutoGen.Agents.Client.csproj @@ -13,6 +13,7 @@ + diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Runtime/AgentWorkerHostingExtensions.cs b/dotnet/src/Microsoft.AutoGen.Agents.Runtime/AgentWorkerHostingExtensions.cs index ee646bb5530e..3e92e43d7c61 100644 --- a/dotnet/src/Microsoft.AutoGen.Agents.Runtime/AgentWorkerHostingExtensions.cs +++ b/dotnet/src/Microsoft.AutoGen.Agents.Runtime/AgentWorkerHostingExtensions.cs @@ -1,5 +1,7 @@ using System.Diagnostics; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; @@ -23,6 +25,24 @@ public static IHostApplicationBuilder AddAgentService(this IHostApplicationBuild return builder; } + public static WebApplicationBuilder AddLocalAgentService(this WebApplicationBuilder builder) + { + builder.WebHost.ConfigureKestrel(serverOptions => + { + serverOptions.ListenLocalhost(5001, listenOptions => + { + listenOptions.Protocols = HttpProtocols.Http2; + listenOptions.UseHttps(); + }); + }); + builder.AddAgentService(); + builder.UseOrleans(siloBuilder => + { + siloBuilder.UseLocalhostClustering(); ; + }); + return builder; + } + public static WebApplication MapAgentService(this WebApplication app) { app.MapGrpcService(); diff --git a/dotnet/src/Microsoft.AutoGen.Agents.Runtime/Host.cs b/dotnet/src/Microsoft.AutoGen.Agents.Runtime/Host.cs new file mode 100644 index 000000000000..e3e8107b4d12 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen.Agents.Runtime/Host.cs @@ -0,0 +1,23 @@ +using Microsoft.AspNetCore.Builder; + +namespace Microsoft.AutoGen.Agents.Runtime; + +public static class Host +{ + public static async Task StartAsync(bool local = false) + { + var builder = WebApplication.CreateBuilder(); + if (local) + { + builder.AddLocalAgentService(); + } + else + { + builder.AddAgentService(); + } + var app = builder.Build(); + app.MapAgentService(); + await app.StartAsync().ConfigureAwait(false); + return app; + } +} diff --git a/dotnet/samples/dev-team/DevTeam.ServiceDefaults/Extensions.cs b/dotnet/src/Microsoft.AutoGen.ServiceDefaults/Extensions.cs similarity index 99% rename from dotnet/samples/dev-team/DevTeam.ServiceDefaults/Extensions.cs rename to dotnet/src/Microsoft.AutoGen.ServiceDefaults/Extensions.cs index 6cbae1ff4715..d45caeae6cde 100644 --- a/dotnet/samples/dev-team/DevTeam.ServiceDefaults/Extensions.cs +++ b/dotnet/src/Microsoft.AutoGen.ServiceDefaults/Extensions.cs @@ -36,10 +36,8 @@ public static IHostApplicationBuilder AddServiceDefaults(this IHostApplicationBu // { // options.AllowedSchemes = ["https"]; // }); - return builder; } - public static IHostApplicationBuilder ConfigureOpenTelemetry(this IHostApplicationBuilder builder) { builder.Logging.AddOpenTelemetry(logging => @@ -69,7 +67,6 @@ public static IHostApplicationBuilder ConfigureOpenTelemetry(this IHostApplicati return builder; } - private static IHostApplicationBuilder AddOpenTelemetryExporters(this IHostApplicationBuilder builder) { var useOtlpExporter = !string.IsNullOrWhiteSpace(builder.Configuration["OTEL_EXPORTER_OTLP_ENDPOINT"]); @@ -85,19 +82,15 @@ private static IHostApplicationBuilder AddOpenTelemetryExporters(this IHostAppli // builder.Services.AddOpenTelemetry() // .UseAzureMonitor(); //} - return builder; } - public static IHostApplicationBuilder AddDefaultHealthChecks(this IHostApplicationBuilder builder) { builder.Services.AddHealthChecks() // Add a default liveness check to ensure app is responsive .AddCheck("self", () => HealthCheckResult.Healthy(), ["live"]); - return builder; } - public static WebApplication MapDefaultEndpoints(this WebApplication app) { // Adding health checks endpoints to applications in non-development environments has security implications. @@ -113,7 +106,6 @@ public static WebApplication MapDefaultEndpoints(this WebApplication app) Predicate = r => r.Tags.Contains("live") }); } - return app; } } diff --git a/dotnet/samples/dev-team/DevTeam.ServiceDefaults/DevTeam.ServiceDefaults.csproj b/dotnet/src/Microsoft.AutoGen.ServiceDefaults/Microsoft.AutoGen.ServiceDefaults.csproj similarity index 99% rename from dotnet/samples/dev-team/DevTeam.ServiceDefaults/DevTeam.ServiceDefaults.csproj rename to dotnet/src/Microsoft.AutoGen.ServiceDefaults/Microsoft.AutoGen.ServiceDefaults.csproj index 5060c01aa69a..cf2446f93349 100644 --- a/dotnet/samples/dev-team/DevTeam.ServiceDefaults/DevTeam.ServiceDefaults.csproj +++ b/dotnet/src/Microsoft.AutoGen.ServiceDefaults/Microsoft.AutoGen.ServiceDefaults.csproj @@ -1,15 +1,12 @@ - net8.0 enable enable true - - @@ -18,5 +15,4 @@ - diff --git a/python/packages/autogen-core/docs/src/core-user-guide/core-concepts/faqs.md b/python/packages/autogen-core/docs/src/core-user-guide/core-concepts/faqs.md index 1ecc14ffbfc5..5e12f8603d98 100644 --- a/python/packages/autogen-core/docs/src/core-user-guide/core-concepts/faqs.md +++ b/python/packages/autogen-core/docs/src/core-user-guide/core-concepts/faqs.md @@ -15,3 +15,22 @@ This allows your agent to work in a distributed environment a well as a local on An {py:class}`autogen_core.base.AgentId` is composed of a `type` and a `key`. The type corresponds to the factory that created the agent, and the key is a runtime, data dependent key for this instance. The key can correspond to a user id, a session id, or could just be "default" if you don't need to differentiate between instances. Each unique key will create a new instance of the agent, based on the factory provided. This allows the system to automatically scale to different instances of the same agent, and to manage the lifecycle of each instance independently based on how you choose to handle keys in your application. + +## How do I increase the GRPC message size? + +If you need to provide custom gRPC options, such as overriding the `max_send_message_length` and `max_receive_message_length`, you can define an `extra_grpc_config` variable and pass it to both the `WorkerAgentRuntimeHost` and `WorkerAgentRuntime` instances. + +```python +# Define custom gRPC options +extra_grpc_config = [ + ("grpc.max_send_message_length", new_max_size), + ("grpc.max_receive_message_length", new_max_size), +] + +# Create instances of WorkerAgentRuntimeHost and WorkerAgentRuntime with the custom gRPC options + +host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config) +worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config) +``` + +**Note**: When `WorkerAgentRuntime` creates a host connection for the clients, it uses `DEFAULT_GRPC_CONFIG` from `HostConnection` class as default set of values which will can be overriden if you pass parameters with the same name using `extra_grpc_config`. diff --git a/python/packages/autogen-core/src/autogen_core/application/__init__.py b/python/packages/autogen-core/src/autogen_core/application/__init__.py index 60bf1bb94a30..0caa89c6da07 100644 --- a/python/packages/autogen-core/src/autogen_core/application/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/application/__init__.py @@ -6,8 +6,4 @@ from ._worker_runtime import WorkerAgentRuntime from ._worker_runtime_host import WorkerAgentRuntimeHost -__all__ = [ - "SingleThreadedAgentRuntime", - "WorkerAgentRuntime", - "WorkerAgentRuntimeHost", -] +__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "WorkerAgentRuntimeHost"] diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index ab38cdfbf30d..ac3e00e4a4a9 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -34,6 +34,7 @@ from autogen_core.base import JSON_DATA_CONTENT_TYPE from autogen_core.base._serialization import MessageSerializer, SerializationRegistry +from autogen_core.base._type_helpers import ChannelArgumentType from ..base import ( Agent, @@ -63,6 +64,7 @@ P = ParamSpec("P") T = TypeVar("T", bound=Agent) + type_func_alias = type @@ -78,20 +80,27 @@ def __aiter__(self) -> AsyncIterator[Any]: class HostConnection: - DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = { - "methodConfig": [ - { - "name": [{}], - "retryPolicy": { - "maxAttempts": 3, - "initialBackoff": "0.01s", - "maxBackoff": "5s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - } - ], - } + DEFAULT_GRPC_CONFIG: ClassVar[ChannelArgumentType] = [ + ( + "grpc.service_config", + json.dumps( + { + "methodConfig": [ + { + "name": [{}], + "retryPolicy": { + "maxAttempts": 3, + "initialBackoff": "0.01s", + "maxBackoff": "5s", + "backoffMultiplier": 2, + "retryableStatusCodes": ["UNAVAILABLE"], + }, + } + ], + } + ), + ) + ] def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore self._channel = channel @@ -100,9 +109,17 @@ def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore self._connection_task: Task[None] | None = None @classmethod - def from_host_address(cls, host_address: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG) -> Self: + def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self: logger.info("Connecting to %s", host_address) - channel = grpc.aio.insecure_channel(host_address, options=[("grpc.service_config", json.dumps(grpc_config))]) + # Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config + merged_options = [ + (k, v) for k, v in {**dict(HostConnection.DEFAULT_GRPC_CONFIG), **dict(extra_grpc_config)}.items() + ] + + channel = grpc.aio.insecure_channel( + host_address, + options=merged_options, + ) instance = cls(channel) instance._connection_task = asyncio.create_task( instance._connect(channel, instance._send_queue, instance._recv_queue) @@ -150,7 +167,12 @@ async def recv(self) -> agent_worker_pb2.Message: class WorkerAgentRuntime(AgentRuntime): - def __init__(self, host_address: str, tracer_provider: TracerProvider | None = None) -> None: + def __init__( + self, + host_address: str, + tracer_provider: TracerProvider | None = None, + extra_grpc_config: ChannelArgumentType | None = None, + ) -> None: self._host_address = host_address self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime")) self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set) @@ -168,13 +190,16 @@ def __init__(self, host_address: str, tracer_provider: TracerProvider | None = N self._background_tasks: Set[Task[Any]] = set() self._subscription_manager = SubscriptionManager() self._serialization_registry = SerializationRegistry() + self._extra_grpc_config = extra_grpc_config or [] def start(self) -> None: """Start the runtime in a background task.""" if self._running: raise ValueError("Runtime is already running.") logger.info(f"Connecting to host: {self._host_address}") - self._host_connection = HostConnection.from_host_address(self._host_address) + self._host_connection = HostConnection.from_host_address( + self._host_address, extra_grpc_config=self._extra_grpc_config + ) logger.info("Connection established") if self._read_task is None: self._read_task = asyncio.create_task(self._run_read_loop()) diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py index 7a5590b2109d..e6585098bdbd 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py @@ -1,10 +1,12 @@ import asyncio import logging import signal -from typing import Sequence +from typing import Optional, Sequence import grpc +from autogen_core.base._type_helpers import ChannelArgumentType + from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer from .protos import agent_worker_pb2_grpc @@ -12,8 +14,8 @@ class WorkerAgentRuntimeHost: - def __init__(self, address: str) -> None: - self._server = grpc.aio.server() + def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None: + self._server = grpc.aio.server(options=extra_grpc_config) self._servicer = WorkerAgentRuntimeHostServicer() agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server) self._server.add_insecure_port(address) diff --git a/python/packages/autogen-core/src/autogen_core/base/_type_helpers.py b/python/packages/autogen-core/src/autogen_core/base/_type_helpers.py index 66e52e4b6a37..3dab707feb35 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_type_helpers.py +++ b/python/packages/autogen-core/src/autogen_core/base/_type_helpers.py @@ -1,6 +1,9 @@ from collections.abc import Sequence from types import NoneType, UnionType -from typing import Any, Optional, Type, Union, get_args, get_origin +from typing import Any, Optional, Tuple, Type, Union, get_args, get_origin + +# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors +ChannelArgumentType = Sequence[Tuple[str, Any]] def is_union(t: object) -> bool: diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 762a05a4edc8..d1a1a3299053 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -14,11 +14,17 @@ from autogen_core.components import ( DefaultTopicId, TypeSubscription, - default_subscription, type_subscription, ) from opentelemetry.sdk.trace import TracerProvider -from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent +from test_utils import ( + CascadingAgent, + CascadingMessageType, + LoopbackAgent, + LoopbackAgentWithDefaultSubscription, + MessageType, + NoopAgent, +) from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider test_exporter = TestExporter() @@ -218,9 +224,6 @@ async def test_default_subscription() -> None: runtime = SingleThreadedAgentRuntime() runtime.start() - @default_subscription - class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... - await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription) agent_id = AgentId("name", key="default") @@ -267,9 +270,6 @@ async def test_default_subscription_publish_to_other_source() -> None: runtime = SingleThreadedAgentRuntime() runtime.start() - @default_subscription - class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... - await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription) agent_id = AgentId("name", key="default") diff --git a/python/packages/autogen-core/tests/test_utils/__init__.py b/python/packages/autogen-core/tests/test_utils/__init__.py index eb64be1d8b22..92096550cd98 100644 --- a/python/packages/autogen-core/tests/test_utils/__init__.py +++ b/python/packages/autogen-core/tests/test_utils/__init__.py @@ -14,17 +14,28 @@ class CascadingMessageType: round: int +@dataclass +class ContentMessage: + content: str + + class LoopbackAgent(RoutedAgent): def __init__(self) -> None: super().__init__("A loop back agent.") self.num_calls = 0 @message_handler - async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: + async def on_new_message( + self, message: MessageType | ContentMessage, ctx: MessageContext + ) -> MessageType | ContentMessage: self.num_calls += 1 return message +@default_subscription +class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... + + @default_subscription class CascadingAgent(RoutedAgent): def __init__(self, max_rounds: int) -> None: @@ -46,24 +57,3 @@ def __init__(self) -> None: async def on_message(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError - - -@dataclass -class MyMessage: - content: str - - -@default_subscription -class MyAgent(RoutedAgent): - def __init__(self, name: str) -> None: - super().__init__("My agent") - self._name = name - self._counter = 0 - - @message_handler - async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None: - self._counter += 1 - if self._counter > 5: - return - content = f"{self._name}: Hello x {self._counter}" - await self.publish_message(MyMessage(content=content), DefaultTopicId()) diff --git a/python/packages/autogen-core/tests/test_worker_runtime.py b/python/packages/autogen-core/tests/test_worker_runtime.py index 1f9698c04fb1..0e232243dd2b 100644 --- a/python/packages/autogen-core/tests/test_worker_runtime.py +++ b/python/packages/autogen-core/tests/test_worker_runtime.py @@ -15,10 +15,17 @@ from autogen_core.components import ( DefaultTopicId, TypeSubscription, - default_subscription, type_subscription, ) -from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, MyAgent, MyMessage, NoopAgent +from test_utils import ( + CascadingAgent, + CascadingMessageType, + ContentMessage, + LoopbackAgent, + LoopbackAgentWithDefaultSubscription, + MessageType, + NoopAgent, +) @pytest.mark.asyncio @@ -204,9 +211,6 @@ async def test_default_subscription() -> None: publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) publisher.start() - @default_subscription - class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... - await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription()) await publisher.publish_message(MessageType(), topic_id=DefaultTopicId()) @@ -241,9 +245,6 @@ async def test_default_subscription_other_source() -> None: publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) publisher.start() - @default_subscription - class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... - await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription()) await publisher.publish_message(MessageType(), topic_id=DefaultTopicId(source="other")) @@ -313,20 +314,20 @@ async def test_duplicate_subscription() -> None: host.start() try: worker1.start() - await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1")) + await NoopAgent.register(worker1, "worker1", lambda: NoopAgent()) worker1_2.start() # Note: This passes because worker1 is still running with pytest.raises(RuntimeError, match="Agent type worker1 already registered"): - await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2")) + await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent()) # This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent. # Will keep them both for now as we might replace the way we simulate a disconnect await worker1.stop() with pytest.raises(ValueError): - await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2")) + await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent()) except Exception as ex: raise ex @@ -354,7 +355,9 @@ async def get_subscribed_recipients() -> List[AgentId]: try: worker1.start() - await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1")) + await LoopbackAgentWithDefaultSubscription.register( + worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription() + ) subscriptions1 = get_current_subscriptions() assert len(subscriptions1) == 1 @@ -363,7 +366,7 @@ async def get_subscribed_recipients() -> List[AgentId]: first_subscription_id = subscriptions1[0].id - await worker1.publish_message(MyMessage(content="Hello!"), DefaultTopicId()) + await worker1.publish_message(ContentMessage(content="Hello!"), DefaultTopicId()) # This is a simple simulation of worker disconnct if worker1._host_connection is not None: # type: ignore[reportPrivateUsage] try: @@ -380,7 +383,9 @@ async def get_subscribed_recipients() -> List[AgentId]: await asyncio.sleep(1) worker1_2.start() - await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1")) + await LoopbackAgentWithDefaultSubscription.register( + worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription() + ) subscriptions3 = get_current_subscriptions() assert len(subscriptions3) == 1 @@ -394,10 +399,70 @@ async def get_subscribed_recipients() -> List[AgentId]: finally: await worker1.stop() await worker1_2.stop() + + +@pytest.mark.asyncio +async def test_grpc_max_message_size() -> None: + default_max_size = 2**22 + new_max_size = default_max_size * 2 + small_message = ContentMessage(content="small message") + big_message = ContentMessage(content="." * (default_max_size + 1)) + + extra_grpc_config = [ + ("grpc.max_send_message_length", new_max_size), + ("grpc.max_receive_message_length", new_max_size), + ] + host_address = "localhost:50059" + host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config) + worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config) + worker2 = WorkerAgentRuntime(host_address=host_address) + worker3 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config) + + try: + host.start() + worker1.start() + worker2.start() + worker3.start() + await LoopbackAgentWithDefaultSubscription.register( + worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription() + ) + await LoopbackAgentWithDefaultSubscription.register( + worker2, "worker2", lambda: LoopbackAgentWithDefaultSubscription() + ) + await LoopbackAgentWithDefaultSubscription.register( + worker3, "worker3", lambda: LoopbackAgentWithDefaultSubscription() + ) + + # with pytest.raises(Exception): + await worker1.publish_message(small_message, DefaultTopicId()) + # This is a simple simulation of worker disconnct + await asyncio.sleep(1) + agent_instance_2 = await worker2.try_get_underlying_agent_instance( + AgentId("worker2", key="default"), type=LoopbackAgent + ) + agent_instance_3 = await worker3.try_get_underlying_agent_instance( + AgentId("worker3", key="default"), type=LoopbackAgent + ) + assert agent_instance_2.num_calls == 1 + assert agent_instance_3.num_calls == 1 + + await worker1.publish_message(big_message, DefaultTopicId()) + await asyncio.sleep(2) + assert agent_instance_2.num_calls == 1 # Worker 2 won't receive the big message + assert agent_instance_3.num_calls == 2 # Worker 3 will receive the big message as has increased message length + except Exception as e: + raise e + finally: + await worker1.stop() + # await worker2.stop() # Worker 2 somehow breaks can can not be stopped. + await worker3.stop() + await host.stop() if __name__ == "__main__": os.environ["GRPC_VERBOSITY"] = "DEBUG" os.environ["GRPC_TRACE"] = "all" + asyncio.run(test_disconnected_agent()) + asyncio.run(test_grpc_max_message_size())