Skip to content

Commit

Permalink
[xla:cpu] Add support for linking with external symbols to JitCompiler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701074278
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Nov 28, 2024
1 parent 697677f commit 946f70b
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 36 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ xla_cc_test(
"//xla:util",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcShared",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@local_tsl//tsl/platform:env",
Expand Down
57 changes: 40 additions & 17 deletions third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,19 @@ absl::StatusOr<JitCompiler> JitCompiler::Create(
/*SSP=*/nullptr,
std::make_unique<TaskDispatcher>(std::move(task_runner))));

execution_session->setErrorReporter([](llvm::Error err) {
LOG(ERROR) << "LLVM compilation error: " << llvm::toString(std::move(err));
});

// Create an instance of IrCompiler for lowering LLVM modules to machine code.
auto ir_compiler = std::make_unique<IrCompiler>(
target_machine_builder, std::move(options.ir_compiler_options),
std::move(options.ir_compiler_hooks));

return JitCompiler(std::move(target_machine_builder),
std::move(target_machine), std::move(execution_session),
std::move(ir_compiler), options.num_dylibs);
std::move(ir_compiler), options.num_dylibs,
std::move(options.definition_generator));
}

static std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer>
Expand All @@ -148,35 +153,39 @@ CreateObjectLinkingLayer(llvm::orc::ExecutionSession& execution_session) {

static std::unique_ptr<llvm::orc::IRCompileLayer> CreateCompileLayer(
llvm::orc::ExecutionSession& execution_session,
llvm::orc::RTDyldObjectLinkingLayer& object_linking_layer,
llvm::orc::RTDyldObjectLinkingLayer& object_layer,
std::unique_ptr<IrCompiler> ir_compiler) {
return std::make_unique<llvm::orc::IRCompileLayer>(
execution_session, object_linking_layer, std::move(ir_compiler));
execution_session, object_layer, std::move(ir_compiler));
}

JitCompiler::JitCompiler(
IrCompiler::TargetMachineBuilder target_machine_builder,
std::shared_ptr<llvm::TargetMachine> target_machine,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs)
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
DefinitionGenerator definition_generator)
: target_machine_builder_(std::move(target_machine_builder)),
target_machine_(std::move(target_machine)),
execution_session_(std::move(execution_session)),
object_linking_layer_(CreateObjectLinkingLayer(*execution_session_)),
compile_layer_(CreateCompileLayer(
*execution_session_, *object_linking_layer_, std::move(ir_compiler))),
object_layer_(CreateObjectLinkingLayer(*execution_session_)),
compile_layer_(CreateCompileLayer(*execution_session_, *object_layer_,
std::move(ir_compiler))),
gdb_(llvm::JITEventListener::createGDBRegistrationListener()),
perf_(llvm::JITEventListener::createPerfJITEventListener()) {
// Create at least one dynamic library for the given jit compiler.
dylibs_.resize(std::max<size_t>(1, num_dylibs));
for (size_t i = 0; i < dylibs_.size(); ++i) {
dylibs_[i] = &execution_session_->createBareJITDylib(
absl::StrCat("<xla_jit_dylib_", i, ">"));
if (definition_generator) {
dylibs_[i]->addGenerator(definition_generator(target_machine_.get()));
}
}

// Register GDB and perf event listeners with the object linking layer.
if (gdb_) object_linking_layer_->registerJITEventListener(*gdb_);
if (perf_) object_linking_layer_->registerJITEventListener(*perf_);
if (gdb_) object_layer_->registerJITEventListener(*gdb_);
if (perf_) object_layer_->registerJITEventListener(*perf_);
}

