From 247de366d03d7a4dc0eb7a960e089abaa48a496d Mon Sep 17 00:00:00 2001 From: maxbartel Date: Tue, 25 Jun 2024 17:25:30 +0200 Subject: [PATCH] [Python API] Fix python api for bytecode (#17343) This fixes an issue seen in https://github.com/iree-org/iree/pull/17278. MLIR bytecode can have a zero as the last byte depending on the IR. If this happens by accident parsing will fail. This PR checks if a buffer begins with the _magic number_ from https://mlir.llvm.org/docs/BytecodeFormat/#magic-number and sets zero-termination to false if it does. However I could not generate MLIR bytecode that ends with a zero other than the one in the PR. Even if I save the IR used in textual format and parse it again, some additional bytes are appended and it is not zero terminated anymore. I think this is because of some behavior in TF. Still after reading a bit through https://mlir.llvm.org/docs/BytecodeFormat I think the check implemented in this PR is the correct behavior for bytecode. --------- Signed-off-by: Maximilian Bartel --- .../python/iree/compiler/api/ctypes_dl.py | 12 ++- compiler/bindings/python/test/api/api_test.py | 69 +++++++++++++++++- .../test/api/testdata/bytecode_testfile.bc | Bin 0 -> 90 bytes .../bytecode_zero_terminated_testfile.bc | Bin 0 -> 507 bytes .../api/testdata/generate_mlir_bytecode.py | 47 ++++++++++++ .../python/test/tools/compiler_tflite_test.py | 26 ++++--- 6 files changed, 140 insertions(+), 14 deletions(-) create mode 100644 compiler/bindings/python/test/api/testdata/bytecode_testfile.bc create mode 100644 compiler/bindings/python/test/api/testdata/bytecode_zero_terminated_testfile.bc create mode 100644 compiler/bindings/python/test/api/testdata/generate_mlir_bytecode.py diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py index 755788483390..784fc6160f7e 100644 --- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py +++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py @@ -179,6 +179,16 @@ def _initializeGlobalCL(*cl_args: str): _dylib.ireeCompilerSetupGlobalCL(len(cl_args), arg_pointers, b"ctypes", False) +def _is_null_terminated(view: memoryview): + return view.nbytes > 0 and view[-1] == 0 + + +def _is_mlir_bytecode(view: memoryview): + """Compares the first 4 bytes of the view against the magic number 4d4cef52. + See https://mlir.llvm.org/docs/BytecodeFormat/#magic-number for more info.""" + return len(view) >= 4 and view[:4].hex() == "4d4cef52" + + class Session: def __init__(self): self._global_init = _global_init @@ -339,7 +349,7 @@ def wrap_buffer( buffer, buffer_len, # Detect if nul terminated. - True if buffer_len > 0 and view[-1] == 0 else False, + _is_null_terminated(view) and not _is_mlir_bytecode(view), byref(source_p), ) ) diff --git a/compiler/bindings/python/test/api/api_test.py b/compiler/bindings/python/test/api/api_test.py index ab70fcd0858e..0d52bd3ba004 100644 --- a/compiler/bindings/python/test/api/api_test.py +++ b/compiler/bindings/python/test/api/api_test.py @@ -17,7 +17,11 @@ import tempfile import unittest - from iree.compiler.api import * + from iree.compiler.api import ( + Session, + Source, + Output, + ) from iree.compiler import ir class DlFlagsTest(unittest.TestCase): @@ -81,6 +85,69 @@ def testInputBuffer(self): self.assertIn(b"module", bytes(mem)) out.close() + def testInputBytecode(self): + this_dir = os.path.dirname(__file__) + with open( + os.path.join(this_dir, "testdata", "bytecode_testfile.bc"), "rb" + ) as f: + bytecode = f.read() + session = Session() + inv = session.invocation() + source = Source.wrap_buffer(session, bytecode) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + mem = out.map_memory() + self.assertIn(b"module", bytes(mem)) + out.close() + + def testInputZeroTerminatedBytecode(self): + this_dir = os.path.dirname(__file__) + with open( + os.path.join( + this_dir, "testdata", "bytecode_zero_terminated_testfile.bc" + ), + "rb", + ) as f: + bytecode = f.read() + session = Session() + inv = session.invocation() + source = Source.wrap_buffer(session, bytecode) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + mem = out.map_memory() + self.assertIn(b"module", bytes(mem)) + out.close() + + def testInputRoundtrip(self): + test_ir = b"builtin.module {}" + session = Session() + inv = session.invocation() + source = Source.wrap_buffer( + session, + bytes(test_ir), + ) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir_bytecode(out) + mem = out.map_memory() + bytecode = bytes(mem) + out.close() + session = Session() + inv = session.invocation() + source = Source.wrap_buffer( + session, + bytecode, + ) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + mem = out.map_memory() + text_out = bytes(mem) + out.close() + self.assertIn(b"module", text_out) + def testOutputBytecode(self): session = Session() inv = session.invocation() diff --git a/compiler/bindings/python/test/api/testdata/bytecode_testfile.bc b/compiler/bindings/python/test/api/testdata/bytecode_testfile.bc new file mode 100644 index 0000000000000000000000000000000000000000..0ee6e0115747d0393fb6f8fed1e6388d8b5cabc7 GIT binary patch literal 90 zcmebEc^|~<>*E<@XsKtQXONy*!obMO%m{?+%)G3OjI6@kOuXXEtjw&eEaI$;0gR07 qEbPpTjBLz|tc(oW?5zBPNu`-NC7F2)x%nxjIjIb~3>=)yK%D?(NDR>c literal 0 HcmV?d00001 diff --git a/compiler/bindings/python/test/api/testdata/bytecode_zero_terminated_testfile.bc b/compiler/bindings/python/test/api/testdata/bytecode_zero_terminated_testfile.bc new file mode 100644 index 0000000000000000000000000000000000000000..500820fe9cca00d6a6a2d2fa31dd4a134fa2f84f GIT binary patch literal 507 zcmZXO?@A*v6vcD*_HG!$5Qd=?DWw!CQp8#*BJLuxh#v$6(FX|8$<)C7Bgsg8%6*i5 zh0Vx*a`Pqkp5&bKd$|Aegb(+3PiL3w&3d!5fx(OeC4d|W6GDN{*9ly65s%CT-wHVu z@^hOGmlU~>5|YV;i*LW@l#-;3@>xtknbBCH7*7_5%cIleY)ccsgP1#|865!5DGn*p z9B>s?PRD>Jw4g&;0-nH%DP<`U0)l{Fqg#RuQjS2Fi9Di;sB?aG5fuO$ z#q#gN$X-Ur>EbHS$1_@uw_