diff --git a/Microsoft.Azure.Amqp.Uwp/Amqp/Transport/WebSocketTransportInitiator.UWP.cs b/Microsoft.Azure.Amqp.Uwp/Amqp/Transport/WebSocketTransportInitiator.UWP.cs index 7699f8f6..7c850038 100644 --- a/Microsoft.Azure.Amqp.Uwp/Amqp/Transport/WebSocketTransportInitiator.UWP.cs +++ b/Microsoft.Azure.Amqp.Uwp/Amqp/Transport/WebSocketTransportInitiator.UWP.cs @@ -4,6 +4,8 @@ namespace Microsoft.Azure.Amqp.Transport { using System; + using System.Threading; + using System.Threading.Tasks; using Windows.Networking.Sockets; sealed class WebSocketTransportInitiator : TransportInitiator @@ -20,31 +22,39 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c StreamWebSocket sws = new StreamWebSocket(); sws.Control.SupportedProtocols.Add(this.settings.SubProtocol); - var task = sws.ConnectAsync(this.settings.Uri).AsTask().WithTimeout(timeout, () => "timeout"); + var cts = new CancellationTokenSource(timeout); + var task = sws.ConnectAsync(this.settings.Uri).AsTask(cts.Token); if (task.IsCompleted) { - callbackArgs.Transport = new WebSocketTransport(sws, this.settings.Uri); + this.OnConnect(callbackArgs, task, sws, cts); return false; } task.ContinueWith(t => { - if (t.IsFaulted) - { - callbackArgs.Exception = t.Exception.InnerException; - } - else if (t.IsCanceled) - { - callbackArgs.Exception = new OperationCanceledException(); - } - else - { - callbackArgs.Transport = new WebSocketTransport(sws, this.settings.Uri); - } - + this.OnConnect(callbackArgs, t, sws, cts); callbackArgs.CompletedCallback(callbackArgs); }); return true; } + + void OnConnect(TransportAsyncCallbackArgs callbackArgs, Task t, StreamWebSocket sws, CancellationTokenSource cts) + { + cts.Dispose(); + if (t.IsFaulted) + { + sws.Dispose(); + callbackArgs.Exception = t.Exception.InnerException; + } + else if (t.IsCanceled) + { + sws.Dispose(); + callbackArgs.Exception = new OperationCanceledException(); + } + else + { + callbackArgs.Transport = new WebSocketTransport(sws, this.settings.Uri); + } + } } } \ No newline at end of file diff --git a/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransport.cs b/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransport.cs index 46d7744c..53724d96 100644 --- a/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransport.cs +++ b/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransport.cs @@ -72,25 +72,13 @@ public sealed override bool WriteAsync(TransportAsyncCallbackArgs args) Task task = this.webSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, CancellationToken.None); if (task.IsCompleted) { - this.OnWriteComplete(args, buffer, mergedBuffer, startTime); + this.OnWriteComplete(args, task, buffer, mergedBuffer, startTime); return false; } task.ContinueWith(t => { - if (t.IsFaulted) - { - args.Exception = t.Exception.InnerException; - } - else if (t.IsCanceled) - { - args.Exception = new OperationCanceledException(); - } - else - { - this.OnWriteComplete(args, buffer, mergedBuffer, startTime); - } - + this.OnWriteComplete(args, t, buffer, mergedBuffer, startTime); args.CompletedCallback(args); }); return true; @@ -103,25 +91,13 @@ public sealed override bool ReadAsync(TransportAsyncCallbackArgs args) Task task = this.webSocket.ReceiveAsync(buffer, CancellationToken.None); if (task.IsCompleted) { - this.OnReadComplete(args, task.Result.Count, startTime); + this.OnReadComplete(args, task, startTime); return false; } task.ContinueWith(t => { - if (t.IsFaulted) - { - args.Exception = t.Exception.InnerException; - } - else if (t.IsCanceled) - { - args.Exception = new OperationCanceledException(); - } - else - { - this.OnReadComplete(args, t.Result.Count, startTime); - } - + this.OnReadComplete(args, t, startTime); args.CompletedCallback(args); }); return true; @@ -137,6 +113,15 @@ protected override bool CloseInternal() Task task = webSocket.CloseAsync(WebSocketCloseStatus.Empty, string.Empty, CancellationToken.None); if (task.IsCompleted) { + if (task.IsFaulted) + { + ExceptionDispatcher.Throw(task.Exception.InnerException); + } + else if (task.IsCanceled) + { + throw new OperationCanceledException(); + } + return true; } @@ -165,26 +150,48 @@ internal static bool MatchScheme(string scheme) string.Equals(scheme, WebSocketTransportSettings.SecureWebSockets, StringComparison.OrdinalIgnoreCase); } - void OnWriteComplete(TransportAsyncCallbackArgs args, ArraySegment buffer, ByteBuffer byteBuffer, DateTime startTime) + void OnWriteComplete(TransportAsyncCallbackArgs args, Task t, ArraySegment buffer, ByteBuffer byteBuffer, DateTime startTime) { - args.BytesTransfered = buffer.Count; - if (byteBuffer != null) + if (t.IsFaulted) { - byteBuffer.Dispose(); + args.Exception = t.Exception.InnerException; + } + else if (t.IsCanceled) + { + args.Exception = new OperationCanceledException(); + } + else + { + args.BytesTransfered = buffer.Count; + if (this.usageMeter != null) + { + this.usageMeter.OnTransportWrite(0, buffer.Count, 0, DateTime.UtcNow.Subtract(startTime).Ticks); + } } - if (this.usageMeter != null) + if (byteBuffer != null) { - this.usageMeter.OnTransportWrite(0, buffer.Count, 0, DateTime.UtcNow.Subtract(startTime).Ticks); + byteBuffer.Dispose(); } } - void OnReadComplete(TransportAsyncCallbackArgs args, int count, DateTime startTime) + void OnReadComplete(TransportAsyncCallbackArgs args, Task t, DateTime startTime) { - args.BytesTransfered = count; - if (this.usageMeter != null) + if (t.IsFaulted) { - this.usageMeter.OnTransportRead(0, count, 0, DateTime.UtcNow.Subtract(startTime).Ticks); + args.Exception = t.Exception.InnerException; + } + else if (t.IsCanceled) + { + args.Exception = new OperationCanceledException(); + } + else + { + args.BytesTransfered = t.Result.Count; + if (this.usageMeter != null) + { + this.usageMeter.OnTransportRead(0, args.BytesTransfered, 0, DateTime.UtcNow.Subtract(startTime).Ticks); + } } } } diff --git a/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransportInitiator.cs b/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransportInitiator.cs index 12273a97..fdaebd97 100644 --- a/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransportInitiator.cs +++ b/Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransportInitiator.cs @@ -29,33 +29,39 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c cws.Options.SetBuffer(this.settings.ReceiveBufferSize, this.settings.SendBufferSize); #endif - Task task = cws.ConnectAsync(this.settings.Uri, CancellationToken.None).WithTimeout(timeout, () => "Client WebSocket connect timed out"); + var cts = new CancellationTokenSource(timeout); + Task task = cws.ConnectAsync(this.settings.Uri, cts.Token); if (task.IsCompleted) { - callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri); + this.OnConnect(callbackArgs, task, cws, cts); return false; } task.ContinueWith(t => { - if (t.IsFaulted) - { - cws.Abort(); - callbackArgs.Exception = t.Exception.InnerException; - } - else if (t.IsCanceled) - { - cws.Abort(); - callbackArgs.Exception = new OperationCanceledException(); - } - else - { - callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri); - } - + this.OnConnect(callbackArgs, t, cws, cts); callbackArgs.CompletedCallback(callbackArgs); }); return true; } + + void OnConnect(TransportAsyncCallbackArgs callbackArgs, Task t, ClientWebSocket cws, CancellationTokenSource cts) + { + cts.Dispose(); + if (t.IsFaulted) + { + cws.Dispose(); + callbackArgs.Exception = t.Exception.InnerException; + } + else if (t.IsCanceled) + { + cws.Dispose(); + callbackArgs.Exception = new OperationCanceledException(); + } + else + { + callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri); + } + } } } \ No newline at end of file diff --git a/Microsoft.Azure.Amqp/TaskHelpers.cs b/Microsoft.Azure.Amqp/TaskHelpers.cs index 11a36a4f..c4b20b4a 100644 --- a/Microsoft.Azure.Amqp/TaskHelpers.cs +++ b/Microsoft.Azure.Amqp/TaskHelpers.cs @@ -5,7 +5,6 @@ namespace Microsoft.Azure.Amqp { using System; using System.Runtime.InteropServices; - using System.Threading; using System.Threading.Tasks; static class TaskHelpers @@ -58,17 +57,6 @@ public static Task CreateTask(Func be return retval; } - public static void Fork(this Task thisTask) - { - Fork(thisTask, "TaskExtensions.Fork"); - } - - public static void Fork(this Task thisTask, string tracingInfo) - { - Fx.Assert(thisTask != null, "task is required!"); - thisTask.ContinueWith(t => AmqpTrace.Provider.AmqpHandleException(t.Exception, tracingInfo), TaskContinuationOptions.OnlyOnFaulted); - } - public static IAsyncResult ToAsyncResult(this Task task, AsyncCallback callback, object state) { if (task.AsyncState == state) @@ -167,54 +155,6 @@ public static TResult EndAsyncResult(IAsyncResult asyncResult) return task.GetAwaiter().GetResult(); } - public static Task WithTimeout(this Task task, TimeSpan timeout, Func errorMessage) - { - return WithTimeout(task, timeout, errorMessage, CancellationToken.None); - } - - public static async Task WithTimeout(this Task task, TimeSpan timeout, Func errorMessage, CancellationToken token) - { - if (timeout == TimeSpan.MaxValue) - { - timeout = Timeout.InfiniteTimeSpan; - } - else if (timeout.TotalMilliseconds > Int32.MaxValue) - { - timeout = TimeSpan.FromMilliseconds(Int32.MaxValue); - } - - if (task.IsCompleted || (timeout == Timeout.InfiniteTimeSpan && token == CancellationToken.None)) - { - await task.ConfigureAwait(false); - return; - } - - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(token)) - { - if (task == await Task.WhenAny(task, CreateDelayTask(timeout, cts.Token)).ConfigureAwait(false)) - { - cts.Cancel(); - await task.ConfigureAwait(false); - return; - } - } - - throw new TimeoutException(errorMessage()); - } - - static async Task CreateDelayTask(TimeSpan timeout, CancellationToken token) - { - try - { - await Task.Delay(timeout, token).ConfigureAwait(false); - } - catch (TaskCanceledException) - { - // No need to throw. Caller is responsible for detecting - // which task completed and throwing appropriate Timeout Exception - } - } - [StructLayout(LayoutKind.Sequential, Size = 1)] internal struct VoidTaskResult {