JitCompiler::~JitCompiler() {
Expand Down Expand Up @@ -210,6 +219,22 @@ absl::Status JitCompiler::AddModule(llvm::orc::ThreadSafeModule module,
return absl::OkStatus();
}

absl::Status JitCompiler::AddObjFile(
std::unique_ptr<llvm::MemoryBuffer> obj_file, size_t dylib_index) {
if (dylib_index >= dylibs_.size()) {
return Internal("Invalid dylib index %d (num dylibs: %d))", dylib_index,
dylibs_.size());
}

llvm::orc::JITDylib* dylib = dylibs_[dylib_index];
if (auto err = object_layer_->add(*dylib, std::move(obj_file))) {
return Internal("Failed to add object file to dylib %d: %s", dylib_index,
llvm::toString(std::move(err)));
}

return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
absl::Span<const Symbol> symbols) && {
TraceMe trace([&] {
Expand Down Expand Up @@ -243,8 +268,7 @@ absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
auto symbol_map = execution_session_->lookup(std::move(search_order),
std::move(lookup_set));
if (auto err = symbol_map.takeError()) {
return Internal("Failed to lookup symbols: %s",
llvm::toString(std::move(err)));
return Internal("%s", llvm::toString(std::move(err)));
}

// Resolve type-erased symbol pointers from the symbol map.
Expand All @@ -260,17 +284,14 @@ absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
}

return std::make_unique<CompiledFunctionLibrary>(
std::move(execution_session_), std::move(resolved_map));
std::move(execution_session_), std::move(object_layer_),
std::move(resolved_map));
}

JitCompiler::TaskDispatcher::TaskDispatcher(TaskRunner task_runner)
: task_runner_(std::move(task_runner)) {}

JitCompiler::TaskDispatcher::~TaskDispatcher() {
absl::MutexLock lock(&mu_);
DCHECK(num_dispatched_tasks_ == 0)
<< "TaskDispatcher is still dispatching tasks";
}
JitCompiler::TaskDispatcher::~TaskDispatcher() { shutdown(); }

void JitCompiler::TaskDispatcher::dispatch(
std::unique_ptr<llvm::orc::Task> task) {
Expand Down Expand Up @@ -309,8 +330,10 @@ void JitCompiler::TaskDispatcher::shutdown() {

JitCompiler::CompiledFunctionLibrary::CompiledFunctionLibrary(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer,
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map)
: execution_session_(std::move(execution_session)),
object_layer_(std::move(object_layer)),
symbols_map_(std::move(symbols_map)) {
DCHECK(execution_session_) << "Execution session must not be null";
}
Expand Down
39 changes: 32 additions & 7 deletions third_party/xla/xla/backends/cpu/codegen/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class JitCompiler {
using Task = std::function<void()>; // NOLINT (must be copyable)
using TaskRunner = absl::AnyInvocable<void(Task)>;

// A callback that returns a definition generator that will be added to all
// dynamic libraries created by the jit compiler. Definition generator enables
// linking host runtime symbols into the jit-compiled function library.
using DefinitionGenerator =
std::function<std::unique_ptr<llvm::orc::DefinitionGenerator>(
llvm::TargetMachine*)>;

JitCompiler(JitCompiler&&) = default;
JitCompiler& operator=(JitCompiler&&) = default;

Expand Down Expand Up @@ -93,6 +100,10 @@ class JitCompiler {
// multiple dynamic libraries we enable parallel compilation.
size_t num_dylibs = 1;

// Optional definition generator to inject host runtime symbols into the
// jit-compiled function library.
DefinitionGenerator definition_generator;

// Maximum CPU instruction set for wich the compiler should generate code.
// If instruction set is empty, compiler will generate code for all ISA
// extensions detected on the current machine.
Expand All @@ -109,11 +120,19 @@ class JitCompiler {
absl::Status AddModule(llvm::orc::ThreadSafeModule module,
size_t dylib_index = 0);

// Compiles all added LLVM modules into the FunctionLibrary by resolving all
// symbols in `symbols`. After this method returns, the FunctionLibrary will
// contain compiled functions that can be invoked via function calls. Returned
// FunctionLibrary track type ids of the resolved symbols, but the compiler
// doesn't verify that LLVM IR function signature matches the type id.
// Adds an object file to the dynamic library at `dylib_index`.
absl::Status AddObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,
size_t dylib_index = 0);

// Compiles all added LLVM modules and object files into the FunctionLibrary
// by resolving all symbols in `symbols`.
//
// After this method returns, the FunctionLibrary will contain compiled
// functions that can be invoked via function calls. Returned FunctionLibrary
// tracks type ids of the resolved symbols, but the compiler doesn't verify
// that LLVM IR function signature matches the type id, and it's up to the
// user to make sure that function types actually match, otherwise it will
// lead to run-time crashes.
//
// TODO(ezhulenev): Add an option to pass symbol (function) types at compile
// time together with names and type-check LLVM function signature against the
Expand All @@ -123,11 +142,14 @@ class JitCompiler {
absl::StatusOr<std::unique_ptr<FunctionLibrary>> Compile(
absl::Span<const Symbol> symbols) &&;

llvm::TargetMachine* target_machine() { return target_machine_.get(); }

private:
JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder,
std::shared_ptr<llvm::TargetMachine> target_machine,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs);
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
DefinitionGenerator definition_generator);

// LLVM ORC task dispatcher that uses `TaskRunner` to run compilation tasks.
class TaskDispatcher : public llvm::orc::TaskDispatcher {
Expand Down Expand Up @@ -156,6 +178,7 @@ class JitCompiler {

CompiledFunctionLibrary(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer,
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map);

~CompiledFunctionLibrary() final;
Expand All @@ -165,6 +188,7 @@ class JitCompiler {

private:
std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map_;
};

Expand All @@ -174,9 +198,10 @@ class JitCompiler {
std::shared_ptr<llvm::TargetMachine> target_machine_;

std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_linking_layer_;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
std::unique_ptr<llvm::orc::IRCompileLayer> compile_layer_;

// Non-owning pointers to dynamic libraries created for the execution session.
std::vector<llvm::orc::JITDylib*> dylibs_;

// Non owning pointer to JIT event listeners for gdb and perf.
Expand Down
116 changes: 104 additions & 12 deletions third_party/xla/xla/backends/cpu/codegen/jit_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ limitations under the License.
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/CoreContainers.h"
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "xla/backends/cpu/codegen/function_library.h"
#include "xla/tsl/lib/core/status_test_util.h"
Expand All @@ -42,6 +52,31 @@ limitations under the License.

namespace xla::cpu {

// We use static function to compile the function library, because we transfer
// compiler object into the function and make sure that it gets destroyed before
// returning the function library to the caller, as we test that we don't
// accidentally reference freed objects owned by the compiler.
static absl::StatusOr<std::unique_ptr<FunctionLibrary>> Compile(
JitCompiler compiler, absl::Span<const FunctionLibrary::Symbol> symbols) {
return std::move(compiler).Compile(symbols);
};

// Parses the LLVM IR into a ThreadSafeModule.
static absl::StatusOr<llvm::orc::ThreadSafeModule> ParseModule(
llvm::orc::ThreadSafeContext& context, std::string_view ir,
std::string_view name) {
llvm::SMDiagnostic diagnostic;
llvm::MemoryBufferRef ir_buffer(ir, name);

auto m = llvm::parseAssembly(ir_buffer, diagnostic, *context.getContext());
if (m == nullptr) {
return Internal("Failed to parse LLVM IR: %s",
diagnostic.getMessage().str());
}

return llvm::orc::ThreadSafeModule(std::move(m), context);
}

TEST(JitCompilerTest, Compile) {
auto context = std::make_unique<llvm::LLVMContext>();
llvm::orc::ThreadSafeContext tsc(std::move(context));
Expand Down Expand Up @@ -80,18 +115,9 @@ TEST(JitCompilerTest, Compile) {

auto add_module = [&](std::string_view ir, std::string_view name,
size_t dylib_index) -> absl::Status {
llvm::SMDiagnostic diagnostic;
llvm::MemoryBufferRef ir_buffer(ir, name);

auto m = llvm::parseAssembly(ir_buffer, diagnostic, *tsc.getContext());
if (m == nullptr) {
return Internal("Failed to parse LLVM IR: %s",
diagnostic.getMessage().str());
}

llvm::orc::ThreadSafeModule tsm(std::move(m), tsc);
TF_ASSIGN_OR_RETURN(llvm::orc::ThreadSafeModule tsm,
ParseModule(tsc, ir, name));
TF_RETURN_IF_ERROR(compiler.AddModule(std::move(tsm), dylib_index));

return absl::OkStatus();
};

Expand All @@ -104,7 +130,7 @@ TEST(JitCompilerTest, Compile) {
FunctionLibrary::Sym<ScalarFn>("MulInplace")};

TF_ASSERT_OK_AND_ASSIGN(auto function_library,
std::move(compiler).Compile(symbols));
Compile(std::move(compiler), symbols));

EXPECT_GE(num_tasks, 2);

Expand All @@ -127,4 +153,70 @@ TEST(JitCompilerTest, Compile) {
EXPECT_EQ(value, 4.0f);
}

class ExternalDefinitionGenerator : public llvm::orc::DefinitionGenerator {
public:
static void AddInplace(float* value) { *value += *value; }

llvm::Error tryToGenerate(llvm::orc::LookupState&, llvm::orc::LookupKind,
llvm::orc::JITDylib& jit_dylib,
llvm::orc::JITDylibLookupFlags,
const llvm::orc::SymbolLookupSet& names) final {
llvm::orc::SymbolMap new_defs;
for (auto& [name, flags] : names) {
if (*name == "__external_fn") {
new_defs[name] = llvm::orc::ExecutorSymbolDef{
llvm::orc::ExecutorAddr(reinterpret_cast<uint64_t>(&AddInplace)),
llvm::JITSymbolFlags::None};
}
}

cantFail(jit_dylib.define(llvm::orc::absoluteSymbols(std::move(new_defs))));
return llvm::Error::success();
}
};

TEST(JitCompilerTest, ExternalDefinitionGenerator) {
auto context = std::make_unique<llvm::LLVMContext>();
llvm::orc::ThreadSafeContext tsc(std::move(context));

JitCompiler::Options options;
options.definition_generator = [](llvm::TargetMachine*) {
return std::make_unique<ExternalDefinitionGenerator>();
};

TF_ASSERT_OK_AND_ASSIGN(
auto compiler,
JitCompiler::Create(llvm::TargetOptions(), llvm::CodeGenOptLevel::None,
std::move(options), /*task_runner=*/nullptr));

constexpr std::string_view call_external_fn_ir = R"(
declare void @__external_fn(ptr %arg)
define void @CallExternalFn(ptr %arg) {
call void @__external_fn(ptr %arg)
ret void
})";

TF_ASSERT_OK_AND_ASSIGN(
llvm::orc::ThreadSafeModule tsm,
ParseModule(tsc, call_external_fn_ir, "CallExternalFn"));

TF_ASSERT_OK(compiler.AddModule(std::move(tsm)));

using ScalarFn = void(float*);
std::vector<FunctionLibrary::Symbol> symbols = {
FunctionLibrary::Sym<ScalarFn>("CallExternalFn")};

TF_ASSERT_OK_AND_ASSIGN(auto function_library,
Compile(std::move(compiler), symbols));

TF_ASSERT_OK_AND_ASSIGN(
ScalarFn * call_external_fn,
function_library->ResolveFunction<ScalarFn>("CallExternalFn"));

float value = 1.0f;
call_external_fn(&value);
EXPECT_EQ(value, 2.0f);
}

} // namespace xla::cpu
Loading

0 comments on commit 946f70b

Please sign in to comment.