From f261c7ec9376e1f086a0027df392a4b0e4aa2321 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Mon, 2 Dec 2024 12:50:50 -0500 Subject: [PATCH] feat: Implement IChatClient wrapper for Anthropic --- dotnet/AutoGen.sln | 7 + .../AnthropicChatCompletionClient.cs | 280 ++++++++++++++++++ .../Extensions/Anthropic/AnthropicClient.cs | 215 ++++++++++++++ .../Converters/ContentBaseConverter.cs | 39 +++ .../JsonPropertyNameEnumCoverter.cs | 43 +++ .../Converters/SystemMessageConverter.cs | 41 +++ .../Anthropic/DTO/ChatCompletionRequest.cs | 93 ++++++ .../Anthropic/DTO/ChatCompletionResponse.cs | 96 ++++++ .../Extensions/Anthropic/DTO/Content.cs | 280 ++++++++++++++++++ .../Extensions/Anthropic/DTO/ErrorResponse.cs | 21 ++ .../Extensions/Anthropic/DTO/Tool.cs | 76 +++++ .../Extensions/Anthropic/DTO/ToolChoice.cs | 39 +++ ...rosoft.AutoGen.Extensions.Anthropic.csproj | 31 ++ .../Anthropic/Utils/AnthropicConstants.cs | 15 + 14 files changed, 1276 insertions(+) create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicClient.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/ContentBaseConverter.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/JsonPropertyNameEnumCoverter.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/SystemMessageConverter.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionRequest.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionResponse.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Content.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ErrorResponse.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Tool.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ToolChoice.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Microsoft.AutoGen.Extensions.Anthropic.csproj create mode 100644 dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Utils/AnthropicConstants.cs diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 5b26e27165b3..113988a23cab 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -130,6 +130,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AutoGen.Extensions.Aspire", "src\Microsoft.AutoGen\Extensions\Aspire\Microsoft.AutoGen.Extensions.Aspire.csproj", "{65059914-5527-4A00-9308-9FAF23D5E85A}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AutoGen.Extensions.Anthropic", "src\Microsoft.AutoGen\Extensions\Anthropic\Microsoft.AutoGen.Extensions.Anthropic.csproj", "{5ED47D4C-19D7-4684-B6F8-A4AA77F37E21}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AutoGen.Agents.Tests", "test\Microsoft.AutoGen.Agents.Tests\Microsoft.AutoGen.Agents.Tests.csproj", "{394FDAF8-74F9-4977-94A5-3371737EB774}" EndProject Global @@ -338,6 +340,10 @@ Global {65059914-5527-4A00-9308-9FAF23D5E85A}.Debug|Any CPU.Build.0 = Debug|Any CPU {65059914-5527-4A00-9308-9FAF23D5E85A}.Release|Any CPU.ActiveCfg = Release|Any CPU {65059914-5527-4A00-9308-9FAF23D5E85A}.Release|Any CPU.Build.0 = Release|Any CPU + {5ED47D4C-19D7-4684-B6F8-A4AA77F37E21}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5ED47D4C-19D7-4684-B6F8-A4AA77F37E21}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5ED47D4C-19D7-4684-B6F8-A4AA77F37E21}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5ED47D4C-19D7-4684-B6F8-A4AA77F37E21}.Release|Any CPU.Build.0 = Release|Any CPU {394FDAF8-74F9-4977-94A5-3371737EB774}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {394FDAF8-74F9-4977-94A5-3371737EB774}.Debug|Any CPU.Build.0 = Debug|Any CPU {394FDAF8-74F9-4977-94A5-3371737EB774}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -401,6 +407,7 @@ Global {97550E87-48C6-4EBF-85E1-413ABAE9DBFD} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {64EF61E7-00A6-4E5E-9808-62E10993A0E5} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45} {65059914-5527-4A00-9308-9FAF23D5E85A} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {5ED47D4C-19D7-4684-B6F8-A4AA77F37E21} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {394FDAF8-74F9-4977-94A5-3371737EB774} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs new file mode 100644 index 000000000000..6248bd4462e0 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicChatCompletionClient.cs + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.AutoGen.Extensions.Anthropic.DTO; +using Microsoft.Extensions.AI; +using MEAI = Microsoft.Extensions.AI; + +namespace Microsoft.AutoGen.Extensions.Anthropic; + +public static class AnthropicChatCompletionDefaults +{ + //public const string DefaultSystemMessage = "You are a helpful AI assistant"; + public const decimal DefaultTemperature = 0.7m; + public const int DefaultMaxTokens = 1024; +} + +public sealed class AnthropicChatCompletionClient : IChatClient, IDisposable +{ + private AnthropicClient? _anthropicClient; + private string _modelId; + + public AnthropicChatCompletionClient(HttpClient httpClient, string modelId, string baseUrl, string apiKey) + : this(new AnthropicClient(httpClient, baseUrl, apiKey), modelId) + { + } + + public AnthropicChatCompletionClient([NotNull] AnthropicClient client, string modelId) + { + if (client == null) + { + throw new ArgumentNullException(nameof(client)); + } + + _anthropicClient = client; + _modelId = modelId; + + if (!Uri.TryCreate(client.BaseUrl, UriKind.Absolute, out Uri? uri)) + { + // technically we should never be able to get this far, in this case + throw new ArgumentException($"Invalid base URL in provided client: {client.BaseUrl}", nameof(client)); + } + + this.Metadata = new ChatClientMetadata("Anthropic", uri, modelId); + } + + public ChatClientMetadata Metadata { get; private set; } + + private DTO.ChatMessage Translate(MEAI.ChatMessage message, List? systemMessagesSink = null) + { + if (message.Role == ChatRole.System && systemMessagesSink != null) + { + if (message.Contents.Count != 1 || message.Text == null) + { + throw new Exception($"Invalid SystemMessage: May only contain a single Text AIContent. Actual: { + String.Join(',', from contentObject in message.Contents select contentObject.GetType()) + }"); + } + + systemMessagesSink.Add(SystemMessage.CreateSystemMessage(message.Text)); + } + + List contents = new(from rawContent in message.Contents select (DTO.ContentBase)rawContent); + return new DTO.ChatMessage(message.Role.ToString().ToLowerInvariant(), contents); + } + + private ChatCompletionRequest CreateRequest(IList chatMessages, ChatOptions? options, bool requestStream) + { + ToolChoice? toolChoice = null; + ChatToolMode? mode = options?.ToolMode; + + if (mode is AutoChatToolMode) + { + toolChoice = ToolChoice.Auto; + } + else if (mode is RequiredChatToolMode requiredToolMode) + { + if (requiredToolMode.RequiredFunctionName == null) + { + toolChoice = ToolChoice.Any; + } + else + { + toolChoice = ToolChoice.ToolUse(requiredToolMode.RequiredFunctionName!); + } + } + + List systemMessages = new List(); + List translatedMessages = new(); + + foreach (MEAI.ChatMessage message in chatMessages) + { + if (message.Role == ChatRole.System) + { + Translate(message, systemMessages); + + // TODO: Should the system messages be included in the translatedMessages list? + } + else + { + translatedMessages.Add(Translate(message)); + } + } + + return new ChatCompletionRequest + { + Model = _modelId, + + // TODO: We should consider coming up with a reasonable default for MaxTokens, since the MAAi APIs do not require + // it, while our wrapper for the Anthropic API does. + MaxTokens = options?.MaxOutputTokens ?? throw new ArgumentException("Must specify number of tokens in request for Anthropic", nameof(options)), + StopSequences = options?.StopSequences?.ToArray(), + Stream = requestStream, + Temperature = (decimal?)options?.Temperature, // TODO: why `decimal`?! + ToolChoice = toolChoice, + Tools = (from abstractTool in options?.Tools + where abstractTool is AIFunction + select (Tool)(AIFunction)abstractTool).ToList(), + TopK = options?.TopK, + TopP = (decimal?)options?.TopP, + SystemMessage = systemMessages.ToArray(), + Messages = translatedMessages, + + // TODO: put these somewhere? .Metadata? + //ModelId = _modelId, + //Options = options + }; + } + + private class ChatCompletionAccumulator + { + public string? CompletionId { get; set; } + public string? ModelId { get; set; } + public MEAI.ChatRole? StreamingRole { get; set; } + public MEAI.ChatFinishReason? FinishReason { get; set; } + // public DateTimeOffset CreatedAt { get; set; } + + public ChatCompletionAccumulator() { } + + public List? AccumulateAndExtractContent(ChatCompletionResponse response) + { + this.CompletionId ??= response.Id; + this.ModelId ??= response.Model; + + this.FinishReason ??= response.StopReason switch + { + "end_turn" => MEAI.ChatFinishReason.Stop, + "stop_sequence" => MEAI.ChatFinishReason.Stop, + "tool_use" => MEAI.ChatFinishReason.ToolCalls, + "max_tokens" => MEAI.ChatFinishReason.Length, + _ => null + }; + + this.StreamingRole ??= response.Role switch + { + "assistant" => MEAI.ChatRole.Assistant, + //"user" => MEAI.ChatRole.User, + //null => null, + _ => throw new InvalidOperationException("Anthropic API is defined to only reply with 'assistant'.") + }; + + if (response.Content == null) + { + return null; + } + + return new(from rawContent in response.Content select (MEAI.AIContent)rawContent); + } + + } + + private MEAI.ChatCompletion TranslateCompletion(ChatCompletionResponse response) + { + ChatCompletionAccumulator accumulator = new ChatCompletionAccumulator(); + List? messageContents = accumulator.AccumulateAndExtractContent(response); + + // According to the Anthropic API docs, the response will contain a single "option" in the MEAI + // parlance, but may contain multiple message? (I suspect for the purposes of tool use) + if (messageContents == null) + { + throw new ArgumentNullException(nameof(response.Content)); + } + else if (messageContents.Count == 0) + { + throw new ArgumentException("Response did not contain any content", nameof(response)); + } + + MEAI.ChatMessage message = new(ChatRole.Assistant, messageContents); + + return new MEAI.ChatCompletion(message) + { + CompletionId = accumulator.CompletionId, + ModelId = accumulator.ModelId, + //CreatedAt = TODO: + FinishReason = accumulator.FinishReason, + // Usage = TODO: extract this from the DTO + RawRepresentation = response + // WIP + }; + } + + private MEAI.StreamingChatCompletionUpdate TranslateStreamingUpdate(ChatCompletionAccumulator accumulator, ChatCompletionResponse response) + { + List? messageContents = accumulator.AccumulateAndExtractContent(response); + + // messageContents will be non-null only on the final "tool_call" stop message update, which will contain + // all of the accumulated ToolUseContent objects. + if (messageContents == null && response.Delta != null && response.Delta.Type == "text_delta") + { + messageContents = new List { new MEAI.TextContent(response.Delta.Text) }; + } + + return new MEAI.StreamingChatCompletionUpdate + { + Role = accumulator.StreamingRole, + CompletionId = accumulator.CompletionId, + ModelId = accumulator.ModelId, + //CreatedAt = TODO: + FinishReason = accumulator.FinishReason, + //ChoiceIndex = 0, + Contents = messageContents, + RawRepresentation = response + }; + } + + public async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + ChatCompletionRequest request = CreateRequest(chatMessages, options, requestStream: false); + ChatCompletionResponse response = await this.EnsureClient().CreateChatCompletionsAsync(request, cancellationToken); + + return TranslateCompletion(response); + } + + private AnthropicClient EnsureClient() + { + return this._anthropicClient ?? throw new ObjectDisposedException(nameof(AnthropicChatCompletionClient)); + } + + public async IAsyncEnumerable CompleteStreamingAsync(IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + ChatCompletionRequest request = CreateRequest(chatMessages, options, requestStream: true); + IAsyncEnumerable responseStream = this.EnsureClient().StreamingChatCompletionsAsync(request, cancellationToken); + + // TODO: There is likely a better way to do this + ChatCompletionAccumulator accumulator = new(); + await foreach (ChatCompletionResponse update in responseStream) + { + yield return TranslateStreamingUpdate(accumulator, update); + } + } + + public void Dispose() + { + Interlocked.Exchange(ref this._anthropicClient, null)?.Dispose(); + } + + public TService? GetService(object? key = null) where TService : class + { + // Implement this based on the example in the M.E.AI.OpenAI implementation + // see: https://github.com/dotnet/extensions/blob/main/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs#L95-L105 + + if (key != null) + { + return null; + } + + if (this is TService result) + { + return result; + } + + if (typeof(TService).IsAssignableFrom(typeof(AnthropicClient))) + { + return (TService)(object)this._anthropicClient!; + } + + return null; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicClient.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicClient.cs new file mode 100644 index 000000000000..07388d6820eb --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicClient.cs @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicClient.cs + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +using Microsoft.AutoGen.Extensions.Anthropic.Converters; +using Microsoft.AutoGen.Extensions.Anthropic.DTO; + +namespace Microsoft.AutoGen.Extensions.Anthropic; + +public sealed class AnthropicClient : IDisposable +{ + private readonly HttpClient _httpClient; + private readonly string _baseUrl; + + private static readonly JsonSerializerOptions JsonSerializerOptions = new() + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + Converters = + { + new ContentBaseConverter(), + new JsonPropertyNameEnumConverter(), + new JsonPropertyNameEnumConverter(), + new SystemMessageConverter(), + } + }; + + public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey) + { + _httpClient = httpClient; + _baseUrl = baseUrl; + + _httpClient.DefaultRequestHeaders.Add("x-api-key", apiKey); + _httpClient.DefaultRequestHeaders.Add("anthropic-version", "2023-06-01"); + } + + internal string BaseUrl => _baseUrl; + + public async Task CreateChatCompletionsAsync(ChatCompletionRequest chatCompletionRequest, + CancellationToken cancellationToken) + { + var httpResponseMessage = await SendRequestAsync(chatCompletionRequest, cancellationToken); + var responseStream = await httpResponseMessage.Content.ReadAsStreamAsync(); + + if (httpResponseMessage.IsSuccessStatusCode) + { + return await DeserializeResponseAsync(responseStream, cancellationToken); + } + + ErrorResponse res = await DeserializeResponseAsync(responseStream, cancellationToken); + throw new Exception(res.Error?.Message); + } + + public async IAsyncEnumerable StreamingChatCompletionsAsync( + ChatCompletionRequest chatCompletionRequest, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var httpResponseMessage = await SendRequestAsync(chatCompletionRequest, cancellationToken); + using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync()); + + var currentEvent = new SseEvent(); + + while (await reader.ReadLineAsync() is { } line) + { + if (!string.IsNullOrEmpty(line)) + { + if (line.StartsWith("event:")) + { + currentEvent.EventType = line.Substring("event:".Length).Trim(); + } + else if (line.StartsWith("data:")) + { + currentEvent.Data = line.Substring("data:".Length).Trim(); + } + } + else // an empty line indicates the end of an event + { + Delta? initialText = null; + if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data)) + { + var dataBlock = JsonSerializer.Deserialize(currentEvent.Data!); + if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use") + { // TODO: verify we never get a non-empty text start content block + currentEvent.ContentBlock = dataBlock.ContentBlock; + } + else if (dataBlock != null && dataBlock.ContentBlock?.Type == "text") + { + initialText = new Delta { Type = "text_delta", Text = dataBlock.ContentBlock?.Text }; + } + } + + if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null) + { + var res = await JsonSerializer.DeserializeAsync( + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), + cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response"); + if (initialText != null) + { + Debug.Assert(res.Delta == null, "content_block_start events should not also contain deltas"); + res.Delta = initialText; + } + + if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) && + currentEvent.ContentBlock != null) + { + currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!); + } + else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null) + { + if (res.Content == null) + { + res.Content = [currentEvent.ContentBlock.CreateToolUseContent()]; + } + else + { + res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent()); + } + + currentEvent = new SseEvent(); + } + + yield return res; + } + else if (currentEvent.EventType == "error" && currentEvent.Data != null) + { + var res = await JsonSerializer.DeserializeAsync( + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken); + + throw new Exception(res?.Error?.Message); + } + + if (currentEvent.ContentBlock == null) + { + currentEvent = new SseEvent(); + } + } + } + } + + private Task SendRequestAsync(T requestObject, CancellationToken cancellationToken) + { + var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _baseUrl); + var jsonRequest = JsonSerializer.Serialize(requestObject, JsonSerializerOptions); + httpRequestMessage.Content = new StringContent(jsonRequest, Encoding.UTF8, "application/json"); + httpRequestMessage.Headers.Add("anthropic-beta", "prompt-caching-2024-07-31"); + return _httpClient.SendAsync(httpRequestMessage, cancellationToken); + } + + private async Task DeserializeResponseAsync(Stream responseStream, CancellationToken cancellationToken) + { + return await JsonSerializer.DeserializeAsync(responseStream, JsonSerializerOptions, cancellationToken) + ?? throw new Exception("Failed to deserialize response"); + } + + public void Dispose() + { + _httpClient.Dispose(); + } + + private struct SseEvent + { + public string EventType { get; set; } + public string? Data { get; set; } + public ContentBlock? ContentBlock { get; set; } + + public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null) + { + EventType = eventType; + Data = data; + ContentBlock = contentBlock; + } + } + + private sealed class ContentBlock + { + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("input")] + public object? Input { get; set; } + + [JsonPropertyName("parameters")] + public string? Parameters { get; set; } + + [JsonPropertyName("text")] + public string? Text { get; set; } + + public void AppendDeltaParameters(string deltaParams) + { + StringBuilder sb = new StringBuilder(Parameters); + sb.Append(deltaParams); + Parameters = sb.ToString(); + } + + public ToolUseContent CreateToolUseContent() + { + return new ToolUseContent { Id = Id, Name = Name, Input = Parameters }; + } + } + + private sealed class DataBlock + { + [JsonPropertyName("content_block")] + public ContentBlock? ContentBlock { get; set; } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/ContentBaseConverter.cs new file mode 100644 index 000000000000..024c0f5fbbe4 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/ContentBaseConverter.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ContentBaseConverter.cs + +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.AutoGen.Extensions.Anthropic.DTO; + +namespace Microsoft.AutoGen.Extensions.Anthropic.Converters; + +public sealed class ContentBaseConverter : JsonConverter +{ + public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + using var doc = JsonDocument.ParseValue(ref reader); + if (doc.RootElement.TryGetProperty("type", out JsonElement typeProperty) && !string.IsNullOrEmpty(typeProperty.GetString())) + { + string? type = typeProperty.GetString(); + var text = doc.RootElement.GetRawText(); + switch (type) + { + case "text": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + case "image": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + case "tool_use": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + case "tool_result": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + } + } + + throw new JsonException("Unknown content type"); + } + + public override void Write(Utf8JsonWriter writer, ContentBase value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value, value.GetType(), options); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/JsonPropertyNameEnumCoverter.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/JsonPropertyNameEnumCoverter.cs new file mode 100644 index 000000000000..f3412643ffc0 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/JsonPropertyNameEnumCoverter.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// JsonPropertyNameEnumCoverter.cs + +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.AutoGen.Extensions.Anthropic.Converters; + +internal sealed class JsonPropertyNameEnumConverter : JsonConverter where T : struct, Enum +{ + public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + string value = reader.GetString() ?? throw new JsonException("Value was null."); + + foreach (var field in typeToConvert.GetFields()) + { + var attribute = field.GetCustomAttribute(); + if (attribute?.Name == value) + { + return (T)Enum.Parse(typeToConvert, field.Name); + } + } + + throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}."); + } + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + { + var field = value.GetType().GetField(value.ToString()); + var attribute = field?.GetCustomAttribute(); + + if (attribute != null) + { + writer.WriteStringValue(attribute.Name); + } + else + { + writer.WriteStringValue(value.ToString()); + } + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/SystemMessageConverter.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/SystemMessageConverter.cs new file mode 100644 index 000000000000..1e7414055d17 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Converters/SystemMessageConverter.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SystemMessageConverter.cs + +using System.Text.Json; +using System.Text.Json.Serialization; +using AutoGen.Anthropic.DTO; + +namespace Microsoft.AutoGen.Extensions.Anthropic.Converters; + +public class SystemMessageConverter : JsonConverter +{ + public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.String) + { + return reader.GetString() ?? string.Empty; + } + if (reader.TokenType == JsonTokenType.StartArray) + { + return JsonSerializer.Deserialize(ref reader, options) ?? throw new InvalidOperationException(); + } + + throw new JsonException(); + } + + public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options) + { + if (value is string stringValue) + { + writer.WriteStringValue(stringValue); + } + else if (value is SystemMessage[] arrayValue) + { + JsonSerializer.Serialize(writer, arrayValue, options); + } + else + { + throw new JsonException(); + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionRequest.cs new file mode 100644 index 000000000000..82ade7def8fa --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionRequest.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionRequest.cs + +using System.Text.Json.Serialization; + +namespace Microsoft.AutoGen.Extensions.Anthropic.DTO; + +public class ChatCompletionRequest +{ + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("messages")] + public List Messages { get; set; } + + [JsonPropertyName("system")] + public SystemMessage[]? SystemMessage { get; set; } + + [JsonPropertyName("max_tokens")] + public int MaxTokens { get; set; } + + [JsonPropertyName("metadata")] + public object? Metadata { get; set; } + + [JsonPropertyName("stop_sequences")] + public string[]? StopSequences { get; set; } + + [JsonPropertyName("stream")] + public bool? Stream { get; set; } + + [JsonPropertyName("temperature")] + public decimal? Temperature { get; set; } + + [JsonPropertyName("top_k")] + public int? TopK { get; set; } + + [JsonPropertyName("top_p")] + public decimal? TopP { get; set; } + + [JsonPropertyName("tools")] + public List? Tools { get; set; } + + [JsonPropertyName("tool_choice")] + public ToolChoice? ToolChoice { get; set; } + + public ChatCompletionRequest() + { + Messages = new List(); + } +} + +public class SystemMessage +{ + [JsonPropertyName("text")] + public string? Text { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; private set; } = "text"; + + [JsonPropertyName("cache_control")] + public CacheControl? CacheControl { get; set; } + + public static SystemMessage CreateSystemMessage(string systemMessage) => new() { Text = systemMessage }; + + public static SystemMessage CreateSystemMessageWithCacheControl(string systemMessage) => new() + { + Text = systemMessage, + CacheControl = new CacheControl { Type = CacheControlType.Ephemeral } + }; +} + +public class ChatMessage +{ + [JsonPropertyName("role")] + public string Role { get; set; } + + [JsonPropertyName("content")] + public List Content { get; set; } + + public ChatMessage(string role, string content) + { + Role = role; + Content = new List() { new TextContent { Text = content } }; + } + + public ChatMessage(string role, List content) + { + Role = role; + Content = content; + } + + public void AddContent(ContentBase content) => Content.Add(content); +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionResponse.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionResponse.cs new file mode 100644 index 000000000000..5ca0212fd692 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ChatCompletionResponse.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionResponse.cs + +using System.Text.Json.Serialization; + +namespace Microsoft.AutoGen.Extensions.Anthropic.DTO; + +public class ChatCompletionResponse +{ + [JsonPropertyName("content")] + public List? Content { get; set; } + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("role")] + public string? Role { get; set; } + + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } + + [JsonPropertyName("stop_sequence")] + public object? StopSequence { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("usage")] + public Usage? Usage { get; set; } + + [JsonPropertyName("delta")] + public Delta? Delta { get; set; } + + [JsonPropertyName("message")] + public StreamingMessage? StreamingMessage { get; set; } +} + +public class StreamingMessage +{ + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("role")] + public string? Role { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("stop_reason")] + public object? StopReason { get; set; } + + [JsonPropertyName("stop_sequence")] + public object? StopSequence { get; set; } + + [JsonPropertyName("usage")] + public Usage? Usage { get; set; } +} + +public class Usage +{ + [JsonPropertyName("input_tokens")] + public int InputTokens { get; set; } + + [JsonPropertyName("output_tokens")] + public int OutputTokens { get; set; } + + [JsonPropertyName("cache_creation_input_tokens")] + public int CacheCreationInputTokens { get; set; } + + [JsonPropertyName("cache_read_input_tokens")] + public int CacheReadInputTokens { get; set; } +} + +public class Delta +{ + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("text")] + public string? Text { get; set; } + + [JsonPropertyName("partial_json")] + public string? PartialJson { get; set; } + + [JsonPropertyName("usage")] + public Usage? Usage { get; set; } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Content.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Content.cs new file mode 100644 index 000000000000..6e3403cb2c9f --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Content.cs @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Content.cs + +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using Microsoft.AutoGen.Extensions.Anthropic.Converters; + +namespace Microsoft.AutoGen.Extensions.Anthropic.DTO; + +public static class AIContentExtensions +{ + public static string? ToBase64String(this ReadOnlyMemory? data) + { + if (data == null) + { + return null; + } + + // TODO: reduce the numbers of copies here ( .ToArray() + .ToBase64String() ) + return Convert.ToBase64String(data.Value.ToArray()); + } + + public static byte[]? FromBase64String(this string? data) + { + if (data == null) + { + return null; + } + + return Convert.FromBase64String(data); + } +} + +public abstract class ContentBase +{ + [JsonPropertyName("type")] + public abstract string Type { get; } + + [JsonPropertyName("cache_control")] + public CacheControl? CacheControl { get; set; } + + public static implicit operator ContentBase(Microsoft.Extensions.AI.AIContent content) + { + return content switch + { + Microsoft.Extensions.AI.TextContent textContent => (TextContent)textContent, + Microsoft.Extensions.AI.ImageContent imageContent => (ImageContent)imageContent, + Microsoft.Extensions.AI.FunctionCallContent functionCallContent => (ToolUseContent)functionCallContent, + Microsoft.Extensions.AI.FunctionResultContent functionResultContent => (ToolResultContent)functionResultContent, + _ => throw new NotSupportedException($"Unsupported content type: {content.GetType()}") + }; + } + + public static implicit operator Microsoft.Extensions.AI.AIContent(ContentBase content) + { + return content switch + { + TextContent textContent => (Microsoft.Extensions.AI.TextContent)textContent, + ImageContent imageContent => (Microsoft.Extensions.AI.ImageContent)imageContent, + ToolUseContent toolUseContent => (Microsoft.Extensions.AI.FunctionCallContent)toolUseContent, + ToolResultContent toolResultContent => (Microsoft.Extensions.AI.FunctionResultContent)toolResultContent, + _ => throw new NotSupportedException($"Unsupported content type: {content.GetType()}") + }; + } +} + +public class TextContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "text"; + + [JsonPropertyName("text")] + public string? Text { get; set; } + + public static TextContent CreateTextWithCacheControl(string text) => new() + { + Text = text, + CacheControl = new CacheControl { Type = CacheControlType.Ephemeral } + }; + + public static implicit operator TextContent(Microsoft.Extensions.AI.TextContent textContent) + { + return new TextContent { Text = textContent.Text }; + } + + public static implicit operator Microsoft.Extensions.AI.TextContent(TextContent textContent) + { + return new Microsoft.Extensions.AI.TextContent(textContent.Text) + { + RawRepresentation = textContent + }; + } +} + +public class ImageContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "image"; + + [JsonPropertyName("source")] + public ImageSource? Source { get; set; } + + public static implicit operator ImageContent(Microsoft.Extensions.AI.ImageContent imageContent) + { + ImageSource source = new ImageSource + { + MediaType = imageContent.MediaType, + }; + + if (imageContent.ContainsData) + { + source.Data = imageContent.Data.ToBase64String(); + } + + return new ImageContent + { + Source = source, + }; + } + + public static implicit operator Microsoft.Extensions.AI.ImageContent(ImageContent imageContent) + { + ReadOnlyMemory imageData = imageContent.Source?.Data.FromBase64String() ?? []; + + return new Microsoft.Extensions.AI.ImageContent(imageData, mediaType: imageContent.Source?.MediaType) + { + RawRepresentation = imageContent + }; + } +} + +public class ImageSource +{ + [JsonPropertyName("type")] + public string Type => "base64"; + + [JsonPropertyName("media_type")] + public string? MediaType { get; set; } + + [JsonPropertyName("data")] + public string? Data { get; set; } +} + +public class ToolUseContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "tool_use"; + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("input")] + public JsonNode? Input { get; set; } + + public static implicit operator ToolUseContent(Microsoft.Extensions.AI.FunctionCallContent functionCallContent) + { + JsonNode? input = functionCallContent.Arguments != null ? JsonSerializer.SerializeToNode(functionCallContent.Arguments) : null; + + return new ToolUseContent + { + Id = functionCallContent.CallId, + Name = functionCallContent.Name, + Input = input + }; + } + + public static implicit operator Microsoft.Extensions.AI.FunctionCallContent(ToolUseContent toolUseContent) + { + // These are an unfortunate incompatibilty between the two libraries (for now); later we can work to + // parse the JSON directly into the M.E.AI types + if (toolUseContent.Id == null) + { + throw new ArgumentNullException(nameof(toolUseContent.Id)); + } + + if (toolUseContent.Name == null) + { + throw new ArgumentNullException(nameof(toolUseContent.Name)); + } + + IDictionary? arguments = null; + if (toolUseContent.Input != null) + { + arguments = JsonSerializer.Deserialize>(toolUseContent.Input); + } + + return new Microsoft.Extensions.AI.FunctionCallContent(toolUseContent.Id, toolUseContent.Name, arguments) + { + RawRepresentation = toolUseContent + }; + } +} + +public class ToolResultContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "tool_result"; + + [JsonPropertyName("tool_use_id")] + public string? Id { get; set; } + + [JsonPropertyName("content")] + public string? Content { get; set; } + + [JsonPropertyName("is_error")] + public bool IsError { get; set; } + + public static implicit operator ToolResultContent(Microsoft.Extensions.AI.FunctionResultContent functionResultContent) + { + // If the result is successful, convert the return object (if any) to the content string + // Otherwise, convert the error message to the content string + string? content = null; + if (functionResultContent.Exception != null) + { + // TODO: Technically, .Result should also contain the error message? + content = functionResultContent.Exception.Message; + } + else if (functionResultContent.Result != null) + { + // If the result is a string, it should just be a passthrough (with enquotation) + content = JsonSerializer.Serialize(functionResultContent.Result); + } + + return new ToolResultContent + { + Id = functionResultContent.CallId, + Content = content, + IsError = functionResultContent.Exception != null + }; + } + + public static implicit operator Microsoft.Extensions.AI.FunctionResultContent(ToolResultContent toolResultContent) + { + if (toolResultContent.Id == null) + { + throw new ArgumentNullException(nameof(toolResultContent.Id)); + } + + // If the content is a string, it should be deserialized as a string + object? result = null; + if (toolResultContent.Content != null) + { + // TODO: Unfortunately, there is no way to get back to the exception from the content, + // since ToolCallResult does not encode a way to determine success from failure + result = JsonSerializer.Deserialize(toolResultContent.Content); + } + + Exception? error = null; + if (toolResultContent.IsError) + { + error = new Exception(toolResultContent.Content); + } + + // TODO: Should we model the name on this object? + return new Microsoft.Extensions.AI.FunctionResultContent(toolResultContent.Id, "", result) + { + Exception = error, + RawRepresentation = toolResultContent + }; + } +} + +public class CacheControl +{ + [JsonPropertyName("type")] + public CacheControlType Type { get; set; } + + public static CacheControl Create() => new CacheControl { Type = CacheControlType.Ephemeral }; +} + +[JsonConverter(typeof(JsonPropertyNameEnumConverter))] +public enum CacheControlType +{ + [JsonPropertyName("ephemeral")] + Ephemeral +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ErrorResponse.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ErrorResponse.cs new file mode 100644 index 000000000000..38dd7411d481 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ErrorResponse.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ErrorResponse.cs + +using System.Text.Json.Serialization; + +namespace Microsoft.AutoGen.Extensions.Anthropic.DTO; + +public sealed class ErrorResponse +{ + [JsonPropertyName("error")] + public Error? Error { get; set; } +} + +public sealed class Error +{ + [JsonPropertyName("Type")] + public string? Type { get; set; } + + [JsonPropertyName("message")] + public string? Message { get; set; } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Tool.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Tool.cs new file mode 100644 index 000000000000..bd2ca6ba46b5 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/Tool.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Tool.cs + +using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; + +namespace Microsoft.AutoGen.Extensions.Anthropic.DTO; + +public class Tool +{ + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("description")] + public string? Description { get; set; } + + [JsonPropertyName("input_schema")] + public InputSchema? InputSchema { get; set; } + + [JsonPropertyName("cache_control")] + public CacheControl? CacheControl { get; set; } + + + + // Implicit conversion operator from M.E.AI.AITool to Tool + public static implicit operator Tool(Microsoft.Extensions.AI.AIFunction tool) + { + return new Tool + { + Name = tool.Metadata.Name, + Description = tool.Metadata.Description, + InputSchema = InputSchema.ExtractSchema(tool), + //CacheControl = null + }; + } +} + +public class InputSchema +{ + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("properties")] + public Dictionary? Properties { get; set; } + + [JsonPropertyName("required")] + public List? Required { get; set; } + + public static InputSchema ExtractSchema(AIFunction tool) => ExtractSchema(tool.Metadata.Parameters); + + private static InputSchema ExtractSchema(IReadOnlyList parameterMetadata) + { + List required = new List(); + Dictionary properties = new Dictionary(); + + foreach (AIFunctionParameterMetadata parameter in parameterMetadata) + { + properties.Add(parameter.Name, new SchemaProperty { Type = parameter.ParameterType?.Name, Description = parameter.Description }); + if (parameter.IsRequired) + { + required.Add(parameter.Name); + } + } + + return new InputSchema { Type = "object", Properties = properties, Required = required }; + } +} + +public class SchemaProperty +{ + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("description")] + public string? Description { get; set; } +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ToolChoice.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ToolChoice.cs new file mode 100644 index 000000000000..2ec788a348c6 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/DTO/ToolChoice.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ToolChoice.cs + +using System.Text.Json.Serialization; +using Microsoft.AutoGen.Extensions.Anthropic.Converters; + +namespace Microsoft.AutoGen.Extensions.Anthropic.DTO; + +[JsonConverter(typeof(JsonPropertyNameEnumConverter))] +public enum ToolChoiceType +{ + [JsonPropertyName("auto")] + Auto, // Default behavior + + [JsonPropertyName("any")] + Any, // Use any provided tool + + [JsonPropertyName("tool")] + Tool // Force a specific tool +} + +public class ToolChoice +{ + [JsonPropertyName("type")] + public ToolChoiceType Type { get; set; } + + [JsonPropertyName("name")] + public string? Name { get; set; } + + private ToolChoice(ToolChoiceType type, string? name = null) + { + Type = type; + Name = name; + } + + public static ToolChoice Auto => new(ToolChoiceType.Auto); + public static ToolChoice Any => new(ToolChoiceType.Any); + public static ToolChoice ToolUse(string name) => new(ToolChoiceType.Tool, name); +} diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Microsoft.AutoGen.Extensions.Anthropic.csproj b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Microsoft.AutoGen.Extensions.Anthropic.csproj new file mode 100644 index 000000000000..846d18f9fca4 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Microsoft.AutoGen.Extensions.Anthropic.csproj @@ -0,0 +1,31 @@ + + + net8.0 + enable + enable + + + + + + + Microsoft.Autogen.Extensions.Anthropic + + Provide support for consuming Anthropic models in AutoGen + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Utils/AnthropicConstants.cs b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Utils/AnthropicConstants.cs new file mode 100644 index 000000000000..739aa680f9a3 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/Utils/AnthropicConstants.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicConstants.cs + +namespace Microsoft.AutoGen.Extensions.Anthropic.Utils; + +public static class AnthropicConstants +{ + public static string Endpoint = "https://api.anthropic.com/v1/messages"; + + // Models + public static string Claude3Opus = "claude-3-opus-20240229"; + public static string Claude3Sonnet = "claude-3-sonnet-20240229"; + public static string Claude3Haiku = "claude-3-haiku-20240307"; + public static string Claude35Sonnet = "claude-3-5-sonnet-20240620"; +}