Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weirdness with tokenization in Phi-3 #12

Open
uogbuji opened this issue Jul 23, 2024 · 3 comments
Open

Weirdness with tokenization in Phi-3 #12

uogbuji opened this issue Jul 23, 2024 · 3 comments
Assignees

Comments

@uogbuji
Copy link
Contributor

uogbuji commented Jul 23, 2024

Server:

toolio_server --model=mlx-community/Phi-3-mini-128k-instruct-4bit

Client:

toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow?'

You can run the above any number of times, but as soon as you run a version that tries to use a prior prompt cache:

toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow? Where have I heard that before?'

It blows up. Server exception tail:

  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/cli/server.py", line 271, in post_v1_chat_completions_impl
    for result in app.state.model.completion(
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 296, in completion
    logits, cache = self._evaluate_prompt(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 92, in _evaluate_prompt
    logits = self.model(mx.array(tokens)[None], cache)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 202, in __call__
    out = self.model(inputs, cache)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 184, in __call__
    h = layer(h, mask, c)
        ^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 148, in __call__
    r = self.self_attn(self.input_layernorm(x), mask, cache)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 110, in __call__
    output = mx.fast.scaled_dot_product_attention(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Shapes (1,32,9,24) and (9,9) cannot be broadcast.

Modified schema_helper.py for a trace

    def _evaluate_prompt(
        self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
    ):
        if prior_prompt:
            i = 0
            for i, t in enumerate(prior_prompt):
                # Need to leave at least one token to evaluate because we don't
                # save the past logits.
                if i >= len(prompt) - 1 or prompt[i] != t:
                    break
            cache = prior_cache
            for layer_cache in cache:
                layer_cache.reuse(len(prompt), i)
            tokens = prompt[i:]
            print('CACHED', tokens, prompt)
        else:
            cache = ReusableKVCache.for_model(self.model)
            tokens = prompt
            print('UNCACHED', tokens)

        logits = self.model(mx.array(tokens)[None], cache)
        return logits, cache

First run of the shorter prompt displays:

UNCACHED [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]

Already notice the repeated 32007, which is the Phi-3 '<|end|>' token. This is probably not good. Identical run again:

CACHED [32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]

Expected logic, with nothing but that end token post-cache. Now the longer prompt:

CACHED [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]

End prompt is re-doubled.

At this point I don't know whether this tokenizer oddness is what leads to the shape error, but it's a start for investigating.

@uogbuji uogbuji self-assigned this Jul 23, 2024
@uogbuji
Copy link
Contributor Author

uogbuji commented Jul 23, 2024

Quick look at the Phi-3 tokenizer:

import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained(
    # Should be same tokenizer as microsoft/Phi-3-mini-128k-instruct-4bit
    'mlx-community/Phi-3-mini-128k-instruct-4bit'
)
S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=False)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))

S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=True)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))

Output:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[15043, 32007]
'Hello<|end|>'
[15043, 32007]
'Hello<|end|>'

The 'Special tokens' warning comes up as soon as you load the tokenizer, and has nothing to do with , add_special_tokens=True|False later on.

repr of tokenizer:

LlamaTokenizerFast(name_or_path='mlx-community/Phi-3-mini-128k-instruct-4bit', vocab_size=32000, model_max_length=131072, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '<|end|>', 'unk_token': '<unk>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
        0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        2: AddedToken("</s>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=False),
        32000: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        32001: AddedToken("<|assistant|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32002: AddedToken("<|placeholder1|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32003: AddedToken("<|placeholder2|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32004: AddedToken("<|placeholder3|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32005: AddedToken("<|placeholder4|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32006: AddedToken("<|system|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32007: AddedToken("<|end|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32008: AddedToken("<|placeholder5|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32009: AddedToken("<|placeholder6|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32010: AddedToken("<|user|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
}

So yes, Phi-3 uses the Llama tokenizer. Notice that the special tokens are added with rstrip=True, i.e. with ws normalization.

@uogbuji
Copy link
Contributor Author

uogbuji commented Jul 23, 2024

A trimmed down repro case:

import mlx.core as mx
from toolio.schema_helper import Model, ReusableKVCache

m = Model()
m.load('mlx-community/Phi-3-mini-128k-instruct-4bit')
from mlx_lm.models.base import KVCache
cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]
logits = m.model(mx.array(tokens1)[None], cache)

cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None

tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
for layer_cache in cache:
    layer_cache.reuse(len(tokens2), len(tokens2)-1)

logits = m.model(mx.array(tokens2_postcache)[None], cache)

Result: ValueError: Shapes (1,32,9,32) and (9,9) cannot be broadcast.

Note: just blindly replacing all cases of 32007, 32007 merely tweaked the error: ValueError: Shapes (1,32,8,30) and (8,8) cannot be broadcast.

cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007]
logits = m.model(mx.array(tokens1)[None], cache)

cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None

tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007]
for layer_cache in cache:
    layer_cache.reuse(len(tokens2), len(tokens2)-1)

logits = m.model(mx.array(tokens2_postcache)[None], cache)

uogbuji added a commit that referenced this issue Jul 23, 2024
…ching until we figure out #12. Another apparent shape fix.
uogbuji added a commit that referenced this issue Jul 23, 2024
…ching until we figure out #12. Another apparent shape fix.
@uogbuji
Copy link
Contributor Author

uogbuji commented Aug 2, 2024

For now I've got around this by disabling cache prompting by default. I'll leave the ticket open, though, because it would be nice to work a proper fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant