Skip to content

Commit

Permalink
Added initial write-side API. Added write smoke tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnischan committed Feb 17, 2020
1 parent e614dd2 commit 5e64028
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 8 deletions.
207 changes: 207 additions & 0 deletions src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageWriter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Bedrock.Framework.Protocols.WebSockets
{
/// <summary>
/// A writer-like construct for writing WebSocket messages.
/// </summary>
public class WebSocketMessageWriter
{
/// <summary>
/// True if a message is in progress, false otherwise.
/// </summary>
internal bool _messageInProgress;

/// <summary>
/// The current protocol type for this writer, client or server.
/// </summary>
private WebSocketProtocolType _protocolType;

/// <summary>
/// The transport to write to.
/// </summary>
private PipeWriter _transport;

/// <summary>
/// True if the current message is text, false otherwise.
/// </summary>
public bool _isText;

/// <summary>
/// Creates an instance of a WebSocketMessageWriter.
/// </summary>
/// <param name="transport">The transport to write to.</param>
/// <param name="protocolType">The protocol type for this writer.</param>
public WebSocketMessageWriter(PipeWriter transport, WebSocketProtocolType protocolType)
{
_transport = transport;
_protocolType = protocolType;
}

/// <summary>
/// Starts a message in the writer.
/// </summary>
/// <param name="payload">The payload to write.</param>
/// <param name="isText">Whether the payload is text or not.</param>
/// <param name="cancellationToken">A cancellation token, if any.</param>
internal ValueTask StartMessageAsync(ReadOnlySequence<byte> payload, bool isText, CancellationToken cancellationToken = default)
{
if(_messageInProgress)
{
ThrowMessageAlreadyStarted();
}

_messageInProgress = true;
_isText = isText;
return DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, false, payload, cancellationToken);
}

/// <summary>
/// Writes a single frame message with the writer.
/// </summary>
/// <param name="payload">The payload to write.</param>
/// <param name="isText">Whether the payload is text or not.</param>
/// <param name="cancellationToken">A cancellation token, if any.</param>
internal ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence<byte> payload, bool isText, CancellationToken cancellationToken = default)
{
if (_messageInProgress)
{
ThrowMessageAlreadyStarted();
}

var result = DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, true, payload, cancellationToken);
_messageInProgress = false;

return result;
}

/// <summary>
/// Writes a message payload portion with the writer.
/// </summary>
/// <param name="payload">The payload to write.</param>
/// <param name="cancellationToken">A cancellation token, if any.</param>
public ValueTask WriteAsync(ReadOnlySequence<byte> payload, CancellationToken cancellationToken = default)
{
if (!_messageInProgress)
{
ThrowMessageNotStarted();
}

return DoWriteAsync(WebSocketOpcode.Continuation, false, payload, cancellationToken);
}

/// <summary>
/// Ends a message in progress.
/// </summary>
/// <param name="payload">The payload to write.</param>
/// <param name="cancellationToken">A cancellation token, if any.</param>
public ValueTask EndMessageAsync(ReadOnlySequence<byte> payload, CancellationToken cancellationToken = default)
{
if(!_messageInProgress)
{
ThrowMessageNotStarted();
}

var result = DoWriteAsync(WebSocketOpcode.Continuation, true, payload, cancellationToken);
_messageInProgress = false;

return result;
}

/// <summary>
/// Sends a message payload portion.
/// </summary>
/// <param name="opcode">The WebSocket opcode to send.</param>
/// <param name="endOfMessage">Whether or not this payload portion represents the end of the message.</param>
/// <param name="payload">The payload to send.</param>
/// <param name="cancellationToken">A cancellation token, if any.</param>
private ValueTask DoWriteAsync(WebSocketOpcode opcode, bool endOfMessage, ReadOnlySequence<byte> payload, CancellationToken cancellationToken)
{
var masked = _protocolType == WebSocketProtocolType.Client;
var header = new WebSocketHeader(endOfMessage, opcode, masked, (ulong)payload.Length, masked ? WebSocketHeader.GenerateMaskingKey() : 0);

var frame = new WebSocketWriteFrame(header, payload);
var writer = new WebSocketFrameWriter();

writer.WriteMessage(frame, _transport);
var flushTask = _transport.FlushAsync(cancellationToken);
if (flushTask.IsCompletedSuccessfully)
{
var result = flushTask.Result;
if(result.IsCanceled)
{
ThrowMessageCanceled();
}

if (result.IsCompleted && !endOfMessage)
{
ThrowTransportClosed();
}

return new ValueTask();
}
else
{
return PerformFlushAsync(flushTask, endOfMessage);
}
}

/// <summary>
/// Performs a flush of the writer asynchronously.
/// </summary>
/// <param name="flushTask">The active writer flush task.</param>
/// <param name="endOfMessage">Whether or not this flush will send an end-of-message.</param>
/// <returns></returns>
private async ValueTask PerformFlushAsync(ValueTask<FlushResult> flushTask, bool endOfMessage)
{
var result = await flushTask.ConfigureAwait(false);
if (result.IsCanceled)
{
ThrowMessageCanceled();
}

if (result.IsCompleted && !endOfMessage)
{
ThrowTransportClosed();
}
}

/// <summary>
/// Throws that a message was canceled unexpectedly.
/// </summary>
private void ThrowMessageCanceled()
{
throw new OperationCanceledException("Flush was canceled while a write was still in progress.");
}

