Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quantized cache #30

Merged
merged 32 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cfa4e58
fix generate_answer for quantized cache
maxjeblick Dec 9, 2024
190b3c4
fix value chache pruning
maxjeblick Dec 9, 2024
4255ba1
improve test
maxjeblick Dec 9, 2024
d766ce9
improve test
maxjeblick Dec 9, 2024
9c30df1
add integration tests
maxjeblick Dec 10, 2024
4df78ef
Merge branch 'main' into max/fix_quanto_cache
maxjeblick Dec 10, 2024
f753eae
get correct context length
maxjeblick Dec 10, 2024
7649ab9
fix qunatized key cache
maxjeblick Dec 10, 2024
11a4c45
fix qunatized key cache
maxjeblick Dec 10, 2024
cfec9b8
fix test
maxjeblick Dec 10, 2024
1bfe667
fix test
maxjeblick Dec 10, 2024
d46cc55
fix test
maxjeblick Dec 10, 2024
524fe47
add more asserts
maxjeblick Dec 10, 2024
c07cc4d
fix test
maxjeblick Dec 10, 2024
6ff4bc9
fix test
maxjeblick Dec 10, 2024
960e05e
fix test
maxjeblick Dec 10, 2024
5c5e1b9
Merge branch 'main' into max/fix_quanto_cache
maxjeblick Dec 10, 2024
58bc8ae
fix merge conflicts
maxjeblick Dec 10, 2024
16e8671
fix failing tests
maxjeblick Dec 10, 2024
d9887ea
import flash attn skip
maxjeblick Dec 10, 2024
f5d9d59
fix test
maxjeblick Dec 10, 2024
c620cb0
add integration tests
maxjeblick Dec 10, 2024
ca4f0ba
add integration tests
maxjeblick Dec 10, 2024
702888a
add integration tests
maxjeblick Dec 10, 2024
4369c02
add fixture
maxjeblick Dec 10, 2024
22c9549
easen up test
maxjeblick Dec 10, 2024
4b66477
undo vvariable extraction
maxjeblick Dec 10, 2024
8006a75
undo newlines
maxjeblick Dec 10, 2024
8eb3b3d
Merge branch 'main' into max/fix_quanto_cache
maxjeblick Dec 11, 2024
519222b
address pr feedback
maxjeblick Dec 11, 2024
f75cb61
fix broken test
maxjeblick Dec 11, 2024
261bcea
fix broken test
maxjeblick Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pipe(..., cache=cache)
By default, the `DynamicCache` is used (no quantization).

