diff --git a/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc b/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc index c0beed76cb282e..a82753243dec65 100644 --- a/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/gpu/elemental_ir_emitter.cc @@ -353,6 +353,18 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() { return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } +StatusOr GpuElementalIrEmitter::EmitF32ToBF16( + llvm::Value* f32_value) { + // sm_80 and up has an instruction to convert f32 into bf16. + if (ir_emitter_context_.cuda_compute_capability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + return BitCast( + FPTrunc(BitCast(f32_value, b()->getFloatTy()), b()->getBFloatTy()), + b()->getInt16Ty()); + } + return ElementalIrEmitter::EmitF32ToBF16(f32_value); +} + StatusOr> GpuElementalIrEmitter::EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view, bool /*is_reducer*/) { diff --git a/third_party/xla/xla/service/gpu/elemental_ir_emitter.h b/third_party/xla/xla/service/gpu/elemental_ir_emitter.h index f97861ba2e7afc..76a4ec19dbba67 100644 --- a/third_party/xla/xla/service/gpu/elemental_ir_emitter.h +++ b/third_party/xla/xla/service/gpu/elemental_ir_emitter.h @@ -96,6 +96,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* EmitThreadId() override; + StatusOr EmitF32ToBF16(llvm::Value* f32_value) override; + bool fast_min_max() override { return ir_emitter_context_.debug_options().xla_gpu_enable_fast_min_max(); } diff --git a/third_party/xla/xla/service/gpu/tests/single_instruction.hlo b/third_party/xla/xla/service/gpu/tests/single_instruction.hlo index b45dc9a7b021f2..51daf82bb7ceb8 100644 --- a/third_party/xla/xla/service/gpu/tests/single_instruction.hlo +++ b/third_party/xla/xla/service/gpu/tests/single_instruction.hlo @@ -1,5 +1,6 @@ // RUN: hlo_to_llvm_ir --ptx %s | FileCheck %s // RUN: hlo_to_llvm_ir --ptx %s --sm=80 | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-SM80 +// RUN: hlo_to_llvm_ir --ptx %s --sm=90 | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-SM90 // CHECK-DAG: sqrt.approx.f32 @@ -80,3 +81,21 @@ ENTRY main { b = f32[] parameter(1) ROOT wrapped_b = f32[] fusion(f32[] a, f32[] b), kind=kLoop, calls=fused_computation } + +// ----- + +// CHECK-SM80: cvt.rn.f32.s16 +// CHECK-SM80: cvt.rn.bf16.f32 +// CHECK-SM90: cvt.rn.bf16.s16 + +HloModule Test, is_scheduled=true + +fused_computation { + param_0 = s16[] parameter(0) + ROOT b.1 = bf16[] convert(s16[] param_0) +} + +ENTRY main { + a = s16[] parameter(0) + ROOT wrapped_b = bf16[] fusion(s16[] a), kind=kLoop, calls=fused_computation +}