Skip to content

Commit

Permalink
[Llama] Remove inplace read for KVCache
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Jan 21, 2025
1 parent af4018f commit bf52d93
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 263 deletions.
36 changes: 1 addition & 35 deletions sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def paged_attention(
start_positions: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_state: list[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):

bs, batch_seq_len, _, _ = xq.shape
Expand All @@ -54,8 +52,6 @@ def paged_attention(
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
elif attention_block.cache.is_direct:
xk, xv = attention_block.transact_cache_direct(
Expand Down Expand Up @@ -112,35 +108,7 @@ def run_llama(
start_positions: Optional[torch.Tensor] = None,
):

if phase == "decode":
bs, _, _, _ = xq.shape

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
xk_temp = torch.empty(
[
bs,
config.hp.context_length,
config.hp.attention_head_count_kv,
config.hp.attn_head_dim,
],
dtype=config.activation_dtype,
device=config.device,
)
xv_temp = torch.empty(
[
bs,
config.hp.context_length,
config.hp.attention_head_count_kv,
config.hp.attn_head_dim,
],
dtype=config.activation_dtype,
device=config.device,
)
elif phase == "prefill":
xk_temp = None
xv_temp = None
else:
if phase not in ["prefill", "decode"]:
raise ValueError("'phase' argument needs to be either 'prefill' or 'decode'")

h = paged_attention(
Expand All @@ -153,8 +121,6 @@ def run_llama(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

return h
Expand Down
31 changes: 11 additions & 20 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,22 @@ def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Reads cache partitions from the page table for the given page_ids.
"""Reads K/V caches the page table for the given page_ids.
Args:
state: State struct as returned from allocate().
read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
to access.
Returns a tuple of cache partitions (i.e. k and v caches for the transformer
block), linearized. Note that this reference approach to reading by
materializing linearly may not be terribly efficient unless if the
compiler can fuse the gather.
Returns the K/V cache partitions, linearized. Note that this reference
approach to reading by materializing linearly may not be terribly
efficient unless if the compiler can fuse the gather.
"""
page_table = self.unflatten_page_table(state) # 6D

Expand Down Expand Up @@ -204,32 +201,26 @@ def read(
transformer_block_index * transformer_block_stride
)

def read_cache_partition(
index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
def read_cache_partition(index: int):
subblock_ids = base_subblock_ids + index
# TODO: Potentially clamp all page 0 indices to the mask value.
# Or even better, require that the ids are replicated such that access is
# legal.
# Now for each of the k/v attn_block_ids, which have been adjusted to
# index into the sub-pages, we flatten to do a linear index_select
# copy of the sub-blocks by collapsing the first two dims so we have
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
selected = (
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
# trace_tensor("kv.selected", selected)
into_partition[...] = selected
return selected

for index, read_into_partition in enumerate(read_into_partitions):
read_cache_partition(index, read_into_partition)
key = read_cache_partition(0)
value = read_cache_partition(1)

return tuple([p[:, :seq_len, :] for p in read_into_partitions])
return key[:, :seq_len], value[:, :seq_len]

def write_timestep(
self,
Expand Down
11 changes: 0 additions & 11 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
embedding_batch_mask: Optional[torch.Tensor] = None,
cache_state: list[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)

Expand Down Expand Up @@ -158,8 +156,6 @@ def forward(
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Expand kv heads for GQA.
Expand Down Expand Up @@ -245,8 +241,6 @@ def transact_cache(
seq_block_ids: Optional[torch.Tensor],
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
cache = self.cache
# Manage the cache.
Expand All @@ -266,7 +260,6 @@ def transact_cache(
# use a memory efficient attention kernel that can do indirect
# reads, skipping this materialization. This path is taken for
# a decode step.
assert xk_temp is not None and xv_temp is not None
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride
Expand All @@ -286,10 +279,6 @@ def transact_cache(
# Restore from the cache.
xk, xv = cache.read(
cache_state,
read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
seq_len=kv_seq_len,
Expand Down
26 changes: 0 additions & 26 deletions sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,37 +170,13 @@ def decode(
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)
bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
embedding_batch_mask = self.attention_embedding.compute_batch_mask(
start_positions, batch_seq_len=1
)
self.trace_tensor("grok.embedding_batch_mask", embedding_batch_mask)

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
xk_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)

h = self.token_embedding(tokens)
h *= math.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)
Expand All @@ -220,8 +196,6 @@ def decode(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)

Expand Down
52 changes: 0 additions & 52 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,59 +186,13 @@ def decode(
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)

bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
embedding_batch_mask = self.attention_embedding.compute_batch_mask(
start_positions, batch_seq_len=1
)
self.trace_tensor("llama.embedding_batch_mask", embedding_batch_mask)

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
if self.config.tensor_parallelism_size == 1:
xk_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
else:
shard_size = [
bs,
self.context_length,
self.hp.attention_head_count_kv // self.config.tensor_parallelism_size,
self.hp.attn_head_dim,
]
xk_temp_shard = [
torch.empty(
shard_size, dtype=self.config.activation_dtype, device=self.device
)
for _ in range(self.config.tensor_parallelism_size)
]
xv_temp_shard = [
torch.empty(
shard_size, dtype=self.config.activation_dtype, device=self.device
)
for _ in range(self.config.tensor_parallelism_size)
]
xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2)
xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2)

h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)

Expand All @@ -254,8 +208,6 @@ def decode(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"llama.attn_block.{block_idx}.output", h)

Expand Down Expand Up @@ -323,8 +275,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
embedding_batch_mask: Optional[torch.Tensor] = None,
cache_state: list[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
h = self.attn(
h,
Expand All @@ -335,8 +285,6 @@ def forward(
attention_mask=attention_mask,
embedding_batch_mask=embedding_batch_mask,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Feed forward network.
Expand Down
25 changes: 0 additions & 25 deletions sharktank/sharktank/models/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,29 +177,6 @@ def decode(
)
self.trace_tensor("mixtral.embedding_batch_mask", embedding_batch_mask)

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
xk_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)

h = self.token_embedding(tokens)
self.trace_tensor("mixtral.token_embedding", h)

Expand All @@ -218,8 +195,6 @@ def decode(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)

Expand Down
Loading

0 comments on commit bf52d93

Please sign in to comment.