> [!IMPORTANT]
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto==0.2.4`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)).
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)).


## FAQ
Expand Down
5 changes: 2 additions & 3 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
Expand All @@ -14,7 +15,7 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.tova_press import TOVAPress

__all__ = [
"BasePress",
Expand All @@ -32,5 +33,3 @@
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
]

from kvpress.presses.tova_press import TOVAPress
26 changes: 18 additions & 8 deletions kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

Expand Down Expand Up @@ -248,13 +248,23 @@ def generate_answer(
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)

# Remove the generated tokens from the cache
if isinstance(cache, QuantizedCache):
key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache"
else:
key_attr, value_attr = "key_cache", "value_cache"

setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)])
setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)])
cache.key_cache = [
cache.key_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
cache.value_cache = [
cache.value_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
if hasattr(cache, "_quantized_key_cache"):
cache._quantized_key_cache = [
cache._quantized_key_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
cache._quantized_value_cache = [
cache._quantized_value_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]

return answer

Expand Down
5 changes: 4 additions & 1 deletion kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
Qwen2ForCausalLM,
QuantizedCache,
Qwen2ForCausalLM,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,6 +111,9 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache._seen_tokens = keys.shape[2]
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values
Expand Down
1 change: 1 addition & 0 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass

from kvpress.presses.base_press import BasePress


Expand Down
26 changes: 25 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,33 @@ def danube_500m_model():


@pytest.fixture(scope="session")
def kv_press_pipeline():
def kv_press_unit_test_pipeline():
return pipeline(
"kv-press-text-generation",
model="maxjeblick/llama2-0b-unit-test",
device=0 if torch.cuda.is_available() else -1,
)


@pytest.fixture(scope="session")
def kv_press_danube_pipeline():
return pipeline(
"kv-press-text-generation",
model="h2oai/h2o-danube3-500m-chat",
device=0 if torch.cuda.is_available() else -1,
)


@pytest.fixture(scope="session")
def kv_press_llama3_1_flash_attn_pipeline():
device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
torch_dtype="auto",
model_kwargs={"attn_implementation": attn_implementation},
)
return pipe
Empty file added tests/integration/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions tests/integration/test_ruler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import datasets
import pytest
import torch
from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available

from kvpress import (
ExpectedAttentionPress,
KnormPress,
SimLayerKVPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)
from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401


@pytest.fixture(scope="session")
def df_ruler():
df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas()
df = df.loc[df["task"] == "niah_multikey_1"].reset_index(drop=True)
return df


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
SimJeg marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress]
)
@pytest.mark.parametrize("compression_ratio", [0.1, 0.2])
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811
if cls == ThinKPress:
press = cls(key_channel_compression_ratio=compression_ratio, window_size=2)
elif cls == SimLayerKVPress:
press = cls(lazy_threshold=1 - compression_ratio)
else:
press = cls(compression_ratio=compression_ratio)
if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
elif cache == "quantized" and not is_optimum_quanto_available():
pytest.skip("Quanto is not installed")
else:
raise ValueError(f"Unknown cache type: {cache}")

idx = 0
context = df_ruler.iloc[idx]["context"]
question = df_ruler.iloc[idx]["question"]
true_answer = df_ruler.iloc[idx]["answer"][0]

pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"]
assert true_answer in pred_answer
Empty file added tests/presses/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
KnormPress,
ObservedAttentionPress,
RandomPress,
SnapKVPress,
SimLayerKVPress,
SnapKVPress,
StreamingLLMPress,
TOVAPress,
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@


from kvpress import KnormPress
from tests.fixtures import kv_press_pipeline # noqa: F401
from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401


def test_generate(kv_press_pipeline): # noqa: F811
def test_generate(kv_press_unit_test_pipeline): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
press = KnormPress(compression_ratio=0.4)

# Answer with pipeline
pipe_answer = kv_press_pipeline(context, press=press, max_new_tokens=10)["answer"]
pipe_answer = kv_press_unit_test_pipeline(context, press=press, max_new_tokens=10)["answer"]

# Answer with model.generate
context += "\n" # kv press pipeline automatically adds a newline if no chat template
model = kv_press_pipeline.model
tokenizer = kv_press_pipeline.tokenizer
model = kv_press_unit_test_pipeline.model
tokenizer = kv_press_unit_test_pipeline.tokenizer
with press(model):
inputs = tokenizer(context, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_per_layer_compression_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from tests.fixtures import kv_press_pipeline, unit_test_model # noqa: F401
from tests.fixtures import unit_test_model # noqa: F401


def test_per_layer_compression_press(unit_test_model): # noqa: F811
Expand Down
58 changes: 47 additions & 11 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@

import pytest
import torch
from transformers import AutoTokenizer, DynamicCache
from transformers import AutoTokenizer, DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache
from transformers.utils import is_optimum_quanto_available

from kvpress import ExpectedAttentionPress
from kvpress.pipeline import KVPressTextGenerationPipeline
from tests.fixtures import danube_500m_model, kv_press_pipeline, unit_test_model # noqa: F401
from tests.fixtures import danube_500m_model # noqa: F401
from tests.fixtures import kv_press_danube_pipeline # noqa: F401
from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401
from tests.fixtures import unit_test_model # noqa: F401


def test_pipeline(kv_press_pipeline, caplog): # noqa: F811
def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
answers = kv_press_pipeline(context, questions=questions, press=press)["answers"]
answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"]

assert len(answers) == 1
assert isinstance(answers[0], str)
Expand All @@ -28,12 +32,23 @@ def test_pipeline(kv_press_pipeline, caplog): # noqa: F811
assert "Compressed Context Length: 13" in messages, messages


def test_pipeline_with_cache(kv_press_unit_test_pipeline, caplog): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
cache = DynamicCache()
answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"]

assert len(answers) == 1
assert isinstance(answers[0], str)


@pytest.mark.parametrize("question", ["When was this article written?", ""])
def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): # noqa: F811
def test_pipeline_single_or_no_question(kv_press_unit_test_pipeline, question, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
context = "This is a test article. It was written on 2022-01-01."
press = ExpectedAttentionPress(compression_ratio=0.4)
answer = kv_press_pipeline(context, question=question, press=press)["answer"]
answer = kv_press_unit_test_pipeline(context, question=question, press=press)["answer"]

assert isinstance(answer, str)

Expand All @@ -42,10 +57,10 @@ def test_pipeline_single_or_no_question(kv_press_pipeline, question, caplog): #
assert "Compressed Context Length: 13" in messages, messages


def test_pipeline_no_press_works(kv_press_pipeline, caplog): # noqa: F811
def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
question = "When was this article written?"
kv_press_pipeline(context, question=question)
kv_press_unit_test_pipeline(context, question=question)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
Expand All @@ -61,11 +76,32 @@ def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811
assert "Compressed Context Length: 16" in messages


@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available")
def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]

assert len(answers) == 1
assert isinstance(answers[0], str)

for answer in answers:
assert answer == "This article was written on January 1, 2022."

messages = [record.message for record in caplog.records]
assert "Context Length: 28" in messages
assert "Compressed Context Length: 16" in messages


def test_pipeline_compresses_context(unit_test_model, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
answers = generate_answer(unit_test_model)

assert len(answers) == 1
assert len(answers) == 2
assert isinstance(answers[0], str)

messages = [record.message for record in caplog.records]
Expand All @@ -79,7 +115,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811
questions = ["When was this article written?"]
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)

compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=torch.device("cpu"))
input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"]

seq_len = 256
Expand All @@ -100,7 +136,7 @@ def generate_answer(model):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
questions = ["When was this article written?", "When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)(
Expand Down
Loading