From 12a4d89a7c2a6de71735a55ffc372b7f227eb6a9 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 6 Dec 2024 00:14:24 -0800 Subject: [PATCH 1/5] working qsdpa --- benchmarks/python/sdpa_vector_bench.py | 96 +++-- .../scaled_dot_product_attention.metal | 29 ++ mlx/backend/metal/kernels/sdpa_vector.h | 357 ++++++++++++++++++ .../metal/scaled_dot_product_attention.cpp | 221 ++++++++++- mlx/fast.cpp | 127 ++++++- mlx/fast.h | 15 + mlx/fast_primitives.h | 15 +- python/src/fast.cpp | 39 ++ 8 files changed, 853 insertions(+), 46 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 291f87203..6e56fa1ef 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,58 +1,94 @@ -import argparse -import math - import mlx.core as mx +import numpy as np +from mlx.utils import tree_map from time_utils import time_fn -L = 16384 +L = 32768 H = 32 H_k = H // 4 D = 128 dtype = mx.float16 -loops = 10 +bits = 8 + +loops = 20 def attention(q, k, v): - def _sdpa(q, k, v): + for _ in range(loops): B, Hq, L, D = q.shape _, Hk, S, _ = k.shape q = q.reshape(B, Hk, Hq // Hk, L, D) - k = k[:, :, None, :, :] - v = v[:, :, None, :, :] - s = q @ k.transpose(0, 1, 2, 4, 3) + ke = k[:, :, None, :, :] + ve = v[:, :, None, :, :] + s = q @ ke.transpose(0, 1, 2, 4, 3) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) - o = p @ v - return o.reshape(B, Hq, L, D) - - for i in range(loops): - q = _sdpa(q, k, v) + q = p @ ve + q = q.reshape(B, Hq, L, D) return q def sdpa(q, k, v): - for i in range(loops): - q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + for _ in range(loops): + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) return q -def time_self_attention_primitives(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - mx.eval(q, k, v) +def quant_sdpa(q, k, v, bits=4): + for _ in range(loops): + q = mx.fast.quantized_scaled_dot_product_attention( + q, *k, *v, scale=1.0, mask=None, bits=bits + ) + return q + + +def quant_attention(q, k, v, bits=4): + for _ in range(loops): + B, Hq, L, D = q.shape + Hk = k[0].shape[1] + + q = q.reshape((B, Hk, Hq // Hk, L, D)) + ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k) + ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v) + + scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits) + scores = mx.softmax(scores, axis=-1) + + q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits) + q = q.reshape((B, Hq, L, D)) + return q + + +def time_self_attention_primitives(q, k, v): time_fn(attention, q, k, v) -def time_self_attention_sdpa(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - mx.eval(q, k, v) +def time_self_attention_sdpa(q, k, v): time_fn(sdpa, q, k, v) +def time_self_attention_quant_sdpa(q, k, v, bits=4): + time_fn(quant_sdpa, q, k, v, bits) + + +def time_self_attention_quant_primitives(q, k, v, bits=4): + time_fn(quant_attention, q, k, v, bits) + + if __name__ == "__main__": - time_self_attention_sdpa() - time_self_attention_primitives() + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) + k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + mx.eval(q, k, v) + + k_quant = mx.quantize(k, bits=bits) + v_quant = mx.quantize(v, bits=bits) + mx.eval(k_quant, v_quant) + + k = mx.dequantize(*k_quant, bits=bits) + v = mx.dequantize(*v_quant, bits=bits) + + time_self_attention_sdpa(q, k, v) + time_self_attention_quant_sdpa(q, k_quant, v_quant, bits) + time_self_attention_primitives(q, k, v) + time_self_attention_quant_primitives(q, k_quant, v_quant, bits) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index b5bc9607e..9beec77b1 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -20,4 +20,33 @@ using namespace metal; instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +// Quantized SDPA vector instantiations +#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \ + instantiate_kernel( \ + #name "_" #type "_" #head_dim "_" #group_size "_" #bits, \ + name, type, head_dim, group_size, bits) + +#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \ + instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \ + instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits) + +#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \ + instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \ + instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8) + +#define instantiate_quant_sdpa_vector_group_size(type, heads) \ + instantiate_quant_sdpa_vector_bits(type, heads, 32) \ + instantiate_quant_sdpa_vector_bits(type, heads, 64) \ + instantiate_quant_sdpa_vector_bits(type, heads, 128) + +#define instantiate_quant_sdpa_vector_heads(type) \ + instantiate_quant_sdpa_vector_group_size(type, 64) \ + instantiate_quant_sdpa_vector_group_size(type, 96) \ + instantiate_quant_sdpa_vector_group_size(type, 128) + +instantiate_quant_sdpa_vector_heads(float) +instantiate_quant_sdpa_vector_heads(bfloat16_t) +instantiate_quant_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 8b6af638e..49eb35b1f 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -113,6 +113,208 @@ template } } +template +METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) { + U query_sum = 0; + if (bits == 4) { + for (int i = 0; i < elem_per_thread; i += 4) { + q[i] = scale * queries[i]; + q[i + 1] = scale * queries[i + 1]; + q[i + 2] = scale * queries[i + 2]; + q[i + 3] = scale * queries[i + 3]; + query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3]; + q[i + 1] /= 16.0f; + q[i + 2] /= 256.0f; + q[i + 3] /= 4096.0f; + } + } else if (bits == 8) { + for (int i = 0; i < elem_per_thread; i++) { + q[i] = scale * queries[i]; + query_sum += q[i]; + } + } + return query_sum; +} + +template +METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) { + if (bits == 4) { + auto ks = (const device uint16_t*)keys; + for (int i = 0; i < elem_per_thread / 4; i++) { + k[4 * i] = ks[i] & 0x000f; + k[4 * i + 1] = ks[i] & 0x00f0; + k[4 * i + 2] = ks[i] & 0x0f00; + k[4 * i + 3] = ks[i] & 0xf000; + } + } else if (bits == 8) { + auto ks = (const device uint8_t*)keys; + for (int i = 0; i < elem_per_thread; i++) { + k[i] = ks[i]; + } + } +} + +template +METAL_FUNC void load_values( + const device uint32_t* values, + thread U* v, + U value_scale, + U value_bias) { + auto vs = (const device uint8_t*)values; + if (bits == 4) { + U s[2] = {value_scale, value_scale / 16.0f}; + for (int i = 0; i < elem_per_thread / 2; i++) { + v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias; + v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias; + } + } else if (bits == 8) { + for (int i = 0; i < elem_per_thread; i++) { + v[i] = value_scale * vs[i] + value_bias; + } + } +} + +template +[[kernel]] void quant_sdpa_vector( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device T* key_scales [[buffer(2)]], + const device T* key_biases [[buffer(3)]], + const device uint32_t* values [[buffer(4)]], + const device T* value_scales [[buffer(5)]], + const device T* value_biases [[buffer(6)]], + device T* out [[buffer(7)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& group_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int BN = 32; + constexpr int BD = 4; + constexpr int elem_per_thread = D / BD; + constexpr int pack_factor = 32 / bits; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U v[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + quad_lid * elem_per_thread; + + const int kv_idx = quad_gid * D + quad_lid * elem_per_thread; + const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor; + const int group_idx = kv_head_idx * group_stride + kv_idx / group_size; + keys += packed_idx; + key_scales += group_idx; + key_biases += group_idx; + values += packed_idx; + value_scales += group_idx; + value_biases += group_idx; + + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + U query_sum = load_queries( + queries, q, static_cast(scale)); + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = quad_gid; i < N; i += BN) { + load_keys(keys, k); + + // Assume D % group_size == 0 so all the keys are in the same group + U key_scale = key_scales[0]; + U key_bias = key_biases[0]; + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = score * key_scale + query_sum * key_bias; + score = quad_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + U value_scale = value_scales[0]; + U value_bias = value_biases[0]; + + // Load the values + load_values(values, v, value_scale, value_bias); + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * v[i]; + } + + // Move the pointers to the next kv + keys += stride / pack_factor; + key_scales += stride / group_size; + key_biases += stride / group_size; + values += stride / pack_factor; + value_scales += stride / group_size; + value_biases += stride / group_size; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + // Each quadgroup communicates it's max score + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + // 128 threads with 32 values per thread + outputs[simd_gid * BN + simd_lid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], @@ -290,3 +492,158 @@ template } } } + +template +[[kernel]] void quant_sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device T* key_scales [[buffer(2)]], + const device T* key_biases [[buffer(3)]], + const device uint32_t* values [[buffer(4)]], + const device T* value_scales [[buffer(5)]], + const device T* value_biases [[buffer(6)]], + device float* out [[buffer(7)]], + device float* sums [[buffer(8)]], + device float* maxs [[buffer(9)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant size_t& k_group_stride, + const constant size_t& v_group_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int BN = 8; + constexpr int BD = 4; + constexpr int elem_per_thread = D / BD; + const int stride = BN * D; + constexpr int blocks = 32; + constexpr int pack_factor = 32 / bits; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U v[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int block_idx = tid.z; + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + quad_lid * elem_per_thread; + + const int kv_idx = + (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread; + const int packed_idx = kv_idx / pack_factor; + const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; + const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; + + keys += kv_head_idx * k_stride + packed_idx; + key_scales += k_group_idx; + key_biases += k_group_idx; + values += kv_head_idx * v_stride + packed_idx; + value_scales += v_group_idx; + value_biases += v_group_idx; + + out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread; + sums += head_idx * blocks + block_idx; + maxs += head_idx * blocks + block_idx; + + // Read the query and 0 the output accumulator + U query_sum = load_queries( + queries, q, static_cast(scale)); + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -1e9; + U sum_exp_score = 0; + + // For each key + for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) { + // Read the key + load_keys(keys, k); + + // Assume D % group_size == 0 so all the keys are in the same group + U key_scale = key_scales[0]; + U key_bias = key_biases[0]; + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = score * key_scale + query_sum * key_bias; + score = quad_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + U value_scale = value_scales[0]; + U value_bias = value_biases[0]; + load_values(values, v, value_scale, value_bias); + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * v[i]; + } + + // Move the pointers to the next kv + keys += blocks * stride / pack_factor; + key_scales += blocks * stride / group_size; + key_biases += blocks * stride / group_size; + values += blocks * stride / pack_factor; + value_scales += blocks * stride / group_size; + value_biases += blocks * stride / group_size; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[quad_lid * BN + quad_gid] = + o[i] * fast::exp(max_scores[quad_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (quad_gid == 0) { + U output = outputs[quad_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[quad_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f600a4890..0f87e6027 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -242,6 +242,171 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void quant_sdpa_vector( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const array& k_biases, + const array& v, + const array& v_scales, + const array& v_biases, + array& out, + float scale, + int group_size, + int bits) { + // Set the kernel name + std::string kname; + kname.reserve(96); + kname += "quant_sdpa_vector_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int B = q.shape(0) * q.shape(1); + size_t stride = k.strides()[1]; + size_t group_stride = k_scales.strides()[1]; + MTL::Size group_dims(128, 1, 1); + MTL::Size grid_dims(1, B, 1); + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(k_biases, 3); + compute_encoder.set_input_array(v, 4); + compute_encoder.set_input_array(v_scales, 5); + compute_encoder.set_input_array(v_biases, 6); + compute_encoder.set_output_array(out, 7); + compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8); + compute_encoder.set_bytes(&N, sizeof(int), 9); + compute_encoder.set_bytes(&stride, sizeof(size_t), 10); + compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11); + compute_encoder.set_bytes(&scale, sizeof(float), 12); + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void quant_sdpa_vector_2pass( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const array& k_biases, + const array& v, + const array& v_scales, + const array& v_biases, + array& out, + float scale, + int group_size, + int bits) { + // Set the kernel name + std::string kname; + kname.reserve(96); + kname += "quant_sdpa_vector_2pass_1_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int blocks = 32; + int B = q.shape(0) * q.shape(1); + size_t k_stride = k.strides()[1]; + size_t v_stride = v.strides()[1]; + size_t k_group_stride = k_scales.strides()[1]; + size_t v_group_stride = v_scales.strides()[1]; + MTL::Size group_dims(8 * 4, 1, 1); + MTL::Size grid_dims(1, B, blocks); + + // Allocate the intermediates + std::vector intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); + intermediate_shape.push_back(blocks); + intermediate_shape.push_back(out.shape().back()); + array intermediate(intermediate_shape, float32, nullptr, {}); + intermediate_shape.pop_back(); + array sums(intermediate_shape, float32, nullptr, {}); + array maxs(std::move(intermediate_shape), float32, nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + sums.set_data(allocator::malloc_or_wait(sums.nbytes())); + maxs.set_data(allocator::malloc_or_wait(maxs.nbytes())); + d.add_temporary(intermediate, s.index); + d.add_temporary(sums, s.index); + d.add_temporary(maxs, s.index); + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(k_biases, 3); + compute_encoder.set_input_array(v, 4); + compute_encoder.set_input_array(v_scales, 5); + compute_encoder.set_input_array(v_biases, 6); + compute_encoder.set_output_array(intermediate, 7); + compute_encoder.set_output_array(sums, 8); + compute_encoder.set_output_array(maxs, 9); + compute_encoder.set_bytes(gqa_factor, 10); + compute_encoder.set_bytes(N, 11); + compute_encoder.set_bytes(k_stride, 12); + compute_encoder.set_bytes(v_stride, 13); + compute_encoder.set_bytes(k_group_stride, 14); + compute_encoder.set_bytes(v_group_stride, 15); + compute_encoder.set_bytes(scale, 16); + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Final pass + kname.clear(); + kname += "sdpa_vector_2pass_2_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + // Get the kernel + kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_input_array(sums, 1); + compute_encoder.set_input_array(maxs, 2); + compute_encoder.set_output_array(out, 3); + + // Launch + group_dims = MTL::Size(1024, 1, 1); + grid_dims = MTL::Size(1, B, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + } // namespace void ScaledDotProductAttention::eval_gpu( @@ -254,7 +419,6 @@ void ScaledDotProductAttention::eval_gpu( auto& q_pre = inputs[0]; auto& k_pre = inputs[1]; - auto& v_pre = inputs[2]; auto& o = out; std::vector copies; @@ -295,9 +459,7 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) == 1) { - const auto& q = copy_unless(is_contiguous, q_pre); - const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); - const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); + auto q = copy_unless(is_contiguous, q_pre); // Donate the query if possible if (q.is_donatable()) { @@ -306,20 +468,55 @@ void ScaledDotProductAttention::eval_gpu( o.set_data(allocator::malloc_or_wait(o.nbytes())); } - // We route to the 2 pass fused attention if - // - The device is large and the sequence length long - // - The sequence length is even longer and we have gqa - char devc = d.get_architecture().back(); - if ((devc == 'd' && k.shape(2) >= 1024) || - (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_); + if (quantized_) { + auto& k_scales_pre = inputs[2]; + auto& k_biases_pre = inputs[3]; + auto& v_pre = inputs[4]; + auto& v_scales_pre = inputs[5]; + auto& v_biases_pre = inputs[6]; + + auto k = copy_unless(is_contiguous_except_seq_len, k_pre); + auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre); + auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre); + auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre); + auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre); + + quant_sdpa_vector_2pass( + s, + d, + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + o, + scale_, + group_size_, + bits_); } else { - sdpa_vector(s, d, q, k, v, o, scale_); + auto& k_pre = inputs[1]; + auto& v_pre = inputs[2]; + + const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); + const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); + + char devc = d.get_architecture().back(); + if ((devc == 'd' && k.shape(2) >= 1024) || + (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { + sdpa_vector_2pass(s, d, q, k, v, o, scale_); + } else { + sdpa_vector(s, d, q, k, v, o, scale_); + } } } // Full attention mode else { + auto& v_pre = inputs[2]; + const auto& q = copy_unless(is_matrix_contiguous, q_pre); const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731912d69..37a6ec47b 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -664,7 +664,7 @@ array scaled_dot_product_attention( std::move(out_shape), final_type, std::make_shared( - stream, fallback, scale, false), + stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false), {q, k, v}); } @@ -678,7 +678,130 @@ array scaled_dot_product_attention( bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); - return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; + return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ && + quantized_ == a_other.quantized_; +} + +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + const float scale, + const std::optional& mask, + const int group_size, + const int bits, + StreamOrDevice s) { + int el_per_int = 32 / bits; + int out_dim = values.shape(-1) * el_per_int; + + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + + auto out_shape = std::vector( + {queries.shape(0), queries.shape(1), queries.shape(2), out_dim}); + auto stream = to_stream(s); + bool needs_mask = mask.has_value(); + auto fallback = + [scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s]( + const std::vector& inputs) -> std::vector { + int n_repeats = n_q_heads / n_kv_heads; + + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + + auto k = inputs[1]; + auto k_scales = inputs[2]; + auto k_biases = inputs[3]; + auto v = inputs[4]; + auto v_scales = inputs[5]; + auto v_biases = inputs[6]; + + int B = q.shape(0); + int L = q.shape(2); + + if (n_repeats > 1) { + q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s); + k = expand_dims(k, 2, s); + k_scales = expand_dims(k_scales, 2, s); + k_biases = expand_dims(k_biases, 2, s); + v = expand_dims(v, 2, s); + v_scales = expand_dims(v_scales, 2, s); + v_biases = expand_dims(v_biases, 2, s); + } + + array scores = quantized_matmul( + q, + k, + k_scales, + k_biases, + /*transpose=*/true, + /*group_size=*/group_size, + /*bits=*/bits, + s); + if (needs_mask) { + scores = add(scores, inputs[7], s); + } + scores = softmax(scores, std::vector{-1}, true, s); + array out = quantized_matmul( + scores, + v, + v_scales, + v_biases, + /*transpose=*/false, + /*group_size=*/group_size, + /*bits=*/bits, + s); + if (n_repeats > 1) { + out = reshape(out, {B, n_q_heads, L, -1}, s); + } + return std::vector{out}; + }; + + int L = queries.shape(2); + if (L > 1) { + if (needs_mask) { + return fallback( + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases, + mask.value()})[0]; + } else { + return fallback( + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases})[0]; + } + } else { + return array( + std::move(out_shape), + queries.dtype(), + std::make_shared( + stream, + fallback, + scale, + /*needs_mask=*/false, + /*quantized=*/true, + group_size, + bits), + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases}); + } } array pack_and_quantize( diff --git a/mlx/fast.h b/mlx/fast.h index ddc3512b5..dc47e6c46 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -41,6 +41,21 @@ array scaled_dot_product_attention( const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); +/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + const float scale, + const std::optional& mask = std::nullopt, + const int group_size = 64, + const int bits = 4, + StreamOrDevice s = {}); + std::tuple affine_quantize( const array& w, int group_size = 64, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 30db282ff..830da18b4 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom { Stream stream, std::function(std::vector)> fallback, const float scale, - const bool needs_mask) - : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {} + const bool needs_mask, + const bool quantized, + const int group_size = 64, + const int bits = 4) + : Custom(stream, fallback), + scale_(scale), + needs_mask_(needs_mask), + quantized_(quantized), + group_size_(group_size), + bits_(bits) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom { std::function(std::vector)> fallback_; float scale_; bool needs_mask_; + bool quantized_; + int group_size_; + int bits_; }; class AffineQuantize : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index cbc8b934d..188a93b22 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -161,6 +161,45 @@ void init_fast(nb::module_& parent_module) { array: The output array. )pbdoc"); + m.def( + "quantized_scaled_dot_product_attention", + &fast::quantized_scaled_dot_product_attention, + "q"_a, + "k"_a, + "k_scales"_a, + "k_biases"_a, + "v"_a, + "v_scales"_a, + "v_biases"_a, + nb::kw_only(), + "scale"_a, + "mask"_a = nb::none(), + "group_size"_a = 64, + "bits"_a = 4, + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + A fast implementation of multi-head attention where the keys and values are quantized. + + see :func:`scaled_dot_product_attention` for more details. + + Args: + q (array): Input query array. + k (array): Input keys array. + k_scales (array): Scales for the quantized keys array. + k_biases (array): Biases for the quantized keys array. + v (array): Input values array. + v_scales (array): Scales for the quantized values array. + v_biases (array): Biases for the quantized values array. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive mask to apply to the query-key scores. + group_size (int): The group size used in the KV quantization. + bits (int): The bits used in the KV quantization. + Returns: + array: The output array. + )pbdoc"); + m.def( "metal_kernel", [](const std::string& name, From 3507c104a5c5932a2f66e14bfa2a6ce86b422251 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 6 Dec 2024 00:45:01 -0800 Subject: [PATCH 2/5] add test --- python/tests/test_fast_sdpa.py | 99 +++++++++++++++++----------------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 1df48bc7f..ad3262e7c 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -1,5 +1,6 @@ import math import unittest +from itertools import product import mlx.core as mx import mlx_tests @@ -113,61 +114,63 @@ def test_fast_sdpa(self): R = 1 Dk = 128 scale = float(1.0 / np.sqrt(128.0)) - q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32) - k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) - v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) + q = mx.random.normal(shape=(1, 32, R, Dk)) + k = mx.random.normal(shape=(1, 32, L, Dk)) + v = mx.random.normal(shape=(1, 32, L, Dk)) - q_mlx = mx.array(q_npy) - k_mlx = mx.array(k_npy) - v_mlx = mx.array(v_npy) + reference = mlx_primitives_sdpa(q, k, v, scale) - reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) + o = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) - o_mlx = mx.fast.scaled_dot_product_attention( - q_mlx, k_mlx, v_mlx, scale=scale, mask=None - ) - - self.assertListEqual(list(reference.shape), list(o_mlx.shape)) - self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) + self.assertListEqual(list(reference.shape), list(o.shape)) + self.assertTrue(mx.allclose(o, reference, atol=1e-4)) B = 1 H = 32 - dtypes = [np.float32] - if self.is_apple_silicon: - dtypes.append(np.half) - for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: - for DO_GQA in [0, 1]: - for DTYPE in dtypes: - n_kv_heads = 8 if DO_GQA else 32 - q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE) - k_npy = np.random.normal( - 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) - ).astype(DTYPE) - v_npy = np.random.normal( - 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) - ).astype(DTYPE) - - q_mlx = mx.array(q_npy) - k_mlx = mx.array(k_npy) - v_mlx = mx.array(v_npy) - - reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) - o_mlx = mx.fast.scaled_dot_product_attention( - q_mlx, k_mlx, v_mlx, scale=scale - ) - - self.assertListEqual(list(reference.shape), list(o_mlx.shape)) - rtol = 1e-5 - atol = 1e-1 - - if SEQUENCE_LENGTH > 500: - rtol = 1e-2 - - if DTYPE == np.half: - rtol = 1e-2 - - self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) + tests = product( + [1, 7, 9, 32, 63, 67, 129, 2000], # sequence length + [False, True], # gqa + [mx.float32, mx.float16], + [4, 8], # bits + ) + for sequence_length, do_gqa, dtype, bits in tests: + with self.subTest( + sequence_length=sequence_length, gqa=do_gqa, dtype=dtype, bits=bits + ): + n_kv_heads = 8 if do_gqa else 32 + q = mx.random.normal(shape=(B, H, R, Dk), dtype=dtype) + k = mx.random.normal( + shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype + ) + v = mx.random.normal( + shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype + ) + + k_q = mx.quantize(k, bits=bits) + v_q = mx.quantize(v, bits=bits) + k_d = mx.dequantize(*k_q, bits=bits) + v_d = mx.dequantize(*v_q, bits=bits) + + reference = mlx_primitives_sdpa_with_gqa(q, k_d, v_d, scale) + o = mx.fast.scaled_dot_product_attention(q, k_d, v_d, scale=scale) + o_q = mx.fast.quantized_scaled_dot_product_attention( + q, *k_q, *v_q, scale=scale, bits=bits + ) + + self.assertListEqual(list(reference.shape), list(o.shape)) + rtol = 1e-5 + atol = 1e-1 + + if sequence_length > 500: + rtol = 1e-2 + + if dtype == mx.float16: + rtol = 1e-2 + + # np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol) + self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) + self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol)) q = mx.random.normal(shape=(1, 32, 1, Dk)) k = mx.random.normal(shape=(1, 32, 32, Dk)) From c89ddf62b41667c1914b01ce54a838d18264e155 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 6 Dec 2024 01:09:00 -0800 Subject: [PATCH 3/5] add checks --- .../scaled_dot_product_attention.metal | 15 +- mlx/backend/metal/kernels/sdpa_vector.h | 141 ------------------ .../metal/scaled_dot_product_attention.cpp | 59 -------- mlx/fast.cpp | 61 ++++++-- 4 files changed, 51 insertions(+), 225 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 9beec77b1..5292f8685 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -22,18 +22,14 @@ instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) // Quantized SDPA vector instantiations -#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \ +#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \ instantiate_kernel( \ - #name "_" #type "_" #head_dim "_" #group_size "_" #bits, \ - name, type, head_dim, group_size, bits) - -#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \ - instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \ - instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits) + "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \ + quant_sdpa_vector_2pass_1, type, head_dim, group_size, bits) #define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \ - instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \ - instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8) + instantiate_quant_sdpa_vector(type, heads, group_size, 4) \ + instantiate_quant_sdpa_vector(type, heads, group_size, 8) #define instantiate_quant_sdpa_vector_group_size(type, heads) \ instantiate_quant_sdpa_vector_bits(type, heads, 32) \ @@ -42,7 +38,6 @@ instantiate_sdpa_vector_heads(float16_t) #define instantiate_quant_sdpa_vector_heads(type) \ instantiate_quant_sdpa_vector_group_size(type, 64) \ - instantiate_quant_sdpa_vector_group_size(type, 96) \ instantiate_quant_sdpa_vector_group_size(type, 128) instantiate_quant_sdpa_vector_heads(float) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 49eb35b1f..a03ee3d9b 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -174,147 +174,6 @@ METAL_FUNC void load_values( } } -template -[[kernel]] void quant_sdpa_vector( - const device T* queries [[buffer(0)]], - const device uint32_t* keys [[buffer(1)]], - const device T* key_scales [[buffer(2)]], - const device T* key_biases [[buffer(3)]], - const device uint32_t* values [[buffer(4)]], - const device T* value_scales [[buffer(5)]], - const device T* value_biases [[buffer(6)]], - device T* out [[buffer(7)]], - const constant int& gqa_factor, - const constant int& N, - const constant size_t& k_stride, - const constant size_t& group_stride, - const constant float& scale, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { - constexpr int BN = 32; - constexpr int BD = 4; - constexpr int elem_per_thread = D / BD; - constexpr int pack_factor = 32 / bits; - - const int stride = BN * D; - - typedef float U; - - thread U q[elem_per_thread]; - thread U k[elem_per_thread]; - thread U v[elem_per_thread]; - thread U o[elem_per_thread]; - - threadgroup U outputs[BN * BD]; - threadgroup U max_scores[BN]; - threadgroup U sum_exp_scores[BN]; - - // Adjust positions - const int head_idx = tid.y; - const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + quad_lid * elem_per_thread; - - const int kv_idx = quad_gid * D + quad_lid * elem_per_thread; - const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor; - const int group_idx = kv_head_idx * group_stride + kv_idx / group_size; - keys += packed_idx; - key_scales += group_idx; - key_biases += group_idx; - values += packed_idx; - value_scales += group_idx; - value_biases += group_idx; - - out += head_idx * D + simd_gid * elem_per_thread; - - // Read the query and 0 the output accumulator - U query_sum = load_queries( - queries, q, static_cast(scale)); - for (int i = 0; i < elem_per_thread; i++) { - o[i] = 0; - } - - U max_score = -INFINITY; - U sum_exp_score = 0; - - // For each key - for (int i = quad_gid; i < N; i += BN) { - load_keys(keys, k); - - // Assume D % group_size == 0 so all the keys are in the same group - U key_scale = key_scales[0]; - U key_bias = key_biases[0]; - - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = score * key_scale + query_sum * key_bias; - score = quad_sum(score); - - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); - - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; - - U value_scale = value_scales[0]; - U value_bias = value_biases[0]; - - // Load the values - load_values(values, v, value_scale, value_bias); - - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * v[i]; - } - - // Move the pointers to the next kv - keys += stride / pack_factor; - key_scales += stride / group_size; - key_biases += stride / group_size; - values += stride / pack_factor; - value_scales += stride / group_size; - value_biases += stride / group_size; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Each thread has a partial part of the output so we need to combine them. - - // First let's communicate the max and sum_exp - // Each quadgroup communicates it's max score - if (quad_lid == 0) { - max_scores[quad_gid] = max_score; - sum_exp_scores[quad_gid] = sum_exp_score; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - max_score = max_scores[simd_lid]; - U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); - sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); - - // Now we need to aggregate all the outputs - for (int i = 0; i < elem_per_thread; i++) { - // 128 threads with 32 values per thread - outputs[simd_gid * BN + simd_lid] = o[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // And write the output - if (simd_lid == 0) { - for (int i = 0; i < elem_per_thread; i++) { - out[i] = static_cast(o[i]); - } - } -} - template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 0f87e6027..d82d55bac 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -242,65 +242,6 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -void quant_sdpa_vector( - const Stream& s, - metal::Device& d, - const array& q, - const array& k, - const array& k_scales, - const array& k_biases, - const array& v, - const array& v_scales, - const array& v_biases, - array& out, - float scale, - int group_size, - int bits) { - // Set the kernel name - std::string kname; - kname.reserve(96); - kname += "quant_sdpa_vector_"; - kname += get_type_string(q.dtype()); - kname += "_"; - kname += std::to_string(q.shape(-1)); - kname += "_"; - kname += std::to_string(group_size); - kname += "_"; - kname += std::to_string(bits); - - // Compute the necessary sizes - int gqa_factor = q.shape(1) / k.shape(1); - int N = k.shape(2); - int B = q.shape(0) * q.shape(1); - size_t stride = k.strides()[1]; - size_t group_stride = k_scales.strides()[1]; - MTL::Size group_dims(128, 1, 1); - MTL::Size grid_dims(1, B, 1); - - // Get the kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set its arguments - compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); - compute_encoder.set_input_array(k, 1); - compute_encoder.set_input_array(k_scales, 2); - compute_encoder.set_input_array(k_biases, 3); - compute_encoder.set_input_array(v, 4); - compute_encoder.set_input_array(v_scales, 5); - compute_encoder.set_input_array(v_biases, 6); - compute_encoder.set_output_array(out, 7); - compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8); - compute_encoder.set_bytes(&N, sizeof(int), 9); - compute_encoder.set_bytes(&stride, sizeof(size_t), 10); - compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11); - compute_encoder.set_bytes(&scale, sizeof(float), 12); - - // Launch - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); -} - void quant_sdpa_vector_2pass( const Stream& s, metal::Device& d, diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 37a6ec47b..3848bd58c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -516,15 +516,11 @@ bool RoPE::is_equivalent(const Primitive& other) const { offset_ == a_other.offset_ && forward_ == a_other.forward_); } -/** Computes: O = softmax(Q @ K.T) @ V **/ -array scaled_dot_product_attention( +void check_sdpa_arguments( const array& queries, const array& keys, const array& values, - const float scale, - const std::optional& mask, - const std::optional memory_efficient_threshold, - StreamOrDevice s) { + const std::optional& mask) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { std::ostringstream msg; @@ -550,14 +546,6 @@ array scaled_dot_product_attention( } } - // Q, K must have matching last dims (d_k aka 'head_dim'); - if (queries.shape(-1) != keys.shape(-1)) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " - << queries.shape() << " for keys shape " << keys.shape() << "."; - throw std::invalid_argument(msg.str()); - } - // K, V must have matching number of heads (n_kv_heads); auto n_q_heads = queries.shape(-3); auto n_kv_heads = keys.shape(-3); @@ -577,6 +565,26 @@ array scaled_dot_product_attention( << n_q_heads << " for n_kv_heads " << n_kv_heads << "."; throw std::invalid_argument(msg.str()); } +} + +/** Computes: O = softmax(Q @ K.T) @ V **/ +array scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& values, + const float scale, + const std::optional& mask, + const std::optional memory_efficient_threshold, + StreamOrDevice s) { + check_sdpa_arguments(queries, keys, values, mask); + + // Q, K must have matching last dims (d_k aka 'head_dim'); + if (queries.shape(-1) != keys.shape(-1)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " + << queries.shape() << " for keys shape " << keys.shape() << "."; + throw std::invalid_argument(msg.str()); + } auto final_type = result_type(queries, keys, values); if (!issubdtype(final_type, floating)) { @@ -590,6 +598,9 @@ array scaled_dot_product_attention( auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + /* generic implementation for use cases that Metal implementation does not * support. For non-supported cases listed below, use MLX primitives: * * CPU implementation @@ -696,6 +707,25 @@ array quantized_scaled_dot_product_attention( const int bits, StreamOrDevice s) { int el_per_int = 32 / bits; + + check_sdpa_arguments(queries, keys, values, mask); + + // Q, K must have matching last dims (d_k aka 'head_dim'); + if (queries.shape(-1) != keys.shape(-1) * el_per_int) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " + << queries.shape() << " for keys shape " << keys.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto final_type = result_type(queries, key_scales, value_scales); + if (!issubdtype(final_type, floating)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Received unsupported type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + int out_dim = values.shape(-1) * el_per_int; auto n_q_heads = queries.shape(-3); @@ -760,8 +790,9 @@ array quantized_scaled_dot_product_attention( return std::vector{out}; }; + int query_head_dim = queries.shape(-1); int L = queries.shape(2); - if (L > 1) { + if (L > 1 && query_head_dim != 64 && query_head_dim != 128) { if (needs_mask) { return fallback( {queries, From 769704653adf1adb0e408caf03ab599f9c270da4 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 6 Dec 2024 01:22:50 -0800 Subject: [PATCH 4/5] cpu fallback --- mlx/fast.cpp | 3 ++- python/tests/test_fast_sdpa.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 3848bd58c..bd58715e9 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -792,7 +792,8 @@ array quantized_scaled_dot_product_attention( int query_head_dim = queries.shape(-1); int L = queries.shape(2); - if (L > 1 && query_head_dim != 64 && query_head_dim != 128) { + bool compatible_head_dim = query_head_dim == 64 || query_head_dim == 128; + if (L > 1 || !compatible_head_dim || stream.device != Device::gpu) { if (needs_mask) { return fallback( {queries, diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index ad3262e7c..73aa5b61d 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -168,7 +168,6 @@ def test_fast_sdpa(self): if dtype == mx.float16: rtol = 1e-2 - # np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol) self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol)) From 82a956c1d94872fd7030bc2ac6769c72661566e7 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 6 Dec 2024 10:26:54 -0800 Subject: [PATCH 5/5] fix test --- python/tests/test_fast_sdpa.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 73aa5b61d..528fc0274 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -128,10 +128,13 @@ def test_fast_sdpa(self): B = 1 H = 32 + dtypes = [mx.float32] + if self.is_apple_silcon: + dtypes.append(mx.float16) tests = product( [1, 7, 9, 32, 63, 67, 129, 2000], # sequence length [False, True], # gqa - [mx.float32, mx.float16], + dtypes, [4, 8], # bits ) for sequence_length, do_gqa, dtype, bits in tests: