diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs new file mode 100644 index 00000000..96576fa2 --- /dev/null +++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs @@ -0,0 +1,207 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Bedrock.Framework.Protocols.WebSockets +{ + /// + /// A writer-like construct for writing WebSocket messages. + /// + public class WebSocketMessageWriter + { + /// + /// True if a message is in progress, false otherwise. + /// + internal bool _messageInProgress; + + /// + /// The current protocol type for this writer, client or server. + /// + private WebSocketProtocolType _protocolType; + + /// + /// The transport to write to. + /// + private PipeWriter _transport; + + /// + /// True if the current message is text, false otherwise. + /// + public bool _isText; + + /// + /// Creates an instance of a WebSocketMessageWriter. + /// + /// The transport to write to. + /// The protocol type for this writer. + public WebSocketMessageWriter(PipeWriter transport, WebSocketProtocolType protocolType) + { + _transport = transport; + _protocolType = protocolType; + } + + /// + /// Starts a message in the writer. + /// + /// The payload to write. + /// Whether the payload is text or not. + /// A cancellation token, if any. + internal ValueTask StartMessageAsync(ReadOnlySequence payload, bool isText, CancellationToken cancellationToken = default) + { + if(_messageInProgress) + { + ThrowMessageAlreadyStarted(); + } + + _messageInProgress = true; + _isText = isText; + return DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, false, payload, cancellationToken); + } + + /// + /// Writes a single frame message with the writer. + /// + /// The payload to write. + /// Whether the payload is text or not. + /// A cancellation token, if any. + internal ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence payload, bool isText, CancellationToken cancellationToken = default) + { + if (_messageInProgress) + { + ThrowMessageAlreadyStarted(); + } + + var result = DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, true, payload, cancellationToken); + _messageInProgress = false; + + return result; + } + + /// + /// Writes a message payload portion with the writer. + /// + /// The payload to write. + /// A cancellation token, if any. + public ValueTask WriteAsync(ReadOnlySequence payload, CancellationToken cancellationToken = default) + { + if (!_messageInProgress) + { + ThrowMessageNotStarted(); + } + + return DoWriteAsync(WebSocketOpcode.Continuation, false, payload, cancellationToken); + } + + /// + /// Ends a message in progress. + /// + /// The payload to write. + /// A cancellation token, if any. + public ValueTask EndMessageAsync(ReadOnlySequence payload, CancellationToken cancellationToken = default) + { + if(!_messageInProgress) + { + ThrowMessageNotStarted(); + } + + var result = DoWriteAsync(WebSocketOpcode.Continuation, true, payload, cancellationToken); + _messageInProgress = false; + + return result; + } + + /// + /// Sends a message payload portion. + /// + /// The WebSocket opcode to send. + /// Whether or not this payload portion represents the end of the message. + /// The payload to send. + /// A cancellation token, if any. + private ValueTask DoWriteAsync(WebSocketOpcode opcode, bool endOfMessage, ReadOnlySequence payload, CancellationToken cancellationToken) + { + var masked = _protocolType == WebSocketProtocolType.Client; + var header = new WebSocketHeader(endOfMessage, opcode, masked, (ulong)payload.Length, masked ? WebSocketHeader.GenerateMaskingKey() : 0); + + var frame = new WebSocketWriteFrame(header, payload); + var writer = new WebSocketFrameWriter(); + + writer.WriteMessage(frame, _transport); + var flushTask = _transport.FlushAsync(cancellationToken); + if (flushTask.IsCompletedSuccessfully) + { + var result = flushTask.Result; + if(result.IsCanceled) + { + ThrowMessageCanceled(); + } + + if (result.IsCompleted && !endOfMessage) + { + ThrowTransportClosed(); + } + + return new ValueTask(); + } + else + { + return PerformFlushAsync(flushTask, endOfMessage); + } + } + + /// + /// Performs a flush of the writer asynchronously. + /// + /// The active writer flush task. + /// Whether or not this flush will send an end-of-message. + /// + private async ValueTask PerformFlushAsync(ValueTask flushTask, bool endOfMessage) + { + var result = await flushTask.ConfigureAwait(false); + if (result.IsCanceled) + { + ThrowMessageCanceled(); + } + + if (result.IsCompleted && !endOfMessage) + { + ThrowTransportClosed(); + } + } + + /// + /// Throws that a message was canceled unexpectedly. + /// + private void ThrowMessageCanceled() + { + throw new OperationCanceledException("Flush was canceled while a write was still in progress."); + } + + /// + /// Throws that the underlying transport closed unexpectedly. + /// + private void ThrowTransportClosed() + { + throw new InvalidOperationException("Transport closed unexpectedly while a message is still in progress."); + } + + /// + /// Throws if a message has not yet been started. + /// + private void ThrowMessageNotStarted() + { + throw new InvalidOperationException("Cannot end a message if a message has not been started."); + } + + /// + /// Throws if a message has already been started. + /// + private void ThrowMessageAlreadyStarted() + { + throw new InvalidOperationException("Cannot start a message when a message is already in progress."); + } + } +} diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs index 9268b7e0..9add9c9d 100644 --- a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs +++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs @@ -32,6 +32,11 @@ public class WebSocketProtocol : IControlFrameHandler /// private WebSocketMessageReader _messageReader; + /// + /// The shared WebSocket message writer. + /// + private WebSocketMessageWriter _messageWriter; + /// /// The type of WebSocket protocol, server or client. /// @@ -46,6 +51,7 @@ public WebSocketProtocol(ConnectionContext connection, WebSocketProtocolType pro { _transport = connection.Transport; _messageReader = new WebSocketMessageReader(_transport.Input, this); + _messageWriter = new WebSocketMessageWriter(_transport.Output, protocolType); _protocolType = protocolType; } @@ -85,23 +91,61 @@ private async ValueTask DoReadAsync(ValueTask readTas /// The message to write. /// True if the message is a text type message, false otherwise. /// A cancellation token, if any. - public async ValueTask WriteMessageAsync(ReadOnlySequence message, bool isText, CancellationToken cancellationToken = default) + /// The WebSocket message writer. + public ValueTask StartMessageAsync(ReadOnlySequence message, bool isText, CancellationToken cancellationToken = default) { if (IsClosed) { throw new InvalidOperationException("A close message was already received from the remote endpoint."); } - var opcode = isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary; - var masked = _protocolType == WebSocketProtocolType.Client; + var startMessageTask = _messageWriter.StartMessageAsync(message, isText, cancellationToken); + if(startMessageTask.IsCompletedSuccessfully) + { + return new ValueTask(_messageWriter); + } + else + { + return DoStartMessageAsync(startMessageTask); + } + } - var header = new WebSocketHeader(true, opcode, masked, (ulong)message.Length, WebSocketHeader.GenerateMaskingKey()); + public ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence message, bool isText, CancellationToken cancellationToken = default) + { + if (IsClosed) + { + throw new InvalidOperationException("A close message was already received from the remote endpoint."); + } - var frame = new WebSocketWriteFrame(header, message); - var writer = new WebSocketFrameWriter(); + var writerTask = _messageWriter.WriteSingleFrameMessageAsync(message, isText, cancellationToken); + if (writerTask.IsCompletedSuccessfully) + { + return writerTask; + } + else + { + return DoWriteSingleFrameAsync(writerTask); + } + } + + /// + /// Completes a start message task. + /// + /// The active start message task. + /// The WebSocket message writer. + private async ValueTask DoStartMessageAsync(ValueTask startMessageTask) + { + await startMessageTask; + return _messageWriter; + } - writer.WriteMessage(frame, _transport.Output); - await _transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); + /// + /// Completes a single frame writer task. + /// + /// The active writer task. + private async ValueTask DoWriteSingleFrameAsync(ValueTask writeFrameTask) + { + await writeFrameTask; } /// diff --git a/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs b/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs index 41276709..b918ccf3 100644 --- a/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs +++ b/tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs @@ -15,6 +15,8 @@ namespace Bedrock.Framework.Tests.Protocols { public class WebSocketProtocolTests { + private byte[] _buffer = new byte[4096]; + [Fact] public async Task SingleMessageWorks() { @@ -91,5 +93,36 @@ public async Task MessageWithMultipleFramesWorks() Assert.Equal(payloadString, Encoding.UTF8.GetString(buffer.WrittenSpan)); } + + [Fact] + public async Task WriteSingleMessageWorks() + { + var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false)); + var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server); + + var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30)); + var payloadString = "This is a test payload."; + await protocol.WriteSingleFrameMessageAsync(new ReadOnlySequence(new ReadOnlyMemory(Encoding.UTF8.GetBytes(payloadString))), false, default); + + var result = await webSocket.ReceiveAsync(new ArraySegment(_buffer), default); + Assert.Equal(payloadString, Encoding.UTF8.GetString(_buffer, 0, result.Count)); + } + + [Fact] + public async Task WriteMultipleFramesWorks() + { + var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false)); + var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server); + + var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30)); + var payloadString = "This is a test payload."; + var writer = await protocol.StartMessageAsync(new ReadOnlySequence(new ReadOnlyMemory(Encoding.UTF8.GetBytes(payloadString))), false, default); + await writer.EndMessageAsync(new ReadOnlySequence(new ReadOnlyMemory(Encoding.UTF8.GetBytes(payloadString)))); + + var result = await webSocket.ReceiveAsync(new ArraySegment(_buffer, 0, _buffer.Length), default); + result = await webSocket.ReceiveAsync(new ArraySegment(_buffer, result.Count, _buffer.Length - result.Count), default); + + Assert.Equal($"{payloadString}{payloadString}", Encoding.UTF8.GetString(_buffer, 0, result.Count * 2)); + } } }