From af4018f9768aaad504e0245db16358972dbd6ebe Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 21 Jan 2025 22:23:52 +0000 Subject: [PATCH] [Llama] Remove DirectKVCache (#854) From my reading of it, it looks like DirectKVCache was used when PagedKVCache wasn't implemented properly, but now we have different llm layers for Paged variants and Direct variants. Paged variants explicitly use PagedKVCache, while Direct variants don't really need a cache, just need a cache k/v tensor passed to them through the direct layer. My main problem with DirectKVCache was it's being used inconsistently, with paged variants trying to if case if it's a DirectKVCache. --- .../export_layer/export_paged_attention.py | 2 +- sharktank/sharktank/layers/__init__.py | 2 +- sharktank/sharktank/layers/kv_cache.py | 198 +--------------- .../sharktank/layers/llama_attention_block.py | 16 +- sharktank/sharktank/models/grok/grok.py | 2 +- sharktank/sharktank/models/llama/llama.py | 2 +- sharktank/sharktank/models/mixtral/mixtral.py | 2 +- sharktank/sharktank/utils/create_cache.py | 38 ++- sharktank/tests/layers/kv_cache_test.py | 224 ------------------ .../paged_llama_attention_block_test.py | 4 +- .../tests/models/llama/attention_test.py | 2 +- sharktank/tests/models/llama/kv_cache_test.py | 57 ++--- .../tests/models/llama/sharded_llama_test.py | 14 +- 13 files changed, 61 insertions(+), 502 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py index cb28371bb..186fb5154 100644 --- a/sharktank/sharktank/export_layer/export_paged_attention.py +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -236,7 +236,7 @@ def main(): model = PagedLlamaAttentionBlock( theta=attention_block_theta, block_index=0, - cache=create_kv_cache(llama_config), + cache=create_paged_kv_cache(llama_config), head_count=llama_config.hp.attention_head_count, head_dim=llama_config.hp.attn_head_dim, head_count_kv=llama_config.hp.attention_head_count_kv, diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 3caf7631d..9842a8291 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -6,7 +6,7 @@ from .base import BaseLayer, ThetaLayer from .conv import Conv2DLayer -from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache +from .kv_cache import PagedKVCache from .causal_llm import BaseCausalLMModel from .linear import LinearLayer from .norm import RMSNormLayer, LayerNorm diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index f62002f46..6af0e5183 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -22,204 +22,10 @@ from ..types import SplitPrimitiveTensor, ReplicatedTensor from .. import ops -__all__ = [ - "BaseKVCache", - "DirectKVCache", - "PagedKVCache", -] +__all__ = ["PagedKVCache"] -class BaseKVCache(abc.ABC): - """Base class for a KV cache. - - This doesn't do much on its own except to serve as a type-safe base class - unifying the PagedKVCache and DirectKVCache: - - * PagedKVCache is a shared cache which can be used across an arbitrary - number of batches/sequences with random mapping of blocks within a - sequence to backing "page". - * DirectKVCache is a single-batch cache with a fixed batch size and - sequence length where the K/V cache tensors for each transformer block - are densely layed out in memory. - """ - - block_seq_stride: int - transformer_block_count: int - attn_head_count: int - attn_head_dim: int - - @property - @abc.abstractmethod - def pad_sequence_stride(self) -> int: - """Stride that a sequence must be padded to in order to be valid for - the cache. For paged caches, this will typically be a multiple of the - block_seq_stride. For direct caches it may be 1 or a multiple that - is chosen for performance reasons. - """ - ... - - @property - def is_paged(self) -> bool: - return isinstance(self, PagedKVCache) - - @property - def is_direct(self) -> bool: - return isinstance(self, DirectKVCache) - - @property - def paged(self) -> "PagedKVCache": - assert isinstance( - self, PagedKVCache - ), f"Attempt to access cache {type(self)} as paged but it is not" - return self - - @property - def direct(self) -> "DirectKVCache": - assert isinstance( - self, DirectKVCache - ), f"Attempt to access cache {type(self)} as direct but it is not" - return self - - -class DirectKVCache(BaseKVCache): - """KVCache for a single batch where the cache tensors are densely laid out.""" - - def __init__( - self, - *, - block_seq_stride: int, - transformer_block_count: int, - attn_head_count: int, - attn_head_dim: int, - seq_length: int, - shard_count: int = 1, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - ): - self.block_seq_stride = block_seq_stride - self.transformer_block_count = transformer_block_count - self.attn_head_count = attn_head_count - self.attn_head_dim = attn_head_dim - self.seq_length = seq_length - self.shard_count = shard_count - self.device = device - self.dtype = dtype - - @property - def pad_sequence_stride(self) -> int: - return self.block_seq_stride - - def allocate(self, *, bs: int) -> list[torch.Tensor]: - """Allocates 2*transformer_block_count K/V cache tensors for the - given batch size and sequence length. - - Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] - """ - allocations = [ - torch.empty( - [ - bs, - self.seq_length, - self.attn_head_count, - self.attn_head_dim, - ], - dtype=self.dtype, - device=self.device, - ) - for _ in range(2 * self.transformer_block_count) - ] - - if self.shard_count == 1: - return allocations - - return [ - ops.reshard_split(allocation, dim=2, count=self.shard_count) - for allocation in allocations - ] - - def read( - self, - state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - *, - read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], - transformer_block_index: int, - seq_len: int, - page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, - ): - """Reads cache partitions from the page table for the given page_ids. - - Args: - state: State struct as returned from allocate(). - read_into_partitions: List of cache partitions to read into in-place. - transformer_block_index: The index of the transformer block accessing - the cache. - page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids - to access. - - Returns a tuple of cache partitions (i.e. k and v caches for the transformer - block), linearized. Note that this reference approach to reading by - materializing linearly may not be terribly efficient unless if the - compiler can fuse the gather. - """ - read_count = len(read_into_partitions) - reads = [] - for i in range(read_count): - reads.append( - state[transformer_block_index * read_count + i][:, :seq_len, :, :] - ) - - return tuple(reads) - - def write_timestep( - self, - state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - # List of [bs, 1, attn_head_count, attn_head_dim] - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], - *, - transformer_block_index: int, - # [bs] - seq_positions: Union[torch.Tensor, ReplicatedTensor], - # [bs, max_seqlen // block_pos_stride] - page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, - ): - """Writes a single batched timestep across all cache partitions. - - Note that this internally loops over the batch size, which cannot be - dynamic. - """ - bs, _, _, _ = cache_partitions[0].shape - update_count = len(cache_partitions) - - for b in range(bs): - row_index = torch.tensor([b], dtype=torch.int64) - row_start_pos = seq_positions[row_index].unsqueeze(0) - - for i, update in enumerate(cache_partitions): - cache = state[transformer_block_index * update_count + i] - cache.index_put_((row_index, row_start_pos), update[row_index, 0]) - - def write( - self, - state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], - *, - transformer_block_index: int, - page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, - ): - """Writes cache partitions from a linear layout to the page table. - - This is the inverse of the linear read. The same caveat applies if the - in-place scatter cannot be fused. - """ - update_count = len(cache_partitions) - - for idx, update_src in enumerate(cache_partitions): - cache_dest = state[transformer_block_index * update_count + idx] - _, batch_seq_len, _, _ = update_src.shape - cache_dest[:, :batch_seq_len, :, :] = update_src - - -class PagedKVCache(BaseKVCache): +class PagedKVCache: """Implementation of a KV cache on top of a 'page table'. The page table slab is physically represented as a 2D tensor: diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index 0cdb5d713..38cb7c0fa 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Optional +import math import torch import torch.nn.functional as F @@ -29,7 +30,6 @@ def __init__( head_count: int, head_dim: int, head_count_kv: int, - embedding: RotaryEmbeddingLayer, rms_epsilon: float, ): super().__init__(theta) @@ -41,7 +41,6 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - self.embedding = embedding self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv @@ -50,6 +49,7 @@ def forward( self, h: torch.Tensor, *, + embedding: RotaryEmbeddingLayer, cache_k: torch.Tensor, cache_v: torch.Tensor, start_index: int, @@ -72,11 +72,11 @@ def forward( # Fast path to start_index based embedding lookup if available. # Falls back to a slower position based index lookup. if start_index is not None: - xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) + xq = embedding.forward(xt=xq, start_index=start_index) + xk = embedding.forward(xt=xk, start_index=start_index) else: - xq, xk = embedding.apply_batched_mask( - xq=xq, xk=xk, mask=embedding_batch_mask - ) + xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask) + xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -108,9 +108,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( - self.head_dim - ) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 077e4e064..0be9ede05 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -59,7 +59,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = create_kv_cache(self.config) + self.cache = create_paged_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.add_module( "token_embedding", diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 6fef6704e..0bb6985e7 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -74,7 +74,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = create_kv_cache(self.config) + self.cache = create_paged_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.use_hf = config.use_hf self.attention_kernel = config.attention_kernel diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index e2995dfde..b597c6e99 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -61,7 +61,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = create_kv_cache(self.config) + self.cache = create_paged_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.add_module( "token_embedding", diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py index c1691c8a8..f462d9c00 100644 --- a/sharktank/sharktank/utils/create_cache.py +++ b/sharktank/sharktank/utils/create_cache.py @@ -7,28 +7,18 @@ from ..layers import * -def create_kv_cache(config: LlamaModelConfig) -> BaseKVCache: +def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: + if config.kv_cache_type != "paged": + raise ValueError("Model does not use paged kv cache, cannot create kv cache") + hp = config.hp - if config.kv_cache_type == "direct": - return DirectKVCache( - block_seq_stride=config.block_seq_stride, - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - seq_length=hp.context_length, - device=config.device, - dtype=config.attention_dtype, - ) - elif config.kv_cache_type == "paged": - return PagedKVCache( - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. - block_seq_stride=config.block_seq_stride, - device=config.device, - dtype=config.attention_dtype, - shard_count=config.tensor_parallelism_size, - ) - else: - raise NotImplementedError(f"kv_cache_type = {config.kv_cache_type}") + return PagedKVCache( + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=config.block_seq_stride, + device=config.device, + dtype=config.attention_dtype, + shard_count=config.tensor_parallelism_size, + ) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 65b42c986..8512c8768 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -13,230 +13,6 @@ from sharktank.types import * -def test_direct(): - bs = 4 - seq_length = 24 - attn_head_count = 4 - attn_head_dim = 16 - transformer_block_count = 4 - cache = DirectKVCache( - block_seq_stride=4, - transformer_block_count=transformer_block_count, - attn_head_count=attn_head_count, - attn_head_dim=attn_head_dim, - seq_length=seq_length, - dtype=torch.float32, - device=None, - ) - - allocation = cache.allocate(bs=bs) - allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] - - write_seq_length = seq_length - 5 - - # Write a prefill in: - write_ones = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 - ) - write_twos = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 - ) - cache.write( - allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 - ) - - # Check the written values have updated: - read_empty = [ - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length, - ) - torch.testing.assert_close(write_ones, read_back[0]) - torch.testing.assert_close(write_twos, read_back[1]) - - # Check the others are still zero: - for i in range(transformer_block_count): - if i == 1: - continue - read_ones = [ - torch.zeros( - (bs, write_seq_length, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] - read_ones = cache.read( - allocation, - read_into_partitions=read_ones, - transformer_block_index=i, - seq_len=write_seq_length, - ) - torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) - torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) - - # Write timestep - write_threes = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 - ) - write_fours = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 - ) - write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) - cache.write_timestep( - allocation, - cache_partitions=[write_threes, write_fours], - transformer_block_index=1, - seq_positions=write_pos, - ) - - read_empty = [ - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length + 1, - ) - - check_concat_0 = torch.concat([write_ones, write_threes], dim=1) - check_concat_1 = torch.concat([write_twos, write_fours], dim=1) - - torch.testing.assert_close(check_concat_0, read_back[0]) - torch.testing.assert_close(check_concat_1, read_back[1]) - - -def test_sharded_direct(): - bs = 4 - seq_length = 24 - attn_head_count = 8 - attn_head_dim = 16 - transformer_block_count = 4 - shard_count = 4 - cache = DirectKVCache( - block_seq_stride=4, - transformer_block_count=transformer_block_count, - attn_head_count=attn_head_count, - attn_head_dim=attn_head_dim, - seq_length=seq_length, - shard_count=shard_count, - dtype=torch.float32, - device=None, - ) - - allocation = cache.allocate(bs=bs) - # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] - - write_seq_length = seq_length - 5 - - # Write a prefill in: - write_ones = reshard_split( - torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), - 1.0, - dtype=torch.float32, - ), - dim=2, - count=shard_count, - ) - - write_twos = reshard_split( - torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), - 2.0, - dtype=torch.float32, - ), - dim=2, - count=shard_count, - ) - - cache.write( - allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 - ) - - # Check the written values have updated: - read_empty = [ - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length, - ) - torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) - torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) - - # Write timestep - write_threes = reshard_split( - torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), - dim=2, - count=shard_count, - ) - write_fours = reshard_split( - torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), - dim=2, - count=shard_count, - ) - - write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count - ) - cache.write_timestep( - allocation, - cache_partitions=[write_threes, write_fours], - transformer_block_index=1, - seq_positions=write_pos, - ) - - read_empty = [ - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length + 1, - ) - - check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) - check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) - - torch.testing.assert_close(check_concat_0, unshard(read_back[0])) - torch.testing.assert_close(check_concat_1, unshard(read_back[1])) - - def test_paged(): bs = 4 seq_length = 24 diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index bbb52f235..e74a14ad5 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -57,7 +57,7 @@ def testExportDecomposed(self): dtype=dtype, ) - cache_state = cache.paged.allocate(self.page_count) + cache_state = cache.allocate(self.page_count) cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) theta = make_llama_attention_block_theta( @@ -140,7 +140,7 @@ def testExportNondecomposed(self): dtype=dtype, ) - cache_state = cache.paged.allocate(self.page_count) + cache_state = cache.allocate(self.page_count) cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) theta = make_llama_attention_block_theta( diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index 211fab5a0..22013635b 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -77,7 +77,7 @@ def test(self): input_tensor, embedding=attention_embedding, start_index=0, - cache_state=paged_kv_cache.paged.allocate(128), + cache_state=paged_kv_cache.allocate(128), seq_block_ids=torch.arange(seq_len).view(1, -1), ) diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index a80575951..3d43243b0 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn from sharktank.models.llama.llama import ( + LlamaAttentionBlock, PagedLlamaAttentionBlock, PagedKVCache, - DirectKVCache, ) from sharktank.models.llama.testing import * from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer @@ -48,15 +48,14 @@ def setUp(self): device=self.device, dtype=self.attention_dtype, ) - self.direct_kv_cache = DirectKVCache( - block_seq_stride=self.block_seq_stride, - transformer_block_count=self.head_count, - attn_head_count=self.head_count, - attn_head_dim=self.head_dim, - seq_length=self.max_seq_len, - device=self.device, - dtype=self.attention_dtype, - ) + self.direct_k_cache = [ + torch.empty([self.bs, self.max_seq_len, self.head_count_kv, self.head_dim]) + for _ in range(self.block_count) + ] + self.direct_v_cache = [ + torch.empty([self.bs, self.max_seq_len, self.head_count_kv, self.head_dim]) + for _ in range(self.block_count) + ] self.attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, rope_freq_base=self.rope_freq_base, @@ -80,10 +79,8 @@ def setUp(self): ) self.direct_attn_blocks = nn.ModuleList( [ - PagedLlamaAttentionBlock( + LlamaAttentionBlock( theta=self.attention_block_theta, - block_index=n, - cache=self.direct_kv_cache, head_count=self.head_count, head_dim=self.head_dim, head_count_kv=self.head_count_kv, @@ -98,12 +95,6 @@ def setUp(self): [127], ] ) - self.direct_cache_state = self.direct_kv_cache.allocate(bs=1) - self.direct_seq_block_ids = torch.tensor( - [ - [0], - ] - ) self.embedding_batch_mask = self.attention_embedding.compute_batch_mask( self.start_positions, batch_seq_len=1 ) @@ -139,8 +130,8 @@ def testDirectAndPagedKVCachePrefill(self): embedding=self.attention_embedding, start_index=0, attention_mask=self.prefill_attention_mask, - cache_state=self.direct_cache_state, - seq_block_ids=self.direct_seq_block_ids, + cache_k=self.direct_k_cache[block_idx], + cache_v=self.direct_v_cache[block_idx], ) page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) index_written = self.start_positions.item() @@ -150,13 +141,13 @@ def testDirectAndPagedKVCachePrefill(self): """ page_id = self.paged_seq_block_ids[0][0].item() """ - direct_cache_state is a list of num_transformer_blocks * 2 (one for K and one for V), - so here we index into the first transformer block's keys with self.direct_cache_state[0] - and the first transformer block's values with self.direct_cache_state[1]. Each row + direct_cache_state is a list of num_transformer_blocks (one for K and one for V), + so here we index into the first transformer block's keys with self.direct_k_cache[0] + and the first transformer block's values with self.direct_v_cache[0]. Each row in direct_cache_state is a tensor of [bs, seq_len , attn_heads, attn_dim], so we make sure the first 8 (start_position) tensors starting at sequence 0 of the seq_len are written to. """ - updated_direct_cache_state = self.direct_cache_state[0][ + updated_direct_k_cache_state = self.direct_k_cache[0][ :, :index_written ].squeeze(0) """ @@ -172,10 +163,10 @@ def testDirectAndPagedKVCachePrefill(self): first transformer block's K cache for the first 8 (start_positions) tensors starting at sequence 0. """ - updated_paged_cache_state = page_table[page_id][0, 0, :index_written] - assert updated_direct_cache_state.shape == updated_paged_cache_state.shape + updated_paged_k_cache_state = page_table[page_id][0, 0, :index_written] + assert updated_direct_k_cache_state.shape == updated_paged_k_cache_state.shape torch.testing.assert_close( - updated_direct_cache_state, updated_paged_cache_state + updated_direct_k_cache_state, updated_paged_k_cache_state ) paged_prefill_attn_output = paged_input_tensor @@ -246,20 +237,18 @@ def testDirectAndPagedKVCacheDecode(self): embedding=self.attention_embedding, embedding_batch_mask=self.embedding_batch_mask, attention_mask=decode_attention_mask, - cache_state=self.direct_cache_state, - seq_block_ids=self.direct_seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, + cache_k=self.direct_k_cache, + cache_v=self.direct_v_cache, ) page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) index_written = self.start_positions.item() page_id = self.paged_seq_block_ids[0][0].item() - updated_direct_cache_state_keys = self.direct_cache_state[0][ + updated_direct_cache_state_keys = self.direct_k_cache[0][ :, index_written ].squeeze(0) updated_paged_cache_state_keys = page_table[page_id][0, 0, index_written] - updated_direct_cache_state_values = self.direct_cache_state[1][ + updated_direct_cache_state_values = self.direct_v_cache[0][ :, index_written ].squeeze(0) updated_paged_cache_state_values = page_table[page_id][0, 1, index_written] diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..e78be1cbe 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -94,7 +94,7 @@ def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: seq_block_ids = torch.arange( self.batch_size * batch_seq_len // self.config.block_seq_stride ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) + cache_state = model.cache.allocate(page_count=self.cache_page_count) cache_state = [torch.rand_like(cache_state[0])] return OrderedDict( [ @@ -109,14 +109,14 @@ def make_equal_unsharded_and_sharded_prefill_args( self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: prefill_kwargs = self.make_prefill_args(model) - sharded_cache_state = sharded_model.cache.paged.allocate( + sharded_cache_state = sharded_model.cache.allocate( page_count=self.cache_page_count ) assert iterables_equal( prefill_kwargs["cache_state"][0].shape, sharded_cache_state[0].shape ) sharded_prefill_kwargs = deepcopy(prefill_kwargs) - sharded_cache_state = sharded_model.cache.paged.shard_state( + sharded_cache_state = sharded_model.cache.shard_state( sharded_prefill_kwargs["cache_state"] ) sharded_prefill_kwargs["cache_state"] = sharded_cache_state @@ -149,7 +149,7 @@ def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: seq_block_ids = torch.arange( self.batch_size * batch_seq_len // self.config.block_seq_stride ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) + cache_state = model.cache.allocate(page_count=self.cache_page_count) cache_state = [torch.rand_like(cache_state[0])] return OrderedDict( [ @@ -166,7 +166,7 @@ def make_equal_unsharded_and_sharded_decode_args( ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: decode_kwargs = self.make_decode_args(model) sharded_decode_kwargs = deepcopy(decode_kwargs) - sharded_decode_kwargs["cache_state"] = sharded_model.cache.paged.shard_state( + sharded_decode_kwargs["cache_state"] = sharded_model.cache.shard_state( sharded_decode_kwargs["cache_state"] ) @@ -203,7 +203,7 @@ def testCompareToySizedModelToUnsharded(self): ) expected_cache_state = prefill_kwargs["cache_state"][0] actual_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( + sharded_model.cache.unflatten_page_table( sharded_prefill_kwargs["cache_state"] ) ).flatten(start_dim=1) @@ -224,7 +224,7 @@ def testCompareToySizedModelToUnsharded(self): ) expected_decode_cache_state = decode_kwargs["cache_state"][0] actual_decode_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( + sharded_model.cache.unflatten_page_table( sharded_decode_kwargs["cache_state"] ) ).flatten(start_dim=1)