Skip to content

Commit

Permalink
[Python API] Fix python api for bytecode (iree-org#17343)
Browse files Browse the repository at this point in the history
This fixes an issue seen in iree-org#17278.
MLIR bytecode can have a zero as the last byte depending on the IR. If
this happens by accident parsing will fail. This PR checks if a buffer
begins with the _magic number_ from
https://mlir.llvm.org/docs/BytecodeFormat/#magic-number and sets
zero-termination to false if it does.

However I could not generate MLIR bytecode that ends with a zero other
than the one in the PR. Even if I save the IR used in textual format and
parse it again, some additional bytes are appended and it is not zero
terminated anymore. I think this is because of some behavior in TF.

Still after reading a bit through
https://mlir.llvm.org/docs/BytecodeFormat I think the check implemented
in this PR is the correct behavior for bytecode.

---------

Signed-off-by: Maximilian Bartel <bartel@roofline.ai>
  • Loading branch information
maxbartel authored Jun 25, 2024
1 parent 1f69b85 commit 247de36
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 14 deletions.
12 changes: 11 additions & 1 deletion compiler/bindings/python/iree/compiler/api/ctypes_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def _initializeGlobalCL(*cl_args: str):
_dylib.ireeCompilerSetupGlobalCL(len(cl_args), arg_pointers, b"ctypes", False)


def _is_null_terminated(view: memoryview):
return view.nbytes > 0 and view[-1] == 0


def _is_mlir_bytecode(view: memoryview):
"""Compares the first 4 bytes of the view against the magic number 4d4cef52.
See https://mlir.llvm.org/docs/BytecodeFormat/#magic-number for more info."""
return len(view) >= 4 and view[:4].hex() == "4d4cef52"


class Session:
def __init__(self):
self._global_init = _global_init
Expand Down Expand Up @@ -339,7 +349,7 @@ def wrap_buffer(
buffer,
buffer_len,
# Detect if nul terminated.
True if buffer_len > 0 and view[-1] == 0 else False,
_is_null_terminated(view) and not _is_mlir_bytecode(view),
byref(source_p),
)
)
Expand Down
69 changes: 68 additions & 1 deletion compiler/bindings/python/test/api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import tempfile
import unittest

from iree.compiler.api import *
from iree.compiler.api import (
Session,
Source,
Output,
)
from iree.compiler import ir

class DlFlagsTest(unittest.TestCase):
Expand Down Expand Up @@ -81,6 +85,69 @@ def testInputBuffer(self):
self.assertIn(b"module", bytes(mem))
out.close()

def testInputBytecode(self):
this_dir = os.path.dirname(__file__)
with open(
os.path.join(this_dir, "testdata", "bytecode_testfile.bc"), "rb"
) as f:
bytecode = f.read()
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, bytecode)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()

def testInputZeroTerminatedBytecode(self):
this_dir = os.path.dirname(__file__)
with open(
os.path.join(
this_dir, "testdata", "bytecode_zero_terminated_testfile.bc"
),
"rb",
) as f:
bytecode = f.read()
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, bytecode)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()

def testInputRoundtrip(self):
test_ir = b"builtin.module {}"
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(
session,
bytes(test_ir),
)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir_bytecode(out)
mem = out.map_memory()
bytecode = bytes(mem)
out.close()
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(
session,
bytecode,
)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
text_out = bytes(mem)
out.close()
self.assertIn(b"module", text_out)

def testOutputBytecode(self):
session = Session()
inv = session.invocation()
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.compiler.api import (
Session,
Source,
Output,
)
import os
import iree.compiler.tools.tflite


def generate_test_bytecode():
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, b"builtin.module {}")
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir_bytecode(out)
mem = out.map_memory()

this_dir = os.path.dirname(__file__)
with open(os.path.join(this_dir, "bytecode_testfile.bc"), "wb") as file:
file.write(bytes(mem))


def generate_zero_terminated_bytecode():
"""MLIR Bytecode can also be zero terminated. I couldn't find a way to generate zero terminated
bytecode apart from this. Printing as textual IR and then reparsing and printing as bytecode
removes the zero termination on this IR. This might very well be an odity of TF."""
if not iree.compiler.tools.tflite.is_available():
return
this_dir = os.path.dirname(__file__)
path = os.path.join(this_dir, "..", "..", "tools", "testdata", "tflite_sample.fb")
bytecode = iree.compiler.tools.tflite.compile_file(path, import_only=True)
with open(
os.path.join(this_dir, "bytecode_zero_terminated_testfile.bc"), "wb"
) as file:
file.write(bytecode)


if __name__ == "__main__":
generate_test_bytecode()
generate_zero_terminated_bytecode()
26 changes: 14 additions & 12 deletions compiler/bindings/python/test/tools/compiler_tflite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import tempfile
import unittest

from iree.compiler.tools.ir_tool import __main__ as ir_tool
from iree.compiler.api import (
Session,
Source,
Output,
)

# TODO: No idea why pytype cannot find names from this module.
# pytype: disable=name-error
Expand All @@ -24,18 +28,16 @@
sys.exit(0)


def mlir_bytecode_file_to_text(bytecode_file):
with tempfile.NamedTemporaryFile() as temp_file:
args = ir_tool.parse_arguments(["copy", bytecode_file, "-o", temp_file.name])
ir_tool.main(args)
return str(temp_file.read())


def mlir_bytecode_to_text(bytecode):
with tempfile.NamedTemporaryFile("wb") as temp_bytecode_file:
temp_bytecode_file.write(bytecode)
temp_bytecode_file.flush()
return mlir_bytecode_file_to_text(temp_bytecode_file.name)
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, bytecode)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
text_ir = str(bytes(out.map_memory()))
out.close()
return text_ir


class CompilerTest(unittest.TestCase):
Expand Down

0 comments on commit 247de36

Please sign in to comment.