diff --git a/crates/test-programs/src/bin/preview2_tcp_streams.rs b/crates/test-programs/src/bin/preview2_tcp_streams.rs index c756958968d4..712b234bbfaf 100644 --- a/crates/test-programs/src/bin/preview2_tcp_streams.rs +++ b/crates/test-programs/src/bin/preview2_tcp_streams.rs @@ -109,6 +109,29 @@ fn test_tcp_shutdown_should_not_lose_data(net: &Network, family: IpAddressFamily }); } +/// InputStream::subscribe should not wake up if there is no data to read. +fn test_tcp_input_stream_should_not_wake_on_empty_data(net: &Network, family: IpAddressFamily) { + setup(net, family, |server, client| { + use test_programs::wasi::clocks::monotonic_clock::subscribe_duration; + let timeout_100ms = 100_000_000; + + // Send some data to the server + client.output.blocking_write_and_flush(b"Hi!").unwrap(); + + server.input.subscribe().block(); + let res = server.input.read(512).unwrap(); + assert_eq!(res, b"Hi!", "Expected to receive data"); + + // Don't send any data + + let res = server + .input + .subscribe() + .block_until(&subscribe_duration(timeout_100ms)); + assert!(res.is_err(), "Expected to time out cause no data was sent"); + }); +} + fn main() { let net = Network::default(); @@ -123,6 +146,9 @@ fn main() { test_tcp_shutdown_should_not_lose_data(&net, IpAddressFamily::Ipv4); test_tcp_shutdown_should_not_lose_data(&net, IpAddressFamily::Ipv6); + + test_tcp_input_stream_should_not_wake_on_empty_data(&net, IpAddressFamily::Ipv4); + test_tcp_input_stream_should_not_wake_on_empty_data(&net, IpAddressFamily::Ipv6); } struct Connection { diff --git a/crates/wasi/src/bindings.rs b/crates/wasi/src/bindings.rs index cefeec980ce3..452edbe9d4f5 100644 --- a/crates/wasi/src/bindings.rs +++ b/crates/wasi/src/bindings.rs @@ -385,6 +385,7 @@ mod async_io { "[method]pollable.ready", "[method]tcp-socket.start-bind", "[method]tcp-socket.start-connect", + "[method]tcp-socket.finish-connect", "[method]udp-socket.start-bind", "[method]udp-socket.stream", "[method]outgoing-datagram-stream.send", diff --git a/crates/wasi/src/host/tcp.rs b/crates/wasi/src/host/tcp.rs index 160cb8a757c0..f93270a37493 100644 --- a/crates/wasi/src/host/tcp.rs +++ b/crates/wasi/src/host/tcp.rs @@ -69,14 +69,14 @@ where Ok(()) } - fn finish_connect( + async fn finish_connect( &mut self, this: Resource, ) -> SocketResult<(Resource, Resource)> { let table = self.table(); let socket = table.get_mut(&this)?; - let (input, output) = socket.finish_connect()?; + let (input, output) = socket.finish_connect().await?; let input_stream = self.table().push_child(input, &this)?; let output_stream = self.table().push_child(output, &this)?; @@ -366,7 +366,7 @@ pub mod sync { &mut self, self_: Resource, ) -> Result<(Resource, Resource), SocketError> { - AsyncHostTcpSocket::finish_connect(self, self_) + in_tokio(async { AsyncHostTcpSocket::finish_connect(self, self_).await }) } fn start_listen(&mut self, self_: Resource) -> Result<(), SocketError> { diff --git a/crates/wasi/src/tcp.rs b/crates/wasi/src/tcp.rs index 5b05ebb25339..62cf27f3c028 100644 --- a/crates/wasi/src/tcp.rs +++ b/crates/wasi/src/tcp.rs @@ -263,20 +263,11 @@ impl TcpSocket { Ok(()) } - pub fn finish_connect(&mut self) -> SocketResult<(InputStream, OutputStream)> { + pub async fn finish_connect(&mut self) -> SocketResult<(InputStream, OutputStream)> { let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed); let result = match previous_state { TcpState::ConnectReady(result) => result, - TcpState::Connecting(mut future) => { - let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); - match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) { - Poll::Ready(result) => result, - Poll::Pending => { - self.tcp_state = TcpState::Connecting(future); - return Err(ErrorCode::WouldBlock.into()); - } - } - } + TcpState::Connecting(future) => future.await, previous_state => { self.tcp_state = previous_state; return Err(ErrorCode::NotInProgress.into()); @@ -360,7 +351,7 @@ impl TcpSocket { } } - pub fn accept(&mut self) -> SocketResult<(Self, InputStream, OutputStream)> { + pub async fn accept(&mut self) -> SocketResult<(Self, InputStream, OutputStream)> { let TcpState::Listening { listener, pending_accept,