Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export and run LLMs in C++ #1197

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bert/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

import numpy
import mlx.core as mx
from transformers import AutoModel


Expand All @@ -23,9 +23,9 @@ def convert(bert_model: str, mlx_model: str) -> None:
model = AutoModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
replace_key(key): mx.array(tensor) for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)
mx.save_safetensors(mlx_model, tensors)


if __name__ == "__main__":
Expand All @@ -39,7 +39,7 @@ def convert(bert_model: str, mlx_model: str) -> None:
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
default="bert-base-uncased.safetensors",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()
Expand Down
11 changes: 4 additions & 7 deletions bert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def load_model(

def run(bert_model: str, mlx_model: str, batch: List[str]):
model, tokenizer = load_model(bert_model, mlx_model)

tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}

tokens = tokenizer(batch, return_tensors="mlx", padding=True)
return model(**tokens)


Expand All @@ -149,13 +146,13 @@ def run(bert_model: str, mlx_model: str, batch: List[str]):
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
help="The huggingface name of the BERT model.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
default="bert-base-uncased.safetensors",
help="The path of the stored MLX BERT weights.",
)
parser.add_argument(
"--text",
Expand Down
1 change: 0 additions & 1 deletion bert/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
mlx>=0.0.5
transformers
numpy
4 changes: 2 additions & 2 deletions bert/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def run_torch(bert_model: str, batch: List[str]):
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
default="bert-base-uncased.safetensors",
help="The path of the stored MLX BERT weights.",
)
parser.add_argument(
"--text",
Expand Down
1 change: 0 additions & 1 deletion bert/weights/.gitignore

This file was deleted.

1 change: 1 addition & 0 deletions llms/export/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
33 changes: 33 additions & 0 deletions llms/export/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
cmake_minimum_required(VERSION 3.27)

project(mlxlm LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

add_library(mlxlm)
target_link_libraries(mlxlm PUBLIC mlx)

add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/third_party)

target_sources(mlxlm
PRIVATE
mlxlm.cpp
tokenizer.cpp)

add_executable(main main.cpp)
target_link_libraries(main PRIVATE mlxlm)

add_executable(test test.cpp)
target_link_libraries(test PRIVATE mlxlm)
34 changes: 34 additions & 0 deletions llms/export/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Export LLMs to C++

Export language model inference from Python to run directly in C++.

To run, first install the requirements:

```bash
pip install -U mlx-lm
```

Then generate text from Python with:

```bash
python export.py generate "How tall is K2?"
```

To export the generation function run:

```bash
python export.py export
```

Then build the C++ code (requires CMake):

```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```

And run the generation from C++ with:

```bash
./build/main lama3.1-instruct-4bit "How tall is K2?"
```
171 changes: 171 additions & 0 deletions llms/export/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import time
from pathlib import Path

import fire
import mlx.core as mx
from mlx_lm import load


class ExportableCache:

def __init__(self, keys=None, values=None, offset=0):
self.offset = offset
self.keys = keys
self.values = values

def update_and_fetch(self, keys, values):
if self.keys is not None:
self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,))
self.values = mx.slice_update(self.values, values, self.offset, axes=(2,))
else:
self.keys = keys
self.values = values
return self.keys, self.values

@property
def state(self):
return self.keys, self.values


def expand(cache, mask=None, cache_step_size=256):
cache_size = cache[0].shape[-2]
new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size)

def expand_kv(x):
B, n_heads, _, head_dim = x.shape
new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype)
new_x[..., : x.shape[2], :] = x
return new_x

cache = [expand_kv(c) for c in cache]
if mask is None:
mask = mx.full(new_size, False)
mask[:cache_size] = True
else:
mask = mx.concatenate([mask, mx.full(cache_step_size, False)])
return cache, mask


def causal_mask(N):
idx = mx.arange(N)
return idx[:, None] >= idx


def step(model, y, *state):
mask = state[-1]
if len(state) > 1:
cache, offset = state[:-2], state[-2]
cache = [
ExportableCache(keys, values, offset)
for keys, values in zip(cache[::2], cache[1::2])
]
else:
cache = [ExportableCache() for i in range(len(model.model.layers))]
logits = model(y, cache=cache, mask=mask)
cache = [y for x in cache for y in x.state]
return logits, *cache


def generate_step(prompt, model, max_tokens):
mx.eval(model)

compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True)

def _step(*args):
logits, *cache = compiled_step(*args)
return mx.argmax(logits[:, -1], axis=-1), *cache

y, *cache = _step(prompt, causal_mask(prompt.size))
mx.async_eval(y)
offset = mx.array(prompt.size, mx.uint32)
cache, mask = expand(cache)
n = 0
while True:
if n < max_tokens - 1:
if mask.size <= (prompt.size + n):
cache, mask = expand(cache, mask)
mask[prompt.size + n] = True
next_y, *cache = _step(y[None], *cache, offset, mask)
mx.async_eval(next_y)
offset += 1
n += 1
yield y.item()
if n == max_tokens:
break
y = next_y


def export(
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
path="llama3.1-instruct-4bit",
):
model, tokenizer = load(model)

mx.eval(model)

tokenizer.save_pretrained(path)

_step = lambda *args: step(model, *args)

# Make example inputs
y_prompt = mx.array([[0, 0]], mx.uint32)
y_gen = mx.array([[0]], mx.uint32)
offset = mx.array([0], mx.uint32)

mask = causal_mask(y_prompt.size)
_, *cache = _step(y_prompt, mask)

model_path = str(Path(path) / "model.mlxfn")
with mx.exporter(model_path, _step, shapeless=True) as exporter:
exporter(y_prompt, mask)
cache, mask = expand(cache)
exporter(y_gen, *cache, offset, mask)


def generate(
prompt,
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
max_tokens=128,
):
print("[INFO] Loading model from disk.")
model, tokenizer = load(model)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="mlx",
)

print("[INFO] Starting generation...")
tic = time.time()
tokens = []

detokenizer = tokenizer.detokenizer
detokenizer.reset()

for n, token in enumerate(generate_step(prompt, model, max_tokens)):
if n == 0:
prompt_tps = prompt.size / (time.time() - tic)
tic = time.time()

if token in tokenizer.eos_token_ids:
break
detokenizer.add_token(token)
print(detokenizer.last_segment, end="", flush=True)

detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
gen_tps = (n + 1) / (time.time() - tic)
peak_memory = mx.metal.get_peak_memory() / 1e9
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Peak RAM: {peak_memory:.3f} GB")


if __name__ == "__main__":
fire.Fire(
{
"generate": generate,
"export": export,
}
)
18 changes: 18 additions & 0 deletions llms/export/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.

#include <iostream>

#include "mlxlm.h"

int main(int argc, char *argv[]) {
if (argc < 3) {
std::cerr << "Must provide the model path and prompt." << std::endl;
return 1;
}
auto path = std::string(argv[1]);
auto prompt = std::string(argv[2]);

auto model = load_model(path + "/model.mlxfn");
auto tokenizer = load_tokenizer(path);
generate(model, tokenizer, prompt);
}
Loading