diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs
new file mode 100644
index 00000000..96576fa2
--- /dev/null
+++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs
@@ -0,0 +1,207 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.IO.Pipelines;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Bedrock.Framework.Protocols.WebSockets
+{
+ ///
+ /// A writer-like construct for writing WebSocket messages.
+ ///
+ public class WebSocketMessageWriter
+ {
+ ///
+ /// True if a message is in progress, false otherwise.
+ ///
+ internal bool _messageInProgress;
+
+ ///
+ /// The current protocol type for this writer, client or server.
+ ///
+ private WebSocketProtocolType _protocolType;
+
+ ///
+ /// The transport to write to.
+ ///
+ private PipeWriter _transport;
+
+ ///
+ /// True if the current message is text, false otherwise.
+ ///
+ public bool _isText;
+
+ ///
+ /// Creates an instance of a WebSocketMessageWriter.
+ ///
+ /// The transport to write to.
+ /// The protocol type for this writer.
+ public WebSocketMessageWriter(PipeWriter transport, WebSocketProtocolType protocolType)
+ {
+ _transport = transport;
+ _protocolType = protocolType;
+ }
+
+ ///
+ /// Starts a message in the writer.
+ ///
+ /// The payload to write.
+ /// Whether the payload is text or not.
+ /// A cancellation token, if any.
+ internal ValueTask StartMessageAsync(ReadOnlySequence payload, bool isText, CancellationToken cancellationToken = default)
+ {
+ if(_messageInProgress)
+ {
+ ThrowMessageAlreadyStarted();
+ }
+
+ _messageInProgress = true;
+ _isText = isText;
+ return DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, false, payload, cancellationToken);
+ }
+
+ ///
+ /// Writes a single frame message with the writer.
+ ///
+ /// The payload to write.
+ /// Whether the payload is text or not.
+ /// A cancellation token, if any.
+ internal ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence payload, bool isText, CancellationToken cancellationToken = default)
+ {
+ if (_messageInProgress)
+ {
+ ThrowMessageAlreadyStarted();
+ }
+
+ var result = DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, true, payload, cancellationToken);
+ _messageInProgress = false;
+
+ return result;
+ }
+
+ ///
+ /// Writes a message payload portion with the writer.
+ ///
+ /// The payload to write.
+ /// A cancellation token, if any.
+ public ValueTask WriteAsync(ReadOnlySequence payload, CancellationToken cancellationToken = default)
+ {
+ if (!_messageInProgress)
+ {
+ ThrowMessageNotStarted();
+ }
+
+ return DoWriteAsync(WebSocketOpcode.Continuation, false, payload, cancellationToken);
+ }
+
+ ///
+ /// Ends a message in progress.
+ ///
+ /// The payload to write.
+ /// A cancellation token, if any.
+ public ValueTask EndMessageAsync(ReadOnlySequence payload, CancellationToken cancellationToken = default)
+ {
+ if(!_messageInProgress)
+ {
+ ThrowMessageNotStarted();
+ }
+
+ var result = DoWriteAsync(WebSocketOpcode.Continuation, true, payload, cancellationToken);
+ _messageInProgress = false;
+
+ return result;
+ }
+
+ ///
+ /// Sends a message payload portion.
+ ///
+ /// The WebSocket opcode to send.
+ /// Whether or not this payload portion represents the end of the message.
+ /// The payload to send.
+ /// A cancellation token, if any.
+ private ValueTask DoWriteAsync(WebSocketOpcode opcode, bool endOfMessage, ReadOnlySequence payload, CancellationToken cancellationToken)
+ {
+ var masked = _protocolType == WebSocketProtocolType.Client;
+ var header = new WebSocketHeader(endOfMessage, opcode, masked, (ulong)payload.Length, masked ? WebSocketHeader.GenerateMaskingKey() : 0);
+
+ var frame = new WebSocketWriteFrame(header, payload);
+ var writer = new WebSocketFrameWriter();
+
+ writer.WriteMessage(frame, _transport);
+ var flushTask = _transport.FlushAsync(cancellationToken);
+ if (flushTask.IsCompletedSuccessfully)
+ {
+ var result = flushTask.Result;
+ if(result.IsCanceled)
+ {
+ ThrowMessageCanceled();
+ }
+
+ if (result.IsCompleted && !endOfMessage)
+ {
+ ThrowTransportClosed();
+ }
+
+ return new ValueTask();
+ }
+ else
+ {
+ return PerformFlushAsync(flushTask, endOfMessage);
+ }
+ }
+
+ ///
+ /// Performs a flush of the writer asynchronously.
+ ///
+ /// The active writer flush task.
+ /// Whether or not this flush will send an end-of-message.
+ ///
+ private async ValueTask PerformFlushAsync(ValueTask flushTask, bool endOfMessage)
+ {
+ var result = await flushTask.ConfigureAwait(false);
+ if (result.IsCanceled)
+ {
+ ThrowMessageCanceled();
+ }
+
+ if (result.IsCompleted && !endOfMessage)
+ {
+ ThrowTransportClosed();
+ }
+ }
+
+ ///
+ /// Throws that a message was canceled unexpectedly.
+ ///
+ private void ThrowMessageCanceled()
+ {
+ throw new OperationCanceledException("Flush was canceled while a write was still in progress.");
+ }
+
+ ///
+ /// Throws that the underlying transport closed unexpectedly.
+ ///
+ private void ThrowTransportClosed()
+ {
+ throw new InvalidOperationException("Transport closed unexpectedly while a message is still in progress.");
+ }
+
+ ///
+ /// Throws if a message has not yet been started.
+ ///
+ private void ThrowMessageNotStarted()
+ {
+ throw new InvalidOperationException("Cannot end a message if a message has not been started.");
+ }
+
+ ///
+ /// Throws if a message has already been started.
+ ///
+ private void ThrowMessageAlreadyStarted()
+ {
+ throw new InvalidOperationException("Cannot start a message when a message is already in progress.");
+ }
+ }
+}
diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs
index 9268b7e0..9add9c9d 100644
--- a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs
+++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs
@@ -32,6 +32,11 @@ public class WebSocketProtocol : IControlFrameHandler
///
private WebSocketMessageReader _messageReader;
+ ///
+ /// The shared WebSocket message writer.
+ ///
+ private WebSocketMessageWriter _messageWriter;
+
///
/// The type of WebSocket protocol, server or client.
///
@@ -46,6 +51,7 @@ public WebSocketProtocol(ConnectionContext connection, WebSocketProtocolType pro
{
_transport = connection.Transport;
_messageReader = new WebSocketMessageReader(_transport.Input, this);
+ _messageWriter = new WebSocketMessageWriter(_transport.Output, protocolType);
_protocolType = protocolType;
}
@@ -85,23 +91,61 @@ private async ValueTask DoReadAsync(ValueTask readTas
/// The message to write.
/// True if the message is a text type message, false otherwise.
/// A cancellation token, if any.
- public async ValueTask WriteMessageAsync(ReadOnlySequence message, bool isText, CancellationToken cancellationToken = default)
+ /// The WebSocket message writer.
+ public ValueTask StartMessageAsync(ReadOnlySequence message, bool isText, CancellationToken cancellationToken = default)
{
if (IsClosed)
{
throw new InvalidOperationException("A close message was already received from the remote endpoint.");
}
- var opcode = isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary;
- var masked = _protocolType == WebSocketProtocolType.Client;
+ var startMessageTask = _messageWriter.StartMessageAsync(message, isText, cancellationToken);
+ if(startMessageTask.IsCompletedSuccessfully)
+ {
+ return new ValueTask(_messageWriter);
+ }
+ else
+ {
+ return DoStartMessageAsync(startMessageTask);
+ }
+ }
- var header = new WebSocketHeader(true, opcode, masked, (ulong)message.Length, WebSocketHeader.GenerateMaskingKey());
+ public ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence message, bool isText, CancellationToken cancellationToken = default)
+ {
+ if (IsClosed)
+ {
+ throw new InvalidOperationException("A close message was already received from the remote endpoint.");
+ }
- var frame = new WebSocketWriteFrame(header, message);
- var writer = new WebSocketFrameWriter();
+ var writerTask = _messageWriter.WriteSingleFrameMessageAsync(message, isText, cancellationToken);
+ if (writerTask.IsCompletedSuccessfully)
+ {
+ return writerTask;
+ }
+ else
+ {
+ return DoWriteSingleFrameAsync(writerTask);
+ }
+ }
+
+ ///
+ /// Completes a start message task.
+ ///
+ /// The active start message task.
+ /// The WebSocket message writer.
+ private async ValueTask DoStartMessageAsync(ValueTask startMessageTask)
+ {
+ await startMessageTask;
+ return _messageWriter;
+ }
- writer.WriteMessage(frame, _transport.Output);
- await _transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false);
+ ///
+ /// Completes a single frame writer task.
+ ///
+ /// The active writer task.
+ private async ValueTask DoWriteSingleFrameAsync(ValueTask writeFrameTask)
+ {
+ await writeFrameTask;
}
///
diff --git a/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs b/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs
index 41276709..b918ccf3 100644
--- a/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs
+++ b/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs
@@ -15,6 +15,8 @@ namespace Bedrock.Framework.Tests.Protocols
{
public class WebSocketProtocolTests
{
+ private byte[] _buffer = new byte[4096];
+
[Fact]
public async Task SingleMessageWorks()
{
@@ -91,5 +93,36 @@ public async Task MessageWithMultipleFramesWorks()
Assert.Equal(payloadString, Encoding.UTF8.GetString(buffer.WrittenSpan));
}
+
+ [Fact]
+ public async Task WriteSingleMessageWorks()
+ {
+ var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false));
+ var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server);
+
+ var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30));
+ var payloadString = "This is a test payload.";
+ await protocol.WriteSingleFrameMessageAsync(new ReadOnlySequence(new ReadOnlyMemory(Encoding.UTF8.GetBytes(payloadString))), false, default);
+
+ var result = await webSocket.ReceiveAsync(new ArraySegment(_buffer), default);
+ Assert.Equal(payloadString, Encoding.UTF8.GetString(_buffer, 0, result.Count));
+ }
+
+ [Fact]
+ public async Task WriteMultipleFramesWorks()
+ {
+ var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false));
+ var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server);
+
+ var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30));
+ var payloadString = "This is a test payload.";
+ var writer = await protocol.StartMessageAsync(new ReadOnlySequence(new ReadOnlyMemory(Encoding.UTF8.GetBytes(payloadString))), false, default);
+ await writer.EndMessageAsync(new ReadOnlySequence(new ReadOnlyMemory(Encoding.UTF8.GetBytes(payloadString))));
+
+ var result = await webSocket.ReceiveAsync(new ArraySegment(_buffer, 0, _buffer.Length), default);
+ result = await webSocket.ReceiveAsync(new ArraySegment(_buffer, result.Count, _buffer.Length - result.Count), default);
+
+ Assert.Equal($"{payloadString}{payloadString}", Encoding.UTF8.GetString(_buffer, 0, result.Count * 2));
+ }
}
}