From 0067f4f379222d42f5764440d9a40c4de23aaf49 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Sun, 29 Oct 2023 23:09:45 -0700 Subject: [PATCH] [ntcore] Fix NT write_impl The previous fix didn't handle all cases correctly. Instead, add a new function to raw_ostream (SetNumBytesInBuffer) to allow always using the full buffer size, and revamp write_impl to more cleanly handle all cases. --- .../native/cpp/net/WebSocketConnection.cpp | 56 ++++++++++--------- ...-raw_ostream-Add-SetNumBytesInBuffer.patch | 25 +++++++++ upstream_utils/update_llvm.py | 1 + .../thirdparty/llvm/include/wpi/raw_ostream.h | 5 ++ 4 files changed, 61 insertions(+), 26 deletions(-) create mode 100644 upstream_utils/llvm_patches/0033-raw_ostream-Add-SetNumBytesInBuffer.patch diff --git a/ntcore/src/main/native/cpp/net/WebSocketConnection.cpp b/ntcore/src/main/native/cpp/net/WebSocketConnection.cpp index c5052195145..3664497293d 100644 --- a/ntcore/src/main/native/cpp/net/WebSocketConnection.cpp +++ b/ntcore/src/main/native/cpp/net/WebSocketConnection.cpp @@ -30,7 +30,8 @@ class WebSocketConnection::Stream final : public wpi::raw_ostream { public: explicit Stream(WebSocketConnection& conn) : m_conn{conn} { auto& buf = conn.m_bufs.back(); - SetBuffer(buf.base + buf.len, kAllocSize - buf.len); + SetBuffer(buf.base, kAllocSize); + SetNumBytesInBuffer(buf.len); } ~Stream() final { @@ -48,36 +49,39 @@ class WebSocketConnection::Stream final : public wpi::raw_ostream { }; void WebSocketConnection::Stream::write_impl(const char* data, size_t len) { - if (len >= kAllocSize) { - // only called by raw_ostream::write() when the buffer is empty and a large - // thing is being written; called with a length that's a multiple of the - // alloc size - assert((len % kAllocSize) == 0); - assert(m_conn.m_bufs.back().len == 0); - while (len > 0) { - auto& buf = m_conn.m_bufs.back(); - std::memcpy(buf.base, data, kAllocSize); - buf.len = kAllocSize; - m_conn.m_framePos += kAllocSize; - m_conn.m_written += kAllocSize; - data += kAllocSize; - len -= kAllocSize; + if (data == m_conn.m_bufs.back().base) { + // flush_nonempty() case + m_conn.m_bufs.back().len = len; + if (!m_disableAlloc) { + m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin; + m_conn.StartFrame(wpi::WebSocket::Frame::kFragment); + SetBuffer(m_conn.m_bufs.back().base, kAllocSize); + } + return; + } + + bool updateBuffer = false; + while (len > 0) { + auto& buf = m_conn.m_bufs.back(); + assert(buf.len <= kAllocSize); + size_t amt = (std::min)(kAllocSize - buf.len, len); + if (amt > 0) { + std::memcpy(buf.base + buf.len, data, amt); + buf.len += amt; + m_conn.m_framePos += amt; + m_conn.m_written += amt; + data += amt; + len -= amt; + } + if (buf.len >= kAllocSize && (len > 0 || !m_disableAlloc)) { // fragment the current frame and start a new one m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin; m_conn.StartFrame(wpi::WebSocket::Frame::kFragment); + updateBuffer = true; } - SetBuffer(m_conn.m_bufs.back().base, kAllocSize); - [[unlikely]] return; } - auto& buf = m_conn.m_bufs.back(); - buf.len += len; - m_conn.m_framePos += len; - m_conn.m_written += len; - if (!m_disableAlloc && buf.len >= kAllocSize) { - // fragment the current frame and start a new one - [[unlikely]] m_conn.m_frames.back().opcode &= ~wpi::WebSocket::kFlagFin; - m_conn.StartFrame(wpi::WebSocket::Frame::kFragment); + if (updateBuffer) { SetBuffer(m_conn.m_bufs.back().base, kAllocSize); } } @@ -132,7 +136,7 @@ void WebSocketConnection::StartFrame(uint8_t opcode) { void WebSocketConnection::FinishText() { assert(!m_bufs.empty()); auto& buf = m_bufs.back(); - assert(buf.len < kAllocSize + 1); // safe because we alloc one more byte + assert(buf.len < (kAllocSize + 1)); // safe because we alloc one more byte buf.base[buf.len++] = ']'; } diff --git a/upstream_utils/llvm_patches/0033-raw_ostream-Add-SetNumBytesInBuffer.patch b/upstream_utils/llvm_patches/0033-raw_ostream-Add-SetNumBytesInBuffer.patch new file mode 100644 index 00000000000..a36cfc66530 --- /dev/null +++ b/upstream_utils/llvm_patches/0033-raw_ostream-Add-SetNumBytesInBuffer.patch @@ -0,0 +1,25 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Peter Johnson +Date: Sun, 29 Oct 2023 23:00:08 -0700 +Subject: [PATCH 33/33] raw_ostream: Add SetNumBytesInBuffer + +--- + llvm/include/llvm/Support/raw_ostream.h | 5 +++++ + 1 file changed, 5 insertions(+) + +diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h +index 9a9a1f688313a5784a58a70f2cb4cc0d6ec70e79..39f98e4e7696e28587779e90a03995463767b02d 100644 +--- a/llvm/include/llvm/Support/raw_ostream.h ++++ b/llvm/include/llvm/Support/raw_ostream.h +@@ -356,6 +356,11 @@ protected: + SetBufferAndMode(BufferStart, Size, BufferKind::ExternalBuffer); + } + ++ /// Force-set the number of bytes in the raw_ostream buffer. ++ void SetNumBytesInBuffer(size_t Size) { ++ OutBufCur = OutBufStart + Size; ++ } ++ + /// Return an efficient buffer size for the underlying output mechanism. + virtual size_t preferred_buffer_size() const; + diff --git a/upstream_utils/update_llvm.py b/upstream_utils/update_llvm.py index edbd3df716f..75e926cfc97 100755 --- a/upstream_utils/update_llvm.py +++ b/upstream_utils/update_llvm.py @@ -210,6 +210,7 @@ def main(): "0030-Remove-DenseMap-GTest-printer-test.patch", "0031-Replace-deprecated-std-aligned_storage_t.patch", "0032-Fix-compilation-of-MathExtras.h-on-Windows-with-sdl.patch", + "0033-raw_ostream-Add-SetNumBytesInBuffer.patch", ]: git_am( os.path.join(wpilib_root, "upstream_utils/llvm_patches", f), diff --git a/wpiutil/src/main/native/thirdparty/llvm/include/wpi/raw_ostream.h b/wpiutil/src/main/native/thirdparty/llvm/include/wpi/raw_ostream.h index 7a804d53fe2..06ff2063fdb 100644 --- a/wpiutil/src/main/native/thirdparty/llvm/include/wpi/raw_ostream.h +++ b/wpiutil/src/main/native/thirdparty/llvm/include/wpi/raw_ostream.h @@ -356,6 +356,11 @@ class raw_ostream { SetBufferAndMode(BufferStart, Size, BufferKind::ExternalBuffer); } + /// Force-set the number of bytes in the raw_ostream buffer. + void SetNumBytesInBuffer(size_t Size) { + OutBufCur = OutBufStart + Size; + } + /// Return an efficient buffer size for the underlying output mechanism. virtual size_t preferred_buffer_size() const;