From bf52d93e646d623ae58d90e94a73856040a944d6 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 20 Jan 2025 20:10:47 +0000 Subject: [PATCH] [Llama] Remove inplace read for KVCache --- .../export_layer/export_paged_attention.py | 36 +--------- sharktank/sharktank/layers/kv_cache.py | 31 +++----- .../layers/paged_llama_attention_block.py | 11 --- sharktank/sharktank/models/grok/grok.py | 26 ------- sharktank/sharktank/models/llama/llama.py | 52 -------------- sharktank/sharktank/models/mixtral/mixtral.py | 25 ------- sharktank/tests/layers/kv_cache_test.py | 71 ------------------- .../layers/sharded_paged_kv_cache_test.py | 26 +------ 8 files changed, 15 insertions(+), 263 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py index 186fb5154..e9d284111 100644 --- a/sharktank/sharktank/export_layer/export_paged_attention.py +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -37,8 +37,6 @@ def paged_attention( start_positions: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cache_state: list[torch.Tensor] = None, - xk_temp: Optional[torch.Tensor] = None, - xv_temp: Optional[torch.Tensor] = None, ): bs, batch_seq_len, _, _ = xq.shape @@ -54,8 +52,6 @@ def paged_attention( kv_seq_len=kv_seq_len, start_positions=start_positions, cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, ) elif attention_block.cache.is_direct: xk, xv = attention_block.transact_cache_direct( @@ -112,35 +108,7 @@ def run_llama( start_positions: Optional[torch.Tensor] = None, ): - if phase == "decode": - bs, _, _, _ = xq.shape - - # Allocate per-block temporary K/V tensors. These temporaries hold - # one block's K/V state for the maximum context length. - xk_temp = torch.empty( - [ - bs, - config.hp.context_length, - config.hp.attention_head_count_kv, - config.hp.attn_head_dim, - ], - dtype=config.activation_dtype, - device=config.device, - ) - xv_temp = torch.empty( - [ - bs, - config.hp.context_length, - config.hp.attention_head_count_kv, - config.hp.attn_head_dim, - ], - dtype=config.activation_dtype, - device=config.device, - ) - elif phase == "prefill": - xk_temp = None - xv_temp = None - else: + if phase not in ["prefill", "decode"]: raise ValueError("'phase' argument needs to be either 'prefill' or 'decode'") h = paged_attention( @@ -153,8 +121,6 @@ def run_llama( attention_mask=attention_mask, cache_state=cache_state, seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, ) return h diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 6af0e5183..be8c66fb4 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -158,25 +158,22 @@ def read( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, - read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, seq_len: int, - page_ids: Union[torch.Tensor, ReplicatedTensor], + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, ): - """Reads cache partitions from the page table for the given page_ids. + """Reads K/V caches the page table for the given page_ids. Args: state: State struct as returned from allocate(). - read_into_partitions: List of cache partitions to read into in-place. transformer_block_index: The index of the transformer block accessing the cache. page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids to access. - Returns a tuple of cache partitions (i.e. k and v caches for the transformer - block), linearized. Note that this reference approach to reading by - materializing linearly may not be terribly efficient unless if the - compiler can fuse the gather. + Returns the K/V cache partitions, linearized. Note that this reference + approach to reading by materializing linearly may not be terribly + efficient unless if the compiler can fuse the gather. """ page_table = self.unflatten_page_table(state) # 6D @@ -204,12 +201,8 @@ def read( transformer_block_index * transformer_block_stride ) - def read_cache_partition( - index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor] - ): - subblock_ids = ( - (base_subblock_ids + index) if index > 0 else base_subblock_ids - ) + def read_cache_partition(index: int): + subblock_ids = base_subblock_ids + index # TODO: Potentially clamp all page 0 indices to the mask value. # Or even better, require that the ids are replicated such that access is # legal. @@ -217,19 +210,17 @@ def read_cache_partition( # index into the sub-pages, we flatten to do a linear index_select # copy of the sub-blocks by collapsing the first two dims so we have # a linear list. - # TODO: Can be rewritten into inplace with out= on index_select. selected = ( ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1)) .unflatten(0, blocked_shape[0:2]) .flatten(1, 2) ) - # trace_tensor("kv.selected", selected) - into_partition[...] = selected + return selected - for index, read_into_partition in enumerate(read_into_partitions): - read_cache_partition(index, read_into_partition) + key = read_cache_partition(0) + value = read_cache_partition(1) - return tuple([p[:, :seq_len, :] for p in read_into_partitions]) + return key[:, :seq_len], value[:, :seq_len] def write_timestep( self, diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index d74e2a92d..64c7c65d9 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -98,8 +98,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, embedding_batch_mask: Optional[torch.Tensor] = None, cache_state: list[torch.Tensor] = None, - xk_temp: Optional[torch.Tensor] = None, - xv_temp: Optional[torch.Tensor] = None, ): assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None) @@ -158,8 +156,6 @@ def forward( kv_seq_len=kv_seq_len, start_positions=start_positions, cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, ) # Expand kv heads for GQA. @@ -245,8 +241,6 @@ def transact_cache( seq_block_ids: Optional[torch.Tensor], kv_seq_len: int, start_positions: Optional[torch.Tensor] = None, - xk_temp: Optional[torch.Tensor] = None, - xv_temp: Optional[torch.Tensor] = None, ): cache = self.cache # Manage the cache. @@ -266,7 +260,6 @@ def transact_cache( # use a memory efficient attention kernel that can do indirect # reads, skipping this materialization. This path is taken for # a decode step. - assert xk_temp is not None and xv_temp is not None assert xk_cache_update.shape[1] == 1 assert xv_cache_update.shape[1] == 1 assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride @@ -286,10 +279,6 @@ def transact_cache( # Restore from the cache. xk, xv = cache.read( cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], transformer_block_index=self.block_index, page_ids=seq_block_ids, seq_len=kv_seq_len, diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 0be9ede05..cf38aa9da 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -170,7 +170,6 @@ def decode( self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) - bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. embedding_batch_mask = self.attention_embedding.compute_batch_mask( @@ -178,29 +177,6 @@ def decode( ) self.trace_tensor("grok.embedding_batch_mask", embedding_batch_mask) - # Allocate per-block temporary K/V tensors. These temporaries hold - # one block's K/V state for the maximum context length. - xk_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - xv_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - h = self.token_embedding(tokens) h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) @@ -220,8 +196,6 @@ def decode( attention_mask=attention_mask, cache_state=cache_state, seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, ) self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0bb6985e7..160d4c922 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -186,7 +186,6 @@ def decode( self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) - bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. embedding_batch_mask = self.attention_embedding.compute_batch_mask( @@ -194,51 +193,6 @@ def decode( ) self.trace_tensor("llama.embedding_batch_mask", embedding_batch_mask) - # Allocate per-block temporary K/V tensors. These temporaries hold - # one block's K/V state for the maximum context length. - if self.config.tensor_parallelism_size == 1: - xk_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - xv_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - else: - shard_size = [ - bs, - self.context_length, - self.hp.attention_head_count_kv // self.config.tensor_parallelism_size, - self.hp.attn_head_dim, - ] - xk_temp_shard = [ - torch.empty( - shard_size, dtype=self.config.activation_dtype, device=self.device - ) - for _ in range(self.config.tensor_parallelism_size) - ] - xv_temp_shard = [ - torch.empty( - shard_size, dtype=self.config.activation_dtype, device=self.device - ) - for _ in range(self.config.tensor_parallelism_size) - ] - xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2) - xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2) - h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -254,8 +208,6 @@ def decode( attention_mask=attention_mask, cache_state=cache_state, seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, ) self.trace_tensor(f"llama.attn_block.{block_idx}.output", h) @@ -323,8 +275,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, embedding_batch_mask: Optional[torch.Tensor] = None, cache_state: list[torch.Tensor] = None, - xk_temp: Optional[torch.Tensor] = None, - xv_temp: Optional[torch.Tensor] = None, ): h = self.attn( h, @@ -335,8 +285,6 @@ def forward( attention_mask=attention_mask, embedding_batch_mask=embedding_batch_mask, cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, ) # Feed forward network. diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index b597c6e99..0b0464733 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -177,29 +177,6 @@ def decode( ) self.trace_tensor("mixtral.embedding_batch_mask", embedding_batch_mask) - # Allocate per-block temporary K/V tensors. These temporaries hold - # one block's K/V state for the maximum context length. - xk_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - xv_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - h = self.token_embedding(tokens) self.trace_tensor("mixtral.token_embedding", h) @@ -218,8 +195,6 @@ def decode( attention_mask=attention_mask, cache_state=cache_state, seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, ) self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 8512c8768..d59d0a85b 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -53,18 +53,8 @@ def test_paged(): page_ids=write_page_ids, ) - # Check the written values have updated: - read_empty = [ - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - ] read_back = cache.read( allocation, - read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length, page_ids=write_page_ids, @@ -76,19 +66,8 @@ def test_paged(): for i in range(transformer_block_count): if i == 1: continue - read_ones = [ - torch.zeros( - (bs, write_seq_length, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] read_ones = cache.read( allocation, - read_into_partitions=read_ones, transformer_block_index=i, seq_len=write_seq_length, page_ids=write_page_ids, @@ -112,19 +91,8 @@ def test_paged(): page_ids=page_ids, ) - read_empty = [ - torch.zeros( - (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] read_back = cache.read( allocation, - read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length + 1, page_ids=page_ids, @@ -191,28 +159,8 @@ def test_sharded_paged(): page_ids=write_page_ids, ) - # Check the written values have updated: - empty_k = reshard_split( - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - dim=2, - count=shard_count, - ) - - empty_v = reshard_split( - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - dim=2, - count=shard_count, - ) - - read_empty = [empty_k, empty_v] - read_back = cache.read( allocation, - read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length, page_ids=write_page_ids, @@ -245,27 +193,8 @@ def test_sharded_paged(): page_ids=page_ids, ) - empty_k = reshard_split( - torch.zeros( - (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - dim=2, - count=shard_count, - ) - - empty_v = reshard_split( - torch.zeros( - (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - dim=2, - count=shard_count, - ) - read_back = cache.read( allocation, - read_into_partitions=[empty_k, empty_v], transformer_block_index=1, seq_len=write_seq_length + 1, page_ids=page_ids, diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d7b6a0b33..833e2ce71 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -104,44 +104,24 @@ def testRead(self): sharded_cache_state, ) = self.make_unsharded_and_sharded_equal_cache_states() - read_into_partitions_snapshot = [ - torch.rand( - self.batch_size, - self.block_seq_len * self.block_seq_stride, - self.attn_head_count, - self.attn_head_dim, - ) - for _ in range(self.cache_partition_count) - ] - read_into_partitions = deepcopy(read_into_partitions_snapshot) transformer_block_index = 1 page_ids = torch.randint( low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len] ).reshape([self.batch_size, self.block_seq_len]) - self.cache.read( + unsharded_read = self.cache.read( state=cache_state, - read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, seq_len=self.block_seq_len * self.block_seq_stride, ) - sharded_read_into_partitions = deepcopy( - [ - ops.reshard_split(t, dim=2, count=self.shard_count) - for t in read_into_partitions_snapshot - ] - ) sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) - self.sharded_cache.read( + sharded_read = self.sharded_cache.read( state=sharded_cache_state, - read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, seq_len=self.block_seq_len * self.block_seq_stride, ) - for unsharded, sharded in zip( - read_into_partitions, sharded_read_into_partitions - ): + for unsharded, sharded in zip(unsharded_read, sharded_read): assert ops.equal(unsharded, ops.unshard(sharded)) def testWriteTimestep(self):