From 97d2f361a5265fe890a619af0c8e524ab73f8ea2 Mon Sep 17 00:00:00 2001 From: Jonathan Coates Date: Thu, 4 Apr 2024 18:22:14 +0100 Subject: [PATCH] Poll interrupted state when reading from InputStreams This prevents hard aborts when reading incredibly long (1GB+ strings). Ideally we wouldn't even be able to create strings that long, but that's a whole 'nother issue! --- .../cc/tweaked/cobalt/build/ClassEmitter.kt | 2 + .../squiddev/cobalt/compiler/InputReader.java | 31 ++++++++++-- .../squiddev/cobalt/compiler/LoadState.java | 31 +----------- .../org/squiddev/cobalt/compiler/LuaC.java | 30 ++++++++++-- .../java/org/squiddev/cobalt/lib/BaseLib.java | 49 +++++++++++-------- .../cobalt/lib/system/SystemBaseLib.java | 17 +++++-- .../org/squiddev/cobalt/ProtectionTest.java | 2 +- .../squiddev/cobalt/compiler/SimpleTests.java | 4 +- src/test/resources/protection/load.lua | 13 +++++ 9 files changed, 112 insertions(+), 67 deletions(-) create mode 100644 src/test/resources/protection/load.lua diff --git a/build-tools/src/main/kotlin/cc/tweaked/cobalt/build/ClassEmitter.kt b/build-tools/src/main/kotlin/cc/tweaked/cobalt/build/ClassEmitter.kt index b9ba3ac6..159c44b5 100644 --- a/build-tools/src/main/kotlin/cc/tweaked/cobalt/build/ClassEmitter.kt +++ b/build-tools/src/main/kotlin/cc/tweaked/cobalt/build/ClassEmitter.kt @@ -68,6 +68,8 @@ private val subclassRelations = mapOf( UnorderedPair("org/squiddev/cobalt/LuaValue", "org/squiddev/cobalt/LuaString") to "org/squiddev/cobalt/LuaValue", UnorderedPair("org/squiddev/cobalt/Varargs", "org/squiddev/cobalt/LuaValue") to "org/squiddev/cobalt/Varargs", UnorderedPair("org/squiddev/cobalt/LuaError", "org/squiddev/cobalt/compiler/CompileException") to "java/lang/Exception", + UnorderedPair("org/squiddev/cobalt/LuaError", "java/lang/Exception") to "java/lang/Exception", + UnorderedPair("org/squiddev/cobalt/compiler/CompileException", "java/lang/Exception") to "java/lang/Exception" ) /** A [ClassWriter] extension which avoids loading classes when computing frames. */ diff --git a/src/main/java/org/squiddev/cobalt/compiler/InputReader.java b/src/main/java/org/squiddev/cobalt/compiler/InputReader.java index b468fe11..b4884670 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/InputReader.java +++ b/src/main/java/org/squiddev/cobalt/compiler/InputReader.java @@ -7,10 +7,31 @@ /** * A basic byte-by-byte input stream, which can yield when reading. */ -public interface InputReader { - int read() throws CompileException, LuaError, UnwindThrowable; - - default int resume(Varargs varargs) throws CompileException, LuaError, UnwindThrowable { - throw new IllegalStateException("Cannot resume a non-yielding InputReader."); +public abstract class InputReader { + protected InputReader() { } + + /** + * Read a single byte from this input. + * + * @return The read byte. + * @throws LuaError If the underlying reader threw a Lua error. + * @throws CompileException If reading failed for some other reason. Unlike a {@link LuaError}, this will not be + * passed to the {@code xpcall} error handler. + * @throws UnwindThrowable If the reader yielded. {@link #resume(Varargs)} will be called when the coroutine is + * resumed. + */ + public abstract int read() throws CompileException, LuaError, UnwindThrowable; + + /** + * Resume this reader after yielding. + * + * @param varargs The value returned from the function above this in the stack + * @return The read byte. + * @throws LuaError If the underlying reader threw a Lua error. + * @throws CompileException If reading failed for some other reason. Unlike a {@link LuaError}, this will not be + * passed to the {@code xpcall} error handler. + * @throws UnwindThrowable If the reader yielded. + */ + public abstract int resume(Varargs varargs) throws CompileException, LuaError, UnwindThrowable; } diff --git a/src/main/java/org/squiddev/cobalt/compiler/LoadState.java b/src/main/java/org/squiddev/cobalt/compiler/LoadState.java index ff689cb5..8a0ce031 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/LoadState.java +++ b/src/main/java/org/squiddev/cobalt/compiler/LoadState.java @@ -38,7 +38,7 @@ * Class to manage loading of {@link Prototype} instances. *

