Skip to content

Commit

Permalink
[Llama] Remove DirectKVCache (#854)
Browse files Browse the repository at this point in the history
From my reading of it, it looks like DirectKVCache was used when
PagedKVCache wasn't implemented properly, but now we have different llm
layers for Paged variants and Direct variants. Paged variants explicitly
use PagedKVCache, while Direct variants don't really need a cache, just
need a cache k/v tensor passed to them through the direct layer.

My main problem with DirectKVCache was it's being used inconsistently,
with paged variants trying to if case if it's a DirectKVCache.
  • Loading branch information
Groverkss authored Jan 21, 2025
1 parent c2b365b commit af4018f
Show file tree
Hide file tree
Showing 13 changed files with 61 additions and 502 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def main():
model = PagedLlamaAttentionBlock(
theta=attention_block_theta,
block_index=0,
cache=create_kv_cache(llama_config),
cache=create_paged_kv_cache(llama_config),
head_count=llama_config.hp.attention_head_count,
head_dim=llama_config.hp.attn_head_dim,
head_count_kv=llama_config.hp.attention_head_count_kv,
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .base import BaseLayer, ThetaLayer
from .conv import Conv2DLayer
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .kv_cache import PagedKVCache
from .causal_llm import BaseCausalLMModel
from .linear import LinearLayer
from .norm import RMSNormLayer, LayerNorm
Expand Down
198 changes: 2 additions & 196 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,204 +22,10 @@
from ..types import SplitPrimitiveTensor, ReplicatedTensor
from .. import ops

__all__ = [
"BaseKVCache",
"DirectKVCache",
"PagedKVCache",
]
__all__ = ["PagedKVCache"]


class BaseKVCache(abc.ABC):
"""Base class for a KV cache.
This doesn't do much on its own except to serve as a type-safe base class
unifying the PagedKVCache and DirectKVCache:
* PagedKVCache is a shared cache which can be used across an arbitrary
number of batches/sequences with random mapping of blocks within a
sequence to backing "page".
* DirectKVCache is a single-batch cache with a fixed batch size and
sequence length where the K/V cache tensors for each transformer block
are densely layed out in memory.
"""

block_seq_stride: int
transformer_block_count: int
attn_head_count: int
attn_head_dim: int

@property
@abc.abstractmethod
def pad_sequence_stride(self) -> int:
"""Stride that a sequence must be padded to in order to be valid for
the cache. For paged caches, this will typically be a multiple of the
block_seq_stride. For direct caches it may be 1 or a multiple that
is chosen for performance reasons.
"""
...

@property
def is_paged(self) -> bool:
return isinstance(self, PagedKVCache)

@property
def is_direct(self) -> bool:
return isinstance(self, DirectKVCache)

@property
def paged(self) -> "PagedKVCache":
assert isinstance(
self, PagedKVCache
), f"Attempt to access cache {type(self)} as paged but it is not"
return self

@property
def direct(self) -> "DirectKVCache":
assert isinstance(
self, DirectKVCache
), f"Attempt to access cache {type(self)} as direct but it is not"
return self


class DirectKVCache(BaseKVCache):
"""KVCache for a single batch where the cache tensors are densely laid out."""

def __init__(
self,
*,
block_seq_stride: int,
transformer_block_count: int,
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
shard_count: int = 1,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
self.block_seq_stride = block_seq_stride
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length
self.shard_count = shard_count
self.device = device
self.dtype = dtype

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, *, bs: int) -> list[torch.Tensor]:
"""Allocates 2*transformer_block_count K/V cache tensors for the
given batch size and sequence length.
Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
"""
allocations = [
torch.empty(
[
bs,
self.seq_length,
self.attn_head_count,
self.attn_head_dim,
],
dtype=self.dtype,
device=self.device,
)
for _ in range(2 * self.transformer_block_count)
]

if self.shard_count == 1:
return allocations

return [
ops.reshard_split(allocation, dim=2, count=self.shard_count)
for allocation in allocations
]

def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Reads cache partitions from the page table for the given page_ids.
Args:
state: State struct as returned from allocate().
read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
to access.
Returns a tuple of cache partitions (i.e. k and v caches for the transformer
block), linearized. Note that this reference approach to reading by
materializing linearly may not be terribly efficient unless if the
compiler can fuse the gather.
"""
read_count = len(read_into_partitions)
reads = []
for i in range(read_count):
reads.append(
state[transformer_block_index * read_count + i][:, :seq_len, :, :]
)

return tuple(reads)

def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Writes a single batched timestep across all cache partitions.
Note that this internally loops over the batch size, which cannot be
dynamic.
"""
bs, _, _, _ = cache_partitions[0].shape
update_count = len(cache_partitions)

for b in range(bs):
row_index = torch.tensor([b], dtype=torch.int64)
row_start_pos = seq_positions[row_index].unsqueeze(0)

for i, update in enumerate(cache_partitions):
cache = state[transformer_block_index * update_count + i]
cache.index_put_((row_index, row_start_pos), update[row_index, 0])

def write(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Writes cache partitions from a linear layout to the page table.
This is the inverse of the linear read. The same caveat applies if the
in-place scatter cannot be fused.
"""
update_count = len(cache_partitions)

for idx, update_src in enumerate(cache_partitions):
cache_dest = state[transformer_block_index * update_count + idx]
_, batch_seq_len, _, _ = update_src.shape
cache_dest[:, :batch_seq_len, :, :] = update_src


class PagedKVCache(BaseKVCache):
class PagedKVCache:
"""Implementation of a KV cache on top of a 'page table'.
The page table slab is physically represented as a 2D tensor:
Expand Down
16 changes: 7 additions & 9 deletions sharktank/sharktank/layers/llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional
import math

import torch
import torch.nn.functional as F
Expand All @@ -29,7 +30,6 @@ def __init__(
head_count: int,
head_dim: int,
head_count_kv: int,
embedding: RotaryEmbeddingLayer,
rms_epsilon: float,
):
super().__init__(theta)
Expand All @@ -41,7 +41,6 @@ def __init__(
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))

self.embedding = embedding
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
Expand All @@ -50,6 +49,7 @@ def forward(
self,
h: torch.Tensor,
*,
embedding: RotaryEmbeddingLayer,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
start_index: int,
Expand All @@ -72,11 +72,11 @@ def forward(
# Fast path to start_index based embedding lookup if available.
# Falls back to a slower position based index lookup.
if start_index is not None:
xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index)
xq = embedding.forward(xt=xq, start_index=start_index)
xk = embedding.forward(xt=xk, start_index=start_index)
else:
xq, xk = embedding.apply_batched_mask(
xq=xq, xk=xk, mask=embedding_batch_mask
)
xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask)
xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask)

# Expand kv heads for GQA.
gqa_n_rep = self.head_count // self.head_count_kv
Expand Down Expand Up @@ -108,9 +108,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = values.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt(
self.head_dim
)
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

# Apply attention mask.
if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
)
self.config = config
self.hp = hp
self.cache = create_kv_cache(self.config)
self.cache = create_paged_kv_cache(self.config)
self.activation_dtype = config.activation_dtype
self.add_module(
"token_embedding",
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
)
self.config = config
self.hp = hp
self.cache = create_kv_cache(self.config)
self.cache = create_paged_kv_cache(self.config)
self.activation_dtype = config.activation_dtype
self.use_hf = config.use_hf
self.attention_kernel = config.attention_kernel
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
)
self.config = config
self.hp = hp
self.cache = create_kv_cache(self.config)
self.cache = create_paged_kv_cache(self.config)
self.activation_dtype = config.activation_dtype
self.add_module(
"token_embedding",
Expand Down
38 changes: 14 additions & 24 deletions sharktank/sharktank/utils/create_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,18 @@
from ..layers import *


def create_kv_cache(config: LlamaModelConfig) -> BaseKVCache:
def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache:
if config.kv_cache_type != "paged":
raise ValueError("Model does not use paged kv cache, cannot create kv cache")

hp = config.hp
if config.kv_cache_type == "direct":
return DirectKVCache(
block_seq_stride=config.block_seq_stride,
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
seq_length=hp.context_length,
device=config.device,
dtype=config.attention_dtype,
)
elif config.kv_cache_type == "paged":
return PagedKVCache(
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=config.block_seq_stride,
device=config.device,
dtype=config.attention_dtype,
shard_count=config.tensor_parallelism_size,
)
else:
raise NotImplementedError(f"kv_cache_type = {config.kv_cache_type}")
return PagedKVCache(
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=config.block_seq_stride,
device=config.device,
dtype=config.attention_dtype,
shard_count=config.tensor_parallelism_size,
)
Loading

0 comments on commit af4018f

Please sign in to comment.