diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index f0de9b04..bdb595b6 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -201,7 +201,7 @@ def __init__( dataset: str = None, group_size: int = 128, sym: bool = False, - backend="gptq:exllamav2", + backend="gptq:triton", iters: int = 200, weight_config: dict = None, enable_quanted_input=True, @@ -340,16 +340,17 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): use_triton, disable_exllama, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config( backend, bits ) - QuantLinear = dynamically_import_QuantLinear( - use_triton=False, - desc_act=False, - group_size=group_size, - bits=bits, - disable_exllama=True, - disable_exllamav2=False, - use_qigen=use_qigen, - disable_marlin=disable_marlin, - ) + # QuantLinear = dynamically_import_QuantLinear( + # use_triton=True, + # desc_act=False, + # group_size=group_size, + # bits=bits, + # disable_exllama=disable_exllama, + # disable_exllamav2=disable_exllamav2, + # use_qigen=use_qigen, + # disable_marlin=disable_marlin, + # ) + from auto_round.export.export_to_autoround.qliner_triton import QuantLinear layer = get_module(module, layer_name) device = get_device(layer) if isinstance(layer, nn.Linear): diff --git a/auto_round/export/export_to_autoround/export_to_autoround.py b/auto_round/export/export_to_autoround/export_to_autoround.py index 7f2253f7..e7688f38 100644 --- a/auto_round/export/export_to_autoround/export_to_autoround.py +++ b/auto_round/export/export_to_autoround/export_to_autoround.py @@ -72,7 +72,7 @@ def get_autogptq_backend_config(backend, bits=4): @register_format("autoround") -def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav2", **kwargs): +def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:triton", **kwargs): from auto_gptq.utils.import_utils import dynamically_import_QuantLinear model = kwargs["model"] @@ -96,16 +96,17 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav layer = get_module(model, name) device = "cpu" - QuantLinear = dynamically_import_QuantLinear( - use_triton=use_triton, - desc_act=False, - group_size=group_size, - bits=bits, - disable_exllama=disable_exllamav1, - disable_exllamav2=disable_exllamav2, - use_qigen=use_qigen, - disable_marlin=disable_marlin, - ) + # QuantLinear = dynamically_import_QuantLinear( + # use_triton=use_triton, + # desc_act=False, + # group_size=group_size, + # bits=bits, + # disable_exllama=disable_exllamav1, + # disable_exllamav2=disable_exllamav2, + # use_qigen=use_qigen, + # disable_marlin=disable_marlin, + # ) + from .qliner_triton import QuantLinear if isinstance(layer, nn.Linear): in_features = layer.in_features diff --git a/auto_round/export/export_to_autoround/qliner_triton.py b/auto_round/export/export_to_autoround/qliner_triton.py new file mode 100644 index 00000000..225071c8 --- /dev/null +++ b/auto_round/export/export_to_autoround/qliner_triton.py @@ -0,0 +1,218 @@ +import math +from logging import getLogger + +import numpy as np +import torch +import torch.nn as nn +import transformers + +from .triton_utils.mixin import TritonModuleMixin + + +logger = getLogger(__name__) + +try: + from .triton_utils.kernels import ( + QuantLinearFunction, + QuantLinearInferenceOnlyFunction, + quant_matmul_248, + quant_matmul_inference_only_248, + transpose_quant_matmul_248, + ) +except ImportError as e: + triton_import_exception = e + + def error_raiser_triton(*args, **kwargs): + raise ValueError( + f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" + ) + + class FakeTriton: + def __getattr__(self, name): + raise ImportError( + f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" + ) + + quant_matmul_248 = error_raiser_triton + transpose_quant_matmul_248 = error_raiser_triton + quant_matmul_inference_only_248 = error_raiser_triton + QuantLinearFunction = FakeTriton + QuantLinearInferenceOnlyFunction = FakeTriton + + +class QuantLinear(nn.Module, TritonModuleMixin): + QUANT_TYPE = "triton" + + def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + if infeatures % 32 != 0 or outfeatures % 32 != 0: + raise NotImplementedError("in_feature and out_feature must be divisible by 32.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.group_size = group_size if group_size != -1 else infeatures + self.maxq = 2**self.bits - 1 + + self.register_buffer( + "qweight", + torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + ) + self.register_buffer( + "qzeros", + torch.zeros( + ( + math.ceil(infeatures / self.group_size), + outfeatures // 32 * self.bits, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(infeatures / self.group_size), outfeatures), + dtype=torch.float16, + ), + ) + self.register_buffer( + "g_idx", + torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + ) + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.trainable = trainable + + def post_init(self): + pass + + def pack(self, linear, scales, zeros, g_idx=None): + W = linear.weight.data.clone() + if isinstance(linear, nn.Conv2d): + W = W.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + W = W.t() + + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ + :, None + ] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + + i = 0 + row = 0 + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + # zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction + out = quant_linear_fn.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out.half().reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + + @classmethod + def warmup(cls, model, transpose=False, seqlen=2048): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + kn_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, cls): + continue + + k = m.infeatures + n = m.outfeatures + + if (k, n) not in kn_values: + kn_values[(k, n)] = ( + m.qweight, + m.scales, + m.qzeros, + m.g_idx, + m.bits, + m.maxq, + ) + + logger.info(f"Found {len(kn_values)} unique KN Linear values.") + logger.info("Warming up autotune cache ...") + with torch.no_grad(): + for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): + m = 2**m + for (k, n), ( + qweight, + scales, + qzeros, + g_idx, + bits, + maxq, + ) in kn_values.items(): + if transpose: + a = torch.randn(m, k, dtype=torch.float16, device=model.device) + quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) + a = torch.randn(m, n, dtype=torch.float16, device=model.device) + transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) + else: + a = torch.randn(m, k, dtype=torch.float16, device=model.device) + quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq) + del kn_values + + +__all__ = ["QuantLinear"] diff --git a/auto_round/export/export_to_autoround/triton_utils/__init__.py b/auto_round/export/export_to_autoround/triton_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/auto_round/export/export_to_autoround/triton_utils/custom_autotune.py b/auto_round/export/export_to_autoround/triton_utils/custom_autotune.py new file mode 100644 index 00000000..ff2d14a3 --- /dev/null +++ b/auto_round/export/export_to_autoround/triton_utils/custom_autotune.py @@ -0,0 +1,219 @@ +import builtins +import math +import time +from typing import Dict + +import triton + + +# code based https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + + +class CustomizedTritonAutoTuner(triton.KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + ): + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = ( + prune_configs_by["perf_model"], + prune_configs_by["top_k"], + ) + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current, + ) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40) + except triton.OutOfResources: + return (float("inf"), float("inf"), float("inf")) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): + def decorator(fn): + return CustomizedTritonAutoTuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + prune_configs_by, + nearest_power_of_two, + ) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] + + if ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) in used: + continue + + used.add( + ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) + ) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + + +__all__ = ["autotune"] diff --git a/auto_round/export/export_to_autoround/triton_utils/dequant.py b/auto_round/export/export_to_autoround/triton_utils/dequant.py new file mode 100644 index 00000000..7f13a88f --- /dev/null +++ b/auto_round/export/export_to_autoround/triton_utils/dequant.py @@ -0,0 +1,145 @@ +import itertools + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + + +def make_dequant_configs(block_sizes, num_warps): + configs = [] + for bs, ws in itertools.product(block_sizes, num_warps): + configs.append(triton.Config({"X_BLOCK": bs}, num_warps=ws)) + return configs + + +DEFAULT_DEQUANT_CONFIGS = make_dequant_configs([128, 256, 512, 1024], [4, 8]) + + +@triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=["numels"]) +@triton.jit +def dequant_kernel_248( + g_idx_ptr, + scales_ptr, + qweight_ptr, + qzeros_ptr, + out_ptr, + numels, + maxq: tl.constexpr, + bits: tl.constexpr, + outfeatures: tl.constexpr, + num_groups: tl.constexpr, + X_BLOCK: tl.constexpr, +): + # Block indexing + xoffset = tl.program_id(0) * X_BLOCK + x_index = xoffset + tl.arange(0, X_BLOCK) + xmask = x_index < numels + row_idx = x_index // outfeatures + col_idx = x_index % outfeatures + + elements_per_feature: tl.constexpr = 32 // bits + + # Load parameters + g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last") + qweights = tl.load( + qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))), + None, + ) + + wf_weights = (row_idx % elements_per_feature) * bits + + wf_zeros = (col_idx % elements_per_feature) * bits + + tmp1 = g_idx + num_groups + tmp2 = g_idx < 0 + tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0") + groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx + + scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to( + tl.float32 + ) + + # Unpack weights + weights = qweights >> wf_weights # bit shift qweight + + weights = weights & maxq + + # Unpack zeros + qzero_ncols: tl.constexpr = outfeatures // elements_per_feature + qzeros = tl.load( + qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)), + None, + eviction_policy="evict_last", + ) + zeros = qzeros >> wf_zeros + zeros = zeros & maxq + + # Dequantize + # zeros = zeros + 1 + weights = weights - zeros + weights = weights.to(tl.float32) + weights = scales * weights + + tl.store(out_ptr + (x_index), weights, mask=xmask) + + +def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None): + """ + Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8 + """ + + num_groups = scales.shape[0] + outfeatures = scales.shape[1] + infeatures = g_idx.shape[0] + + out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16) + numels = out.numel() + maxq = 2**bits - 1 if maxq is None else maxq + grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731 + + dequant_kernel_248[grid]( + g_idx, + scales, + qweight, + qzeros, + out, + numels, + maxq=maxq, + bits=bits, + outfeatures=outfeatures, + num_groups=num_groups, + ) + return out + + +def quant_matmul_248( + input, qweight, scales, qzeros, g_idx, bits, maxq=None, transpose=False +): + W = dequant248(qweight, scales, qzeros, g_idx, bits, maxq=maxq) + if transpose: + return input @ W.t() + return input @ W + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits, ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = quant_matmul_248( + grad_output, qweight, scales, qzeros, g_idx, bits, maxq, transpose=True + ) + return grad_input, None, None, None, None, None, None diff --git a/auto_round/export/export_to_autoround/triton_utils/kernels.py b/auto_round/export/export_to_autoround/triton_utils/kernels.py new file mode 100644 index 00000000..c7a8874e --- /dev/null +++ b/auto_round/export/export_to_autoround/triton_utils/kernels.py @@ -0,0 +1,464 @@ +from logging import getLogger + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from . import custom_autotune + + +logger = getLogger(__name__) + + +# code based https://github.com/fpgaminer/GPTQ-triton + + +@custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def quant_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + # zeros = zeros + 1 + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, +) +@triton.jit +def transpose_quant_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for k in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + # zeros = zeros + 1 + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype) + grid = lambda META: ( # noqa: E731 + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + quant_matmul_248_kernel[grid]( + input, + qweight, + output, + scales.to(input.dtype), + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +def transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output_dim = (qweight.shape[0] * 32) // bits + output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype) + grid = lambda META: ( # noqa: E731 + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(output_dim, META["BLOCK_SIZE_K"]), + ) + transpose_quant_matmul_248_kernel[grid]( + input, + qweight, + output, + scales.to(input.dtype), + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + output_dim, + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits, ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) + return grad_input, None, None, None, None, None, None + + +def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + grid = lambda META: ( # noqa: E731 + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + quant_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearInferenceOnlyFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq) + return output diff --git a/auto_round/export/export_to_autoround/triton_utils/mixin.py b/auto_round/export/export_to_autoround/triton_utils/mixin.py new file mode 100644 index 00000000..16161183 --- /dev/null +++ b/auto_round/export/export_to_autoround/triton_utils/mixin.py @@ -0,0 +1,4 @@ +class TritonModuleMixin: + @classmethod + def warmup(cls, model, transpose=False, seqlen=2048): + pass