* The {@link LoadState} class exposes one main function, - * namely {@link #load(LuaState, InputStream, LuaString, LuaTable)}, + * namely {@link #load(LuaState, InputStream, LuaString, LuaValue)}, * to be used to load code from a particular input stream. *

* A simple pattern for loading and executing code is @@ -60,11 +60,6 @@ * @see LuaC */ public final class LoadState { - /** - * Name for compiled chunks - */ - private static final LuaString SOURCE_BINARY_STRING = valueOf("=?"); - /** * Construct our standard Lua function. */ @@ -108,29 +103,7 @@ public static LuaClosure load(LuaState state, InputStream stream, String name, L * @throws CompileException If the stream cannot be loaded. */ public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaValue env) throws CompileException, LuaError { - return load(state, stream, name, null, env); - } - - public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaString mode, LuaValue env) throws CompileException, LuaError { - return state.compiler.load(LuaC.compile(state, stream, name, mode), env); - } - - /** - * Construct a source name from a supplied chunk name - * - * @param name String name that appears in the chunk - * @return source file name - */ - static LuaString getSourceName(LuaString name) { - if (name.length() > 0) { - return switch (name.charAt(0)) { - case '@', '=' -> name.substring(1); - case 27 -> SOURCE_BINARY_STRING; - default -> name; - }; - } - - return name; + return state.compiler.load(LuaC.compile(state, stream, name), env); } private static final int NAME_LENGTH = 30; diff --git a/src/main/java/org/squiddev/cobalt/compiler/LuaC.java b/src/main/java/org/squiddev/cobalt/compiler/LuaC.java index e375ea62..3aab107d 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/LuaC.java +++ b/src/main/java/org/squiddev/cobalt/compiler/LuaC.java @@ -27,6 +27,7 @@ import cc.tweaked.cobalt.internal.unwind.AutoUnwind; import cc.tweaked.cobalt.internal.unwind.SuspendedAction; +import org.checkerframework.checker.nullness.qual.Nullable; import org.squiddev.cobalt.*; import org.squiddev.cobalt.compiler.LoadState.FunctionFactory; import org.squiddev.cobalt.function.LuaInterpretedFunction; @@ -145,13 +146,13 @@ private LuaC() { * @throws CompileException If there is a syntax error. */ public static Prototype compile(LuaState state, InputStream stream, String name) throws CompileException, LuaError { - return compile(state, stream, valueOf(name), null); + return compile(state, stream, valueOf(name)); } - public static Prototype compile(LuaState state, InputStream stream, LuaString name, LuaString mode) throws CompileException, LuaError { + public static Prototype compile(LuaState state, InputStream stream, LuaString name) throws CompileException, LuaError { Object result = SuspendedAction.noYield(() -> { try { - return compile(state, new InputStreamReader(stream), name, mode); + return compile(state, new InputStreamReader(stream), name, null); } catch (CompileException e) { return e; } @@ -186,9 +187,23 @@ private static Prototype loadTextChunk(int firstByte, InputReader stream, LuaStr return parser.mainFunction(); } - public record InputStreamReader(InputStream stream) implements InputReader { + public static final class InputStreamReader extends InputReader { + private final @Nullable LuaState state; + private final InputStream stream; + + public InputStreamReader(InputStream stream) { + this(null, stream); + } + + public InputStreamReader(@Nullable LuaState state, InputStream stream) { + this.state = state; + this.stream = stream; + } + @Override - public int read() throws CompileException { + public int read() throws CompileException, UnwindThrowable, LuaError { + if (state != null && state.isInterrupted()) state.handleInterrupt(); + try { return stream.read(); } catch (IOException e) { @@ -196,5 +211,10 @@ public int read() throws CompileException { throw new CompileException("io error: " + message); } } + + @Override + public int resume(Varargs varargs) throws CompileException, LuaError, UnwindThrowable { + return read(); + } } } diff --git a/src/main/java/org/squiddev/cobalt/lib/BaseLib.java b/src/main/java/org/squiddev/cobalt/lib/BaseLib.java index d087b9ca..9b968fae 100644 --- a/src/main/java/org/squiddev/cobalt/lib/BaseLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/BaseLib.java @@ -29,10 +29,10 @@ import org.squiddev.cobalt.*; import org.squiddev.cobalt.compiler.CompileException; import org.squiddev.cobalt.compiler.InputReader; -import org.squiddev.cobalt.compiler.LoadState; import org.squiddev.cobalt.compiler.LuaC; import org.squiddev.cobalt.debug.DebugFrame; import org.squiddev.cobalt.function.*; +import org.squiddev.cobalt.unwind.SuspendedTask; import java.io.InputStream; import java.nio.ByteBuffer; @@ -69,7 +69,7 @@ public static void add(LuaTable env) { RegisteredFunction.ofV("assert", BaseLib::assert_), RegisteredFunction.of("getfenv", BaseLib::getfenv), RegisteredFunction.ofV("getmetatable", BaseLib::getmetatable), - RegisteredFunction.ofV("loadstring", BaseLib::loadstring), + RegisteredFunction.ofS("loadstring", BaseLib::loadstring), RegisteredFunction.ofV("select", BaseLib::select), RegisteredFunction.ofV("type", BaseLib::type), RegisteredFunction.ofV("rawequal", BaseLib::rawequal), @@ -138,10 +138,11 @@ private static LuaValue getmetatable(LuaState state, Varargs args) throws LuaErr return mt != null ? mt.rawget(Constants.METATABLE).optValue(mt) : Constants.NIL; } - private static Varargs loadstring(LuaState state, Varargs args) throws LuaError { + private static Varargs loadstring(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable { // loadstring( string [,chunkname] ) -> chunk | nil, msg LuaString script = args.arg(1).checkLuaString(); - return BaseLib.loadStream(state, script.toInputStream(), args.arg(2).optLuaString(script)); + InputStream is = script.toInputStream(); + return loadStream(state, di, is, args.arg(2).optLuaString(script), null, state.globals()); } private static Varargs select(LuaState state, Varargs args) throws LuaError { @@ -311,7 +312,7 @@ public Varargs resumeError(LuaState state, ProtectedCall call, LuaError error) t } // load( func|str [,chunkname[, mode[, env]]] ) -> chunk | nil, msg - static class Load extends ResumableVarArgFunction { + static class Load extends ResumableVarArgFunction { @Override protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable { LuaValue scriptGen = args.arg(1); @@ -322,7 +323,7 @@ protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws Lua // If we're a string, load as normal if (scriptGen.isString()) { LuaString contents = scriptGen.checkLuaString(); - return BaseLib.loadStream(state, contents.toInputStream(), chunkName == null ? contents : chunkName, mode, funcEnv); + return BaseLib.loadStream(state, di, contents.toInputStream(), chunkName == null ? contents : chunkName, mode, funcEnv); } LuaFunction function = scriptGen.checkFunction(); @@ -339,29 +340,35 @@ protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws Lua } @Override - public Varargs resume(LuaState state, ProtectedCall call, Varargs value) throws UnwindThrowable { - return call.resume(state, value).asResultOrFailure(); + public Varargs resume(LuaState state, Object funcState, Varargs value) throws UnwindThrowable, LuaError { + if (funcState instanceof ProtectedCall call) { + return call.resume(state, value).asResultOrFailure(); + } else { + return ((SuspendedTask) funcState).resume(value); + } } @Override - public Varargs resumeError(LuaState state, ProtectedCall call, LuaError error) throws UnwindThrowable { - return call.resumeError(state, error).asResultOrFailure(); - } - } - - public static Varargs loadStream(LuaState state, InputStream is, LuaString chunkName, LuaString mode, LuaValue env) { - try { - return LoadState.load(state, is, chunkName, mode, env); - } catch (LuaError | CompileException e) { - return varargsOf(Constants.NIL, valueOf(e.getMessage())); + public Varargs resumeError(LuaState state, Object funcState, LuaError error) throws UnwindThrowable, LuaError { + if (funcState instanceof ProtectedCall call) { + return call.resumeError(state, error).asResultOrFailure(); + } else { + return super.resumeError(state, funcState, error); + } } } - public static Varargs loadStream(LuaState state, InputStream is, LuaString chunkName) { - return loadStream(state, is, chunkName, null, state.globals()); + private static Varargs loadStream(LuaState state, DebugFrame frame, InputStream is, LuaString chunkName, LuaString mode, LuaValue env) throws UnwindThrowable, LuaError { + return SuspendedAction.run(frame, () -> { + try { + return state.compiler.load(LuaC.compile(state, new LuaC.InputStreamReader(state, is), chunkName, mode), env); + } catch (CompileException e) { + return varargsOf(Constants.NIL, valueOf(e.getMessage())); + } + }); } - private static class FunctionInputReader implements InputReader { + private static class FunctionInputReader extends InputReader { private static final ByteBuffer EMPTY = ByteBuffer.allocate(0); private final LuaState state; diff --git a/src/main/java/org/squiddev/cobalt/lib/system/SystemBaseLib.java b/src/main/java/org/squiddev/cobalt/lib/system/SystemBaseLib.java index a7e4ece1..0e88b437 100644 --- a/src/main/java/org/squiddev/cobalt/lib/system/SystemBaseLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/system/SystemBaseLib.java @@ -2,10 +2,11 @@ import cc.tweaked.cobalt.internal.unwind.SuspendedAction; import org.squiddev.cobalt.*; +import org.squiddev.cobalt.compiler.CompileException; +import org.squiddev.cobalt.compiler.LoadState; import org.squiddev.cobalt.debug.DebugFrame; import org.squiddev.cobalt.function.Dispatch; import org.squiddev.cobalt.function.RegisteredFunction; -import org.squiddev.cobalt.lib.BaseLib; import java.io.InputStream; import java.io.PrintStream; @@ -63,14 +64,14 @@ private static LuaValue collectgarbage(LuaState state, LuaValue arg1, LuaValue a private Varargs loadfile(LuaState state, Varargs args) throws LuaError { // loadfile( [filename] ) -> chunk | nil, msg return args.first().isNil() ? - BaseLib.loadStream(state, in, STDIN_STR) : + SystemBaseLib.loadBasicStream(state, in, STDIN_STR) : SystemBaseLib.loadFile(state, resources, args.arg(1).checkString()); } private Varargs dofile(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable { // dofile( filename ) -> result1, ... Varargs v = args.first().isNil() ? - BaseLib.loadStream(state, in, STDIN_STR) : + SystemBaseLib.loadBasicStream(state, in, STDIN_STR) : SystemBaseLib.loadFile(state, resources, args.arg(1).checkString()); if (v.first().isNil()) { throw new LuaError(v.arg(2).toString()); @@ -99,6 +100,14 @@ private Varargs print(LuaState state, DebugFrame frame, Varargs args) throws Lua }); } + private static Varargs loadBasicStream(LuaState state, InputStream is, LuaString chunkName) { + try { + return LoadState.load(state, is, chunkName, state.globals()); + } catch (LuaError | CompileException e) { + return varargsOf(Constants.NIL, valueOf(e.getMessage())); + } + } + /** * Load from a named file, returning the chunk or nil,error of can't load * @@ -112,7 +121,7 @@ public static Varargs loadFile(LuaState state, ResourceLoader resources, String return varargsOf(Constants.NIL, valueOf("cannot open " + filename + ": No such file or directory")); } try { - return BaseLib.loadStream(state, is, valueOf("@" + filename)); + return loadBasicStream(state, is, valueOf("@" + filename)); } finally { try { is.close(); diff --git a/src/test/java/org/squiddev/cobalt/ProtectionTest.java b/src/test/java/org/squiddev/cobalt/ProtectionTest.java index 75c9d632..3b17e6f4 100644 --- a/src/test/java/org/squiddev/cobalt/ProtectionTest.java +++ b/src/test/java/org/squiddev/cobalt/ProtectionTest.java @@ -74,7 +74,7 @@ public void tearDown() { @Timeout(3) @ParameterizedTest(name = ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER) - @ValueSource(strings = {"string", "loop"}) + @ValueSource(strings = {"string", "loop", "load"}) public void run(String name) throws IOException, CompileException, LuaError, InterruptedException { LuaThread.runMain(helpers.state, helpers.loadScript(name)); } diff --git a/src/test/java/org/squiddev/cobalt/compiler/SimpleTests.java b/src/test/java/org/squiddev/cobalt/compiler/SimpleTests.java index 191f75d2..aaeefe78 100644 --- a/src/test/java/org/squiddev/cobalt/compiler/SimpleTests.java +++ b/src/test/java/org/squiddev/cobalt/compiler/SimpleTests.java @@ -50,7 +50,7 @@ public void setup() throws LuaError { private void doTest(String script) { try { InputStream is = new ByteArrayInputStream(script.getBytes(StandardCharsets.UTF_8)); - LuaFunction c = LoadState.interpretedFunction(LuaC.compile(state, is, valueOf("script"), null), _G); + LuaFunction c = LoadState.load(state, is, valueOf("script"), _G); LuaThread.runMain(state, c); } catch (Exception e) { fail("i/o exception: " + e); @@ -127,7 +127,7 @@ public void testZap() { String s = "print('\\z"; assertThrows(CompileException.class, () -> { InputStream is = new ByteArrayInputStream(s.getBytes(StandardCharsets.UTF_8)); - LoadState.interpretedFunction(LuaC.compile(state, is, valueOf("script"), null), _G); + LoadState.load(state, is, valueOf("script"), _G); }); } diff --git a/src/test/resources/protection/load.lua b/src/test/resources/protection/load.lua new file mode 100644 index 00000000..2579e2d1 --- /dev/null +++ b/src/test/resources/protection/load.lua @@ -0,0 +1,13 @@ +--- Test loading long strings + +local function check(...) + local success, message = pcall(...) + + assert(not success, "Expected abort") + assert(message:find("Timed out"), "Got " .. message) +end + +check(function() + local fn, err = load("--[" .. ("="):rep(1e8) .. "[") + print(fn, err) +end)