diff --git a/src/commonMain/kotlin/ktor/WebsocketLoxoneClient.kt b/src/commonMain/kotlin/ktor/WebsocketLoxoneClient.kt index 8753512..0e10576 100644 --- a/src/commonMain/kotlin/ktor/WebsocketLoxoneClient.kt +++ b/src/commonMain/kotlin/ktor/WebsocketLoxoneClient.kt @@ -1,6 +1,5 @@ package cz.smarteon.loxone.ktor -import co.touchlab.stately.concurrency.AtomicReference import cz.smarteon.loxone.Codec import cz.smarteon.loxone.Codec.loxJson import cz.smarteon.loxone.Command @@ -22,6 +21,8 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withTimeout import kotlin.jvm.JvmOverloads @@ -38,7 +39,8 @@ class WebsocketLoxoneClient internal constructor( private val logger = KotlinLogging.logger {} - private val webSocketSession = AtomicReference(null) + private val sessionMutex = Mutex() + private var session: ClientWebSocketSession? = null private val scope = CoroutineScope(Dispatchers.Default) // TODO think more about correct dispacthers @@ -70,25 +72,27 @@ class WebsocketLoxoneClient internal constructor( override suspend fun close() { scope.cancel("Closing the connection") - webSocketSession.get()?.close(CloseReason(NORMAL, "LoxoneKotlin finished")) + session?.close(CloseReason(NORMAL, "LoxoneKotlin finished")) } private suspend fun ensureSession(): ClientWebSocketSession { - webSocketSession.compareAndSet( - null, - client.webSocketSession( - host = endpoint?.host, - port = endpoint?.port, - path = if (endpoint != null) endpoint.path + WS_PATH else WS_PATH, - block = { - url.protocol = if (endpoint?.useSsl == true) URLProtocol.WSS else URLProtocol.WS + sessionMutex.withLock { + if (session == null) { + client.webSocketSession( + host = endpoint?.host, + port = endpoint?.port, + path = if (endpoint != null) endpoint.path + WS_PATH else WS_PATH, + block = { + url.protocol = if (endpoint?.useSsl == true) URLProtocol.WSS else URLProtocol.WS + } + ).let { newSession -> + logger.debug { "WebSocketSession session created" } + session = newSession + newSession.incoming.receiveAsFlow().onEach(::processFrame).launchIn(scope) } - ) - ) - return checkNotNull(webSocketSession.get()) { "WebSocketSession should not be null right after init" } - .also { session -> - session.incoming.receiveAsFlow().onEach(::processFrame).launchIn(scope) } + } + return checkNotNull(session) { "WebSocketSession should not be null after init" } } private suspend fun processFrame(frame: Frame) {