/// <summary>
/// Throws that the underlying transport closed unexpectedly.
/// </summary>
private void ThrowTransportClosed()
{
throw new InvalidOperationException("Transport closed unexpectedly while a message is still in progress.");
}

/// <summary>
/// Throws if a message has not yet been started.
/// </summary>
private void ThrowMessageNotStarted()
{
throw new InvalidOperationException("Cannot end a message if a message has not been started.");
}

/// <summary>
/// Throws if a message has already been started.
/// </summary>
private void ThrowMessageAlreadyStarted()
{
throw new InvalidOperationException("Cannot start a message when a message is already in progress.");
}
}
}
60 changes: 52 additions & 8 deletions src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public class WebSocketProtocol : IControlFrameHandler
/// </summary>
private WebSocketMessageReader _messageReader;

/// <summary>
/// The shared WebSocket message writer.
/// </summary>
private WebSocketMessageWriter _messageWriter;

/// <summary>
/// The type of WebSocket protocol, server or client.
/// </summary>
Expand All @@ -46,6 +51,7 @@ public WebSocketProtocol(ConnectionContext connection, WebSocketProtocolType pro
{
_transport = connection.Transport;
_messageReader = new WebSocketMessageReader(_transport.Input, this);
_messageWriter = new WebSocketMessageWriter(_transport.Output, protocolType);
_protocolType = protocolType;
}

Expand Down Expand Up @@ -85,23 +91,61 @@ private async ValueTask<WebSocketReadResult> DoReadAsync(ValueTask<bool> readTas
/// <param name="message">The message to write.</param>
/// <param name="isText">True if the message is a text type message, false otherwise.</param>
/// <param name="cancellationToken">A cancellation token, if any.</param>
public async ValueTask WriteMessageAsync(ReadOnlySequence<byte> message, bool isText, CancellationToken cancellationToken = default)
/// <returns>The WebSocket message writer.</returns>
public ValueTask<WebSocketMessageWriter> StartMessageAsync(ReadOnlySequence<byte> message, bool isText, CancellationToken cancellationToken = default)
{
if (IsClosed)
{
throw new InvalidOperationException("A close message was already received from the remote endpoint.");
}

var opcode = isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary;
var masked = _protocolType == WebSocketProtocolType.Client;
var startMessageTask = _messageWriter.StartMessageAsync(message, isText, cancellationToken);
if(startMessageTask.IsCompletedSuccessfully)
{
return new ValueTask<WebSocketMessageWriter>(_messageWriter);
}
else
{
return DoStartMessageAsync(startMessageTask);
}
}

var header = new WebSocketHeader(true, opcode, masked, (ulong)message.Length, WebSocketHeader.GenerateMaskingKey());
public ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence<byte> message, bool isText, CancellationToken cancellationToken = default)
{
if (IsClosed)
{
throw new InvalidOperationException("A close message was already received from the remote endpoint.");
}

var frame = new WebSocketWriteFrame(header, message);
var writer = new WebSocketFrameWriter();
var writerTask = _messageWriter.WriteSingleFrameMessageAsync(message, isText, cancellationToken);
if (writerTask.IsCompletedSuccessfully)
{
return writerTask;
}
else
{
return DoWriteSingleFrameAsync(writerTask);
}
}

/// <summary>
/// Completes a start message task.
/// </summary>
/// <param name="startMessageTask">The active start message task.</param>
/// <returns>The WebSocket message writer.</returns>
private async ValueTask<WebSocketMessageWriter> DoStartMessageAsync(ValueTask startMessageTask)
{
await startMessageTask;
return _messageWriter;
}

writer.WriteMessage(frame, _transport.Output);
await _transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false);
/// <summary>
/// Completes a single frame writer task.
/// </summary>
/// <param name="writeFrameTask">The active writer task.</param>
private async ValueTask DoWriteSingleFrameAsync(ValueTask writeFrameTask)
{
await writeFrameTask;
}

/// <summary>
Expand Down
33 changes: 33 additions & 0 deletions tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace Bedrock.Framework.Tests.Protocols
{
public class WebSocketProtocolTests
{
private byte[] _buffer = new byte[4096];

[Fact]
public async Task SingleMessageWorks()
{
Expand Down Expand Up @@ -91,5 +93,36 @@ public async Task MessageWithMultipleFramesWorks()

Assert.Equal(payloadString, Encoding.UTF8.GetString(buffer.WrittenSpan));
}

[Fact]
public async Task WriteSingleMessageWorks()
{
var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false));
var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server);

var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30));
var payloadString = "This is a test payload.";
await protocol.WriteSingleFrameMessageAsync(new ReadOnlySequence<byte>(new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes(payloadString))), false, default);

var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(_buffer), default);
Assert.Equal(payloadString, Encoding.UTF8.GetString(_buffer, 0, result.Count));
}

[Fact]
public async Task WriteMultipleFramesWorks()
{
var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false));
var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server);

var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30));
var payloadString = "This is a test payload.";
var writer = await protocol.StartMessageAsync(new ReadOnlySequence<byte>(new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes(payloadString))), false, default);
await writer.EndMessageAsync(new ReadOnlySequence<byte>(new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes(payloadString))));

var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(_buffer, 0, _buffer.Length), default);
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(_buffer, result.Count, _buffer.Length - result.Count), default);

Assert.Equal($"{payloadString}{payloadString}", Encoding.UTF8.GetString(_buffer, 0, result.Count * 2));
}
}
}

0 comments on commit 5e64028

Please sign in to comment.