Skip to content

Commit

Permalink
[XLA:GPU] Use f32->bfloat conversion instructions on sm_80+
Browse files Browse the repository at this point in the history
We tried this before with an intrinsic, but that breaks vectorization. Relying
on native LLVM types doesn't while delivering the same code improvements. The
downside is that LLVM now knows that it's a bfloat instead of a i16 and will
optimize based on it. While making this change I had to patch a bunch of holes
in the NVPTX LLVM backend, there might be more.

Depends on llvm/llvm-project#74827

PiperOrigin-RevId: 590118269
  • Loading branch information
d0k authored and tensorflower-gardener committed Dec 12, 2023
1 parent d820ba9 commit a28a99a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
12 changes: 12 additions & 0 deletions third_party/xla/xla/service/gpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,18 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitF32ToBF16(
llvm::Value* f32_value) {
// sm_80 and up has an instruction to convert f32 into bf16.
if (ir_emitter_context_.cuda_compute_capability().IsAtLeast(
se::CudaComputeCapability::AMPERE)) {
return BitCast(
FPTrunc(BitCast(f32_value, b()->getFloatTy()), b()->getBFloatTy()),
b()->getInt16Ty());
}
return ElementalIrEmitter::EmitF32ToBF16(f32_value);
}

StatusOr<std::vector<llvm::Value*>> GpuElementalIrEmitter::EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view, bool /*is_reducer*/) {
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {

llvm::Value* EmitThreadId() override;

StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value) override;

bool fast_min_max() override {
return ir_emitter_context_.debug_options().xla_gpu_enable_fast_min_max();
}
Expand Down
19 changes: 19 additions & 0 deletions third_party/xla/xla/service/gpu/tests/single_instruction.hlo
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: hlo_to_llvm_ir --ptx %s | FileCheck %s
// RUN: hlo_to_llvm_ir --ptx %s --sm=80 | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-SM80
// RUN: hlo_to_llvm_ir --ptx %s --sm=90 | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-SM90

// CHECK-DAG: sqrt.approx.f32

Expand Down Expand Up @@ -80,3 +81,21 @@ ENTRY main {
b = f32[] parameter(1)
ROOT wrapped_b = f32[] fusion(f32[] a, f32[] b), kind=kLoop, calls=fused_computation
}

// -----

// CHECK-SM80: cvt.rn.f32.s16
// CHECK-SM80: cvt.rn.bf16.f32
// CHECK-SM90: cvt.rn.bf16.s16

HloModule Test, is_scheduled=true

fused_computation {
param_0 = s16[] parameter(0)
ROOT b.1 = bf16[] convert(s16[] param_0)
}

ENTRY main {
a = s16[] parameter(0)
ROOT wrapped_b = bf16[] fusion(s16[] a), kind=kLoop, calls=fused_computation
}

0 comments on commit a28a99a

Please sign in to comment.