diff --git a/Source/astcenc_weight_align.cpp b/Source/astcenc_weight_align.cpp index 4f302388..b2054164 100644 --- a/Source/astcenc_weight_align.cpp +++ b/Source/astcenc_weight_align.cpp @@ -175,24 +175,16 @@ static void compute_lowest_and_highest_weight( // unrounded weights in a straightforward way. vfloat min_weight(FLT_MAX); vfloat max_weight(-FLT_MAX); - unsigned int partial_weight_start = round_down_to_simd_multiple_vla(weight_count); - for (unsigned int i = 0; i < partial_weight_start; i += ASTCENC_SIMD_WIDTH) - { - vfloat weights = loada(dec_weight_ideal_value + i); - min_weight = min(min_weight, weights); - max_weight = max(max_weight, weights); - } - if (partial_weight_start != weight_count) + vint lane_id = vint::lane_id(); + for (unsigned int i = 0; i < weight_count; i += ASTCENC_SIMD_WIDTH) { - vfloat partial_weights = loada(dec_weight_ideal_value + partial_weight_start); - vmask active = vint::lane_id() < vint(weight_count - partial_weight_start); - - vmask smaller = active & (partial_weights < min_weight); - min_weight = select(min_weight, partial_weights, smaller); + vmask active = lane_id < vint(weight_count); + lane_id += vint(ASTCENC_SIMD_WIDTH); - vmask larger = active & (partial_weights > max_weight); - max_weight = select(max_weight, partial_weights, larger); + vfloat weights = loada(dec_weight_ideal_value + i); + min_weight = min(min_weight, select(min_weight, weights, active)); + max_weight = max(max_weight, select(max_weight, weights, active)); } min_weight = hmin(min_weight);