-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #64 from fumito-ito/feature/token-counting
token counting
- Loading branch information
Showing
10 changed files
with
321 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// | ||
// CountTokens.swift | ||
// AnthropicSwiftSDK | ||
// | ||
// Created by 伊藤史 on 2024/11/13. | ||
// | ||
import Foundation | ||
|
||
public struct CountTokens { | ||
/// The API key used for authentication with the Anthropic API. | ||
private let apiKey: String | ||
/// The URL session used for network requests. | ||
private let session: URLSession | ||
|
||
/// Initializes a new instance of `MessageBatches`. | ||
/// | ||
/// - Parameters: | ||
/// - apiKey: The API key for authentication. | ||
/// - session: The URL session for network requests. | ||
init(apiKey: String, session: URLSession) { | ||
self.apiKey = apiKey | ||
self.session = session | ||
} | ||
|
||
public func countTokens( | ||
_ messages: [Message], | ||
model: Model = .claude_3_Opus, | ||
system: [SystemPrompt] = [], | ||
maxTokens: Int, | ||
metaData: MetaData? = nil, | ||
stopSequence: [String]? = nil, | ||
temperature: Double? = nil, | ||
topP: Double? = nil, | ||
topK: Int? = nil, | ||
tools: [Tool]? = nil, | ||
toolChoice: ToolChoice = .auto | ||
) async throws -> CountTokenResponse { | ||
try await countTokens( | ||
messages, | ||
model: model, | ||
system: system, | ||
maxTokens: maxTokens, | ||
metaData: metaData, | ||
stopSequence: stopSequence, | ||
temperature: temperature, | ||
topP: topP, | ||
topK: topK, | ||
tools: tools, | ||
toolChoice: toolChoice, | ||
anthropicHeaderProvider: DefaultAnthropicHeaderProvider(), | ||
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey) | ||
) | ||
} | ||
|
||
public func countTokens( | ||
_ messages: [Message], | ||
model: Model = .claude_3_Opus, | ||
system: [SystemPrompt] = [], | ||
maxTokens: Int, | ||
metaData: MetaData? = nil, | ||
stopSequence: [String]? = nil, | ||
temperature: Double? = nil, | ||
topP: Double? = nil, | ||
topK: Int? = nil, | ||
tools: [Tool]? = nil, | ||
toolChoice: ToolChoice = .auto, | ||
anthropicHeaderProvider: AnthropicHeaderProvider, | ||
authenticationHeaderProvider: AuthenticationHeaderProvider | ||
) async throws -> CountTokenResponse { | ||
let client = APIClient( | ||
session: session, | ||
anthropicHeaderProvider: anthropicHeaderProvider, | ||
authenticationHeaderProvider: authenticationHeaderProvider | ||
) | ||
|
||
let request = CountTokenRequest( | ||
body: .init( | ||
model: model, | ||
messages: messages, | ||
system: system, | ||
maxTokens: maxTokens, | ||
metaData: metaData, | ||
stopSequences: stopSequence, | ||
stream: false, | ||
temperature: temperature, | ||
topP: topP, | ||
topK: topK, | ||
tools: tools, | ||
toolChoice: toolChoice | ||
) | ||
) | ||
|
||
let (data, response) = try await client.send(request: request) | ||
|
||
guard let httpResponse = response as? HTTPURLResponse else { | ||
throw ClientError.cannotHandleURLResponse(response) | ||
} | ||
|
||
guard httpResponse.statusCode == 200 else { | ||
throw AnthropicAPIError(fromHttpStatusCode: httpResponse.statusCode) | ||
} | ||
|
||
return try anthropicJSONDecoder.decode(CountTokenResponse.self, from: data) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
Sources/AnthropicSwiftSDK/Network/Request/CountTokenRequest.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// | ||
// CountTokenRequest.swift | ||
// AnthropicSwiftSDK | ||
// | ||
// Created by 伊藤史 on 2024/11/13. | ||
// | ||
|
||
import Foundation | ||
|
||
struct CountTokenRequest: Request { | ||
typealias Body = CountTokenRequestBody | ||
|
||
let method: HttpMethod = .post | ||
let path: String = RequestType.countTokens.basePath | ||
let queries: [String: CustomStringConvertible]? = nil | ||
let body: Body? | ||
} | ||
|
||
// MARK: Request Body | ||
|
||
/// Request object for Count Token API | ||
/// | ||
/// a structured list of input messages with text and/or image content, and the model will generate the next message in the conversation. | ||
struct CountTokenRequestBody: Encodable { | ||
/// The model that will complete your prompt. | ||
let model: Model | ||
/// Input messages. | ||
let messages: [Message] | ||
/// System prompt. | ||
/// | ||
/// A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role. | ||
let system: [SystemPrompt] | ||
/// The maximum number of tokens to generate before stopping. | ||
/// | ||
/// Note that our models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. | ||
/// Different models have different maximum values for this parameter. | ||
let maxTokens: Int | ||
/// An object describing metadata about the request. | ||
let metaData: MetaData? | ||
/// Custom text sequences that will cause the model to stop generating. | ||
let stopSequences: [String]? | ||
/// Whether to incrementally stream the response using server-sent events. | ||
/// | ||
/// see [streaming](https://docs.anthropic.com/claude/reference/messages-streaming) for more detail. | ||
let stream: Bool | ||
/// Amount of randomness injected into the response. | ||
/// | ||
/// Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. | ||
/// Note that even with temperature of 0.0, the results will not be fully deterministic. | ||
let temperature: Double? | ||
/// Use nucleus sampling. | ||
/// | ||
/// In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. | ||
/// Recommended for advanced use cases only. You usually only need to use temperature. | ||
let topP: Double? | ||
/// Only sample from the top K options for each subsequent token. | ||
/// | ||
/// Used to remove "long tail" low probability responses. | ||
/// Recommended for advanced use cases only. You usually only need to use temperature. | ||
let topK: Int? | ||
/// Definition of tools with names, descriptions, and input schemas in your API request. | ||
let tools: [Tool]? | ||
/// Definition whether or not to force Claude to use the tool. ToolChoice should be set if tools are specified. | ||
let toolChoice: ToolChoice? | ||
|
||
init( | ||
model: Model = .claude_3_Opus, | ||
messages: [Message], | ||
system: [SystemPrompt] = [], | ||
maxTokens: Int, | ||
metaData: MetaData? = nil, | ||
stopSequences: [String]? = nil, | ||
stream: Bool = false, | ||
temperature: Double? = nil, | ||
topP: Double? = nil, | ||
topK: Int? = nil, | ||
tools: [Tool]? = nil, | ||
toolChoice: ToolChoice = .auto | ||
) { | ||
self.model = model | ||
self.messages = messages | ||
self.system = system | ||
self.maxTokens = maxTokens | ||
self.metaData = metaData | ||
self.stopSequences = stopSequences | ||
self.stream = stream | ||
self.temperature = temperature | ||
self.topP = topP | ||
self.topK = topK | ||
self.tools = tools | ||
self.toolChoice = tools == nil ? nil : toolChoice // ToolChoice should be set if tools are specified. | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
Sources/AnthropicSwiftSDK/Network/Response/CountTokenResponse.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// | ||
// CountTokenResponse.swift | ||
// AnthropicSwiftSDK | ||
// | ||
// Created by 伊藤史 on 2024/11/13. | ||
// | ||
|
||
/// Billing and rate-limit usage. | ||
public struct CountTokenResponse: Decodable { | ||
/// The number of input tokens which were used. | ||
public let inputTokens: Int? | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
Tests/AnthropicSwiftSDKTests/Network/Request/CountTokenRequestTests.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// | ||
// CountTokenRequestTests.swift | ||
// AnthropicSwiftSDK | ||
// | ||
// Created by 伊藤史 on 2024/11/13. | ||
// | ||
|
||
import XCTest | ||
@testable import AnthropicSwiftSDK | ||
|
||
final class CountTokenRequestTests: XCTestCase { | ||
func testEncoding() throws { | ||
// Prepare test data | ||
let message = Message(role: .user, content: [.text("Hello")]) | ||
let systemPrompt = SystemPrompt.text("You are a helpful assistant", nil) | ||
|
||
let sut = CountTokenRequestBody( | ||
model: .claude_3_Opus, | ||
messages: [message], | ||
system: [systemPrompt], | ||
maxTokens: 1000, | ||
metaData: .init(userId: "test-user"), | ||
stopSequences: ["STOP"], | ||
stream: true, | ||
temperature: 0.7, | ||
topP: 0.9, | ||
topK: 10, | ||
tools: nil, | ||
toolChoice: .auto | ||
) | ||
|
||
// Encode to JSON | ||
let encoder = JSONEncoder() | ||
encoder.keyEncodingStrategy = .convertToSnakeCase | ||
let data = try encoder.encode(sut) | ||
let json = try JSONSerialization.jsonObject(with: data) as! [String: Any] | ||
|
||
// Verify basic properties | ||
XCTAssertEqual(json["model"] as? String, "claude-3-opus-20240229") | ||
XCTAssertEqual((json["messages"] as? [[String: Any]])?.count, 1) | ||
XCTAssertEqual((json["system"] as? [[String: Any]])?.count, 1) | ||
XCTAssertEqual(json["max_tokens"] as? Int, 1000) | ||
XCTAssertNotNil(json["meta_data"]) | ||
XCTAssertEqual((json["stop_sequences"] as? [String]), ["STOP"]) | ||
XCTAssertEqual(json["stream"] as? Bool, true) | ||
XCTAssertEqual(json["temperature"] as? Double, 0.7) | ||
XCTAssertEqual(json["top_p"] as? Double, 0.9) | ||
XCTAssertEqual(json["top_k"] as? Int, 10) | ||
XCTAssertNil(json["tools"]) // Verify tools is nil | ||
XCTAssertNil(json["tool_choice"]) | ||
} | ||
|
||
func testEncodingWithMinimalParameters() throws { | ||
// Test with only required parameters | ||
let message = Message(role: .user, content: [.text("Hello")]) | ||
let sut = CountTokenRequestBody( | ||
messages: [message], | ||
maxTokens: 1000 | ||
) | ||
|
||
let encoder = JSONEncoder() | ||
encoder.keyEncodingStrategy = .convertToSnakeCase | ||
let data = try encoder.encode(sut) | ||
let json = try JSONSerialization.jsonObject(with: data) as! [String: Any] | ||
|
||
// Verify minimal configuration | ||
XCTAssertEqual(json["model"] as? String, "claude-3-opus-20240229") | ||
XCTAssertEqual((json["messages"] as? [[String: Any]])?.count, 1) | ||
XCTAssertEqual(json["max_tokens"] as? Int, 1000) | ||
XCTAssertNil(json["tool_choice"]) // Verify tool_choice is nil when tools are not specified | ||
} | ||
} |