Skip to content

Commit

Permalink
fix(apple): properly detect output timeout across stdout and stderr
Browse files Browse the repository at this point in the history
  • Loading branch information
Malinskiy committed Feb 26, 2024
1 parent fee8c92 commit 7a75392
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@ import kotlinx.coroutines.CompletableJob
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.cancel
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.onClosed
import kotlinx.coroutines.channels.onFailure
import kotlinx.coroutines.channels.onSuccess
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.supervisorScope
import kotlinx.coroutines.withContext
import kotlin.coroutines.cancellation.CancellationException

abstract class BaseCommand(
override val stdout: ReceiveChannel<String>,
Expand All @@ -19,38 +25,50 @@ abstract class BaseCommand(
override suspend fun await(): CommandResult = withContext(Dispatchers.IO) {
val deferredStdout = supervisorScope {
async(job) {
val stdoutBuffer = mutableListOf<String>()
for (line in stdout) {
stdoutBuffer.add(line)
val buffer = mutableListOf<String>()
while (true) {
val channelResult = stdout.receiveCatching()
channelResult.onSuccess { buffer.add(it) }
channelResult.onClosed { if (it != null) cancel(CancellationException("Channel closed", it)) }
channelResult.onFailure { if (it != null) cancel(CancellationException("Channel failed", it)) }

if (!channelResult.isSuccess) break
}
stdoutBuffer
buffer
}
}

val deferredStderr = supervisorScope {
async(job) {
val stderrBuffer = mutableListOf<String>()
for (line in stderr) {
stderrBuffer.add(line)
val buffer = mutableListOf<String>()
while (true) {
val channelResult = stderr.receiveCatching()
channelResult.onSuccess { buffer.add(it) }
channelResult.onClosed { if (it != null) cancel(CancellationException("Channel closed", it)) }
channelResult.onFailure { if (it != null) cancel(CancellationException("Channel failed", it)) }

if (!channelResult.isSuccess) break
}
stderrBuffer
buffer
}
}

val out = deferredStdout.await()
val err = deferredStderr.await()
val exitCode = exitCode.await()

CommandResult(out, err, exitCode)

val (out, err, exitCode) = awaitAll(deferredStdout, deferredStderr, exitCode)

CommandResult(out as List<String>, err as List<String>, exitCode as Int)
}

override suspend fun drain() {
return supervisorScope {
async(job) {
for (line in stdout) {}
for (line in stdout) {
}
}
async(job) {
for (line in stderr) {}
for (line in stderr) {
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlinx.coroutines.delay
import java.io.File
import java.nio.charset.Charset
import java.time.Duration
import java.util.concurrent.atomic.AtomicLong

/**
* Note: doesn't support idle timeout currently
Expand Down Expand Up @@ -64,11 +65,11 @@ class KotlinProcessCommandExecutor(
process.suspendFor()
}
}

val stdout = produceLinesManually(job, process.inputStream, idleTimeout, charset, channelCapacity) { process.isAlive && !exitCode.isCompleted }
val stderr = produceLinesManually(job, process.errorStream, idleTimeout, charset, channelCapacity) { process.isAlive && !exitCode.isCompleted }


val lastOutputTimeMillis = AtomicLong(System.currentTimeMillis())
val stdout = produceLinesManually(job, process.inputStream, lastOutputTimeMillis, idleTimeout, charset, channelCapacity) { process.isAlive && !exitCode.isCompleted }
val stderr = produceLinesManually(job, process.errorStream, lastOutputTimeMillis, idleTimeout, charset, channelCapacity) { process.isAlive && !exitCode.isCompleted }

return KotlinProcessCommand(
process, job, stdout, stderr, exitCode, destroyForcibly
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.malinskiy.marathon.apple.cmd.remote.ssh.sshj

import com.malinskiy.marathon.apple.cmd.CommandExecutor
import com.malinskiy.marathon.apple.cmd.CommandSession
import com.malinskiy.marathon.apple.extensions.Durations
import com.malinskiy.marathon.apple.extensions.produceLinesManually
import com.malinskiy.marathon.extension.withTimeoutOrNull
import com.malinskiy.marathon.log.MarathonLogging
Expand All @@ -20,6 +21,7 @@ import java.io.IOException
import java.nio.charset.Charset
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.coroutineContext

class SshjCommandExecutor(
Expand Down Expand Up @@ -53,9 +55,10 @@ class SshjCommandExecutor(
}
}
val cmd = session.exec(escapedCmd)

val stdout = produceLinesManually(job, cmd.inputStream, idleTimeout, charset, channelCapacity) { cmd.isOpen && !cmd.isEOF }
val stderr = produceLinesManually(job, cmd.errorStream, idleTimeout, charset, channelCapacity) { cmd.isOpen && !cmd.isEOF }

val lastOutputTimeMillis = AtomicLong(System.currentTimeMillis())
val stdout = produceLinesManually(job, cmd.inputStream, lastOutputTimeMillis, idleTimeout, charset, channelCapacity) { cmd.isOpen && !cmd.isEOF }
val stderr = produceLinesManually(job, cmd.errorStream, lastOutputTimeMillis, idleTimeout, charset, channelCapacity) { cmd.isOpen && !cmd.isEOF }
val exitCode: Deferred<Int?> = async(job) {
val result = withTimeoutOrNull(timeout) {
cmd.suspendFor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class SshjCommandSession(

override fun close() {
if (!closed.getAndSet(true)) {
if (command.isOpen) {
terminate()
}
command.close()

command.join()
super.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import java.io.InputStream
import java.nio.charset.Charset
import java.time.Duration
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicLong
import kotlin.math.min

fun CoroutineScope.produceLines(
Expand All @@ -38,15 +39,14 @@ fun CoroutineScope.produceLines(
fun CoroutineScope.produceLinesManually(
job: Job,
inputStream: InputStream,
lastOutputTimeMillis: AtomicLong,
idleTimeout: Duration,
charset: Charset,
channelCapacity: Int,
canRead: () -> Boolean,
): ReceiveChannel<String> {
return produce(capacity = channelCapacity, context = job) {
inputStream.buffered().use { inputStream ->

var lastOutputTimeMillis = System.currentTimeMillis()
LineBuffer(charset, onLine = { send(it) }).use { lineBuffer ->
val byteArray = ByteArray(16384)
while (coroutineContext.isActive && !channel.isClosedForSend && !job.isCancelled) {
Expand All @@ -66,6 +66,15 @@ fun CoroutineScope.produceLinesManually(
available > 0 -> inputStream.read(byteArray, 0, min(available, byteArray.size))
else -> 0
}

//Check we didn't go over idle timeout
val lastOutput = lastOutputTimeMillis.get()
val timeSinceLastOutputMillis = System.currentTimeMillis() - lastOutput
if (timeSinceLastOutputMillis > idleTimeout.toMillis()) {
close(TimeoutException("idle timeout $idleTimeout reached"))
break
}

// if there was nothing to read
if (count == 0) {
// if session received EOF or has been closed, reading stops
Expand All @@ -77,13 +86,9 @@ fun CoroutineScope.produceLinesManually(
} else if (count == -1) {
break
} else {
val timeSinceLastOutputMillis = System.currentTimeMillis() - lastOutputTimeMillis
if (timeSinceLastOutputMillis > idleTimeout.toMillis()) {
close(TimeoutException("idle timeout $idleTimeout reached"))
break
}
lineBuffer.append(byteArray, count)
lastOutputTimeMillis = System.currentTimeMillis()
//Check we didn't go over idle timeout
lastOutputTimeMillis.set(System.currentTimeMillis())
}
// immediately send any full lines for parsing
lineBuffer.flush()
Expand Down

0 comments on commit 7a75392

Please sign in to comment.