Skip to content

Commit

Permalink
MacOS / VisionOS support and ability to run Google Gemma LLMs (#47)
Browse files Browse the repository at this point in the history
# MacOS / VisionOS support and ability to run Google Gemma LLMs

## ♻️ Current situation & Problem
As of now, Google's Gemma LLMs are not yet supported in SpeziLLM.
Furthermore, SpeziLLM is currently constrained to the iOS platform.


## ⚙️ Release Notes 
- MacOS / VisionOS support 
- Ability to run Google Gemma LLMs


## 📚 Documentation
Added proper docs


## ✅ Testing
Manual testing

## 📝 Code of Conduct & Contributing Guidelines 

By submitting creating this pull request, you agree to follow our [Code
of
Conduct](https://github.com/StanfordSpezi/.github/blob/main/CODE_OF_CONDUCT.md)
and [Contributing
Guidelines](https://github.com/StanfordSpezi/.github/blob/main/CONTRIBUTING.md):
- [x] I agree to follow the [Code of
Conduct](https://github.com/StanfordSpezi/.github/blob/main/CODE_OF_CONDUCT.md)
and [Contributing
Guidelines](https://github.com/StanfordSpezi/.github/blob/main/CONTRIBUTING.md).
  • Loading branch information
philippzagar authored Mar 7, 2024
1 parent 6892c5d commit ca37910
Show file tree
Hide file tree
Showing 22 changed files with 340 additions and 90 deletions.
24 changes: 12 additions & 12 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ let package = Package(
name: "SpeziLLM",
defaultLocalization: "en",
platforms: [
.iOS(.v17)
.iOS(.v17),
.visionOS(.v1),
.macOS(.v14)
],
products: [
.library(name: "SpeziLLM", targets: ["SpeziLLM"]),
Expand All @@ -25,14 +27,13 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/MacPaw/OpenAI", .upToNextMinor(from: "0.2.6")),
.package(url: "https://github.com/StanfordBDHG/llama.cpp", .upToNextMinor(from: "0.1.8")),
.package(url: "https://github.com/StanfordSpezi/Spezi", from: "1.1.0"),
.package(url: "https://github.com/StanfordSpezi/SpeziStorage", from: "1.0.0"),
.package(url: "https://github.com/StanfordSpezi/SpeziOnboarding", from: "1.0.0"),
.package(url: "https://github.com/StanfordSpezi/SpeziSpeech", from: "1.0.0"),
.package(url: "https://github.com/StanfordSpezi/SpeziChat", .upToNextMinor(from: "0.1.8")),
.package(url: "https://github.com/StanfordSpezi/SpeziViews", from: "1.0.0"),
.package(url: "https://github.com/groue/Semaphore.git", exact: "0.0.8")
.package(url: "https://github.com/StanfordBDHG/llama.cpp", .upToNextMinor(from: "0.2.1")),
.package(url: "https://github.com/StanfordSpezi/Spezi", from: "1.2.1"),
.package(url: "https://github.com/StanfordSpezi/SpeziFoundation", from: "1.0.4"),
.package(url: "https://github.com/StanfordSpezi/SpeziStorage", from: "1.0.2"),
.package(url: "https://github.com/StanfordSpezi/SpeziOnboarding", from: "1.1.1"),
.package(url: "https://github.com/StanfordSpezi/SpeziChat", .upToNextMinor(from: "0.1.9")),
.package(url: "https://github.com/StanfordSpezi/SpeziViews", from: "1.3.1")
],
targets: [
.target(
Expand All @@ -48,7 +49,7 @@ let package = Package(
dependencies: [
.target(name: "SpeziLLM"),
.product(name: "llama", package: "llama.cpp"),
.product(name: "Semaphore", package: "Semaphore"),
.product(name: "SpeziFoundation", package: "SpeziFoundation"),
.product(name: "Spezi", package: "Spezi")
],
swiftSettings: [
Expand All @@ -67,11 +68,10 @@ let package = Package(
dependencies: [
.target(name: "SpeziLLM"),
.product(name: "OpenAI", package: "OpenAI"),
.product(name: "Semaphore", package: "Semaphore"),
.product(name: "SpeziFoundation", package: "SpeziFoundation"),
.product(name: "Spezi", package: "Spezi"),
.product(name: "SpeziChat", package: "SpeziChat"),
.product(name: "SpeziSecureStorage", package: "SpeziStorage"),
.product(name: "SpeziSpeechRecognizer", package: "SpeziSpeech"),
.product(name: "SpeziOnboarding", package: "SpeziOnboarding")
]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,31 @@
//

import Foundation
import llama


/// Represents the configuration of the Spezi ``LLMLocalPlatform``.
public struct LLMLocalPlatformConfiguration: Sendable {
/// Wrapper around the `ggml_numa_strategy` type of llama.cpp, indicating the non-unified memory access configuration of the device.
public enum NonUniformMemoryAccess: UInt32, Sendable {
case disabled
case distributed
case isolated
case numaCtl
case mirror
case count


var wrappedValue: ggml_numa_strategy {
.init(rawValue: self.rawValue)
}
}


/// The task priority of the initiated LLM inference tasks.
let taskPriority: TaskPriority
/// Indicates if this is a device with non-unified memory access.
let nonUniformMemoryAccess: Bool
/// Indicates the non-unified memory access configuration of the device.
let nonUniformMemoryAccess: NonUniformMemoryAccess


/// Creates the ``LLMLocalPlatformConfiguration`` which configures the Spezi ``LLMLocalPlatform``.
Expand All @@ -24,7 +41,7 @@ public struct LLMLocalPlatformConfiguration: Sendable {
/// - nonUniformMemoryAccess: Indicates if this is a device with non-unified memory access.
public init(
taskPriority: TaskPriority = .userInitiated,
nonUniformMemoryAccess: Bool = false
nonUniformMemoryAccess: NonUniformMemoryAccess = .disabled
) {
self.taskPriority = taskPriority
self.nonUniformMemoryAccess = nonUniformMemoryAccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ public struct LLMLocalSamplingParameters: Sendable {
)
}
set {
wrapped.cfg_negative_prompt = std.string(newValue.negativePrompt)
if let negativePrompt = newValue.negativePrompt {
wrapped.cfg_negative_prompt = std.string(negativePrompt)
}
wrapped.cfg_scale = newValue.scale
}
}
Expand Down
9 changes: 5 additions & 4 deletions Sources/SpeziLLMLocal/LLMLocalPlatform.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import Foundation
import llama
import Semaphore
import Spezi
import SpeziFoundation
import SpeziLLM


Expand Down Expand Up @@ -62,15 +62,16 @@ public actor LLMLocalPlatform: LLMPlatform, DefaultInitializable {

public nonisolated func configure() {
// Initialize the llama.cpp backend
llama_backend_init(configuration.nonUniformMemoryAccess)
llama_backend_init()
llama_numa_init(configuration.nonUniformMemoryAccess.wrappedValue)
}

nonisolated public func callAsFunction(with llmSchema: LLMLocalSchema) -> LLMLocalSession {
public nonisolated func callAsFunction(with llmSchema: LLMLocalSchema) -> LLMLocalSession {
LLMLocalSession(self, schema: llmSchema)
}

nonisolated func exclusiveAccess() async throws {
try await semaphore.waitUnlessCancelled()
try await semaphore.waitCheckingCancellation()
await MainActor.run {
state = .processing
}
Expand Down
101 changes: 86 additions & 15 deletions Sources/SpeziLLMLocal/LLMLocalSchema+PromptFormatting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ extension LLMLocalSchema {
/// Holds default prompt formatting strategies for [Llama2](https://ai.meta.com/llama/) as well as [Phi-2](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) models.
public enum PromptFormattingDefaults {
/// Prompt formatting closure for the [Llama2](https://ai.meta.com/llama/) model
public static let llama2: (@Sendable (Chat) throws -> String) = { chat in
public static let llama2: (@Sendable (Chat) throws -> String) = { chat in // swiftlint:disable:this closure_body_length
/// BOS token of the LLM, used at the start of each prompt passage.
let BOS = "<s>"
/// EOS token of the LLM, used at the end of each prompt passage.
Expand All @@ -27,14 +27,26 @@ extension LLMLocalSchema {
/// EOINST token of the LLM, used at the end of the instruction part of the prompt.
let EOINST = "[/INST]"

// Ensure that system prompt as well as a first user prompt exist
guard let systemPrompt = chat.first,
systemPrompt.role == .system,
let initialUserPrompt = chat.indices.contains(1) ? chat[1] : nil,
initialUserPrompt.role == .user else {
guard chat.first?.role == .system else {
throw LLMLocalError.illegalContext
}

var systemPrompts: [String] = []
var initialUserPrompt: String = ""

for chatEntity in chat {
if chatEntity.role != .system {
if chatEntity.role == .user {
initialUserPrompt = chatEntity.content
break
} else {
throw LLMLocalError.illegalContext
}
}

systemPrompts.append(chatEntity.content)
}

/// Build the initial Llama2 prompt structure
///
/// A template of the prompt structure looks like:
Expand All @@ -47,10 +59,10 @@ extension LLMLocalSchema {
/// """
var prompt = """
\(BOS)\(BOINST) \(BOSYS)
\(systemPrompt.content)
\(systemPrompts.joined(separator: " "))
\(EOSYS)
\(initialUserPrompt.content) \(EOINST)
\(initialUserPrompt) \(EOINST)
""" + " " // Add a spacer to the generated output from the model

for chatEntry in chat.dropFirst(2) {
Expand Down Expand Up @@ -78,14 +90,26 @@ extension LLMLocalSchema {

/// Prompt formatting closure for the [Phi-2](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) model
public static let phi2: (@Sendable (Chat) throws -> String) = { chat in
// Ensure that system prompt as well as a first user prompt exist
guard let systemPrompt = chat.first,
systemPrompt.role == .system,
let initialUserPrompt = chat.indices.contains(1) ? chat[1] : nil,
initialUserPrompt.role == .user else {
guard chat.first?.role == .system else {
throw LLMLocalError.illegalContext
}

var systemPrompts: [String] = []
var initialUserPrompt: String = ""

for chatEntity in chat {
if chatEntity.role != .system {
if chatEntity.role == .user {
initialUserPrompt = chatEntity.content
break
} else {
throw LLMLocalError.illegalContext
}
}

systemPrompts.append(chatEntity.content)
}

/// Build the initial Phi-2 prompt structure
///
/// A template of the prompt structure looks like:
Expand All @@ -95,8 +119,8 @@ extension LLMLocalSchema {
/// Output: {model_reply_1}
/// """
var prompt = """
System: \(systemPrompt.content)
Instruct: \(initialUserPrompt.content)\n
System: \(systemPrompts.joined(separator: " "))
Instruct: \(initialUserPrompt)\n
"""

for chatEntry in chat.dropFirst(2) {
Expand All @@ -120,5 +144,52 @@ extension LLMLocalSchema {

return prompt
}

/// Prompt formatting closure for the [Gemma](https://ai.google.dev/gemma/docs/formatting) models
/// - Important: System prompts are ignored as Gemma doesn't support them
public static let gemma: (@Sendable (Chat) throws -> String) = { chat in
/// Start token of Gemma
let startToken = "<start_of_turn>"
/// End token of Gemma
let endToken = "<end_of_turn>"

/// Build the initial Gemma prompt structure
///
/// A template of the prompt structure looks like:
/// """
/// <start_of_turn>user
/// knock knock<end_of_turn>
/// <start_of_turn>model
/// who is there<end_of_turn>
/// <start_of_turn>user
/// Gemma<end_of_turn>
/// <start_of_turn>model
/// Gemma who?<end_of_turn>
/// """
var prompt = ""

for chatEntry in chat {
if chatEntry.role == .assistant {
/// Append response from assistant to the Gemma prompt structure
prompt += """
\(startToken)model
\(chatEntry.content)\(endToken)\n
"""
} else if chatEntry.role == .user {
/// Append response from assistant to the Gemma prompt structure
prompt += """
\(startToken)user
\(chatEntry.content)\(endToken)\n
"""
}
}

/// Model starts responding after
if chat.last?.role == .user {
prompt += "\(startToken)model\n"
}

return prompt
}
}
}
6 changes: 4 additions & 2 deletions Sources/SpeziLLMLocal/LLMLocalSession+Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,16 @@ extension LLMLocalSession {
""")

// Automatically inject the yielded string piece into the `LLMLocal/context`
if schema.injectIntoContext {
if schema.injectIntoContext && nextTokenId != 0 {
let nextStringPiece = nextStringPiece
await MainActor.run {
context.append(assistantOutput: nextStringPiece)
}
}

continuation.yield(nextStringPiece)
if nextTokenId != 0 {
continuation.yield(nextStringPiece)
}

// Prepare the next batch
llama_batch_clear(&batch)
Expand Down
24 changes: 24 additions & 0 deletions Sources/SpeziLLMLocal/LLMLocalSession+Tokenization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,33 @@ extension LLMLocalSession {
// Format the chat into a prompt that conforms to the prompt structure of the respective LLM
let formattedChat = try await schema.formatChat(self.context)

// C++ vector doesn't conform to Swift sequence on VisionOS SDK (Swift C++ Interop bug),
// therefore requiring workaround for VisionSDK
#if !os(visionOS)
var tokens: [LLMLocalToken] = .init(
llama_tokenize_with_context(self.modelContext, std.string(formattedChat), schema.parameters.addBosToken, true)
)
#else
// Swift String to C++ String buggy on VisionOS, workaround via C-based `char` array
guard let cString = formattedChat.cString(using: .utf8) else {
fatalError("SpeziLLMLocal: Couldn't bridge the LLM Swift-based String context to a C-based String.")
}

let cxxTokensVector = llama_tokenize_with_context_from_char_array(self.modelContext, cString, schema.parameters.addBosToken, true)

// Get C array from C++ vector containing the tokenized content
guard var cxxTokensArray = vectorToIntArray(cxxTokensVector) else {
fatalError("SpeziLLMLocal: Couldn't get C array containing the tokenized content from C++ vector.")
}

// Extract tokens from C array to a Swift array
var tokens: [LLMLocalToken] = []

for _ in 0...cxxTokensVector.size() {
tokens.append(cxxTokensArray.pointee)
cxxTokensArray = cxxTokensArray.advanced(by: 1)
}
#endif

// Truncate tokens if there wouldn't be enough context size for the generated output
if tokens.count > Int(schema.contextParameters.contextWindowSize) - schema.parameters.maxOutputLength {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,28 @@ extension LLMLocalDownloadManager {
return url
}

/// Gemma 7B model with `Q4_K_M` quantization (~5GB)
public static var gemma7BModelUrl: URL {
guard let url = URL(string: "https://huggingface.co/rahuldshetty/gemma-7b-it-gguf-quantized/resolve/main/gemma-7b-it-Q4_K_M.gguf") else {
preconditionFailure("""
SpeziLLM: Invalid LLMUrlDefaults LLM download URL.
""")
}

return url
}

/// Gemma 2B model with `Q4_K_M` quantization (~1.5GB)
public static var gemma2BModelUrl: URL {
guard let url = URL(string: "https://huggingface.co/rahuldshetty/gemma-2b-gguf-quantized/resolve/main/gemma-2b-Q4_K_M.gguf") else {
preconditionFailure("""
SpeziLLM: Invalid LLMUrlDefaults LLM download URL.
""")
}

return url
}

/// Tiny LLama 1.1B model with `Q5_K_M` quantization in its chat variation (~800MB)
public static var tinyLLama2ModelUrl: URL {
guard let url = URL(string: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf") else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ public struct LLMOpenAIParameters: Sendable {
/// - overwritingToken: Separate OpenAI token that overrides the one defined within the ``LLMOpenAIPlatform``.
public init(
modelType: Model,
systemPrompt: String = Defaults.defaultOpenAISystemPrompt,
systemPrompt: String? = Defaults.defaultOpenAISystemPrompt,
modelAccessTest: Bool = false,
overwritingToken: String? = nil
) {
self.modelType = modelType
self.systemPrompts = [systemPrompt]
self.systemPrompts = systemPrompt.map { [$0] } ?? []
self.modelAccessTest = modelAccessTest
self.overwritingToken = overwritingToken
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/SpeziLLMOpenAI/Helpers/LLMOpenAIConstants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@

/// Constants used throughout the `SpeziLLMOpenAI` target.
enum LLMOpenAIConstants {
static let credentialsServer = "openapi.org"
static let credentialsServer = "openai.com"
static let credentialsUsername = "OpenAIGPT"
}
Loading

0 comments on commit ca37910

Please sign in to comment.