diff --git a/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp b/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp index ed4f2297564..58cc1623fc6 100644 --- a/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp +++ b/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp @@ -114,7 +114,7 @@ static inline gqa_parameters init_params(const std::vector& inputs, const gqa_params.num_heads = num_heads; gqa_params.max_sequence_length = sequence_length; gqa_params.seq_stride = head_size; - gqa_params.head_stride = sequence_length * gqa_params.seq_stride; + gqa_params.head_stride = head_stride; gqa_params.batch_stride = batch_stride; gqa_params.position_ids_format = position_ids_format; gqa_params.seqlen_present_kv_cache = present_kv_seqlen; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/group_query_attention.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/group_query_attention.hpp index 095fb823a09..dbb60e7bd43 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/group_query_attention.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/group_query_attention.hpp @@ -28,36 +28,9 @@ #include #include #include -#include #include namespace migraphx { -// struct gqa_parameters -// { -// float scale; -// index_int batch_size; // Batch size used by input -// index_int sequence_length; // Sequence length used by input -// index_int hidden_size; // Hidden size used by input -// index_int head_size; // Head size -// index_int rotary_embedding_dim; // Rotary embedding dimension. -// index_int num_heads; // num_heads = hidden_size / head_size -// index_int max_sequence_length; // Sequence length used by cos/sin cache -// index_int head_stride; // Head stride -// index_int seq_stride; // Sequence stride -// index_int batch_stride; // Batch stride -// index_int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, -// // sequence_length) -// bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, -// // seq_len, hidden) -// index_int seqlen_present_kv_cache; // Sequence length of present kv-cache (4096 when using -// // shared buffer) -// bool do_rotary; // Whether to use rotary position embedding. Default value is 0. -// index_int kv_num_heads; // Number of attention heads for k and v -// int local_window_size; // left_window_size for local attention. Default value is -1 meaning -// // unused. -// bool rotary_interleaved; // Rotate using interleaved pattern. Default value is 0 (False). -// bool past_present_share_buffer; // Whether to use same buffer for KV-cache inputs and outputs -// }; template