diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index 3a0f582f50f8..48efd73d748e 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -44,7 +44,6 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.SynchronousSink; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -109,7 +108,7 @@ public static Publisher invokeSuspendingFunction(Method method, Object target * @throws IllegalArgumentException if {@code method} is not a suspending function * @since 6.0 */ - @SuppressWarnings({"deprecation", "DataFlowIssue", "NullAway"}) + @SuppressWarnings({"DataFlowIssue", "NullAway"}) public static Publisher invokeSuspendingFunction( CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) { @@ -146,7 +145,7 @@ public static Publisher invokeSuspendingFunction( } return KCallables.callSuspendBy(function, argMap, continuation); }) - .handle(CoroutinesUtils::handleResult) + .filter(result -> result != Unit.INSTANCE) .onErrorMap(InvocationTargetException.class, InvocationTargetException::getTargetException); KType returnType = function.getReturnType(); @@ -166,22 +165,4 @@ private static Flux asFlux(Object flow) { return ReactorFlowKt.asFlux(((Flow) flow)); } - private static void handleResult(Object result, SynchronousSink sink) { - if (result == Unit.INSTANCE) { - sink.complete(); - } - else if (KotlinDetector.isInlineClass(result.getClass())) { - try { - sink.next(result.getClass().getDeclaredMethod("unbox-impl").invoke(result)); - sink.complete(); - } - catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) { - sink.error(ex); - } - } - else { - sink.next(result); - sink.complete(); - } - } } diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index 143a2cad08c8..31ebb74927d7 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -192,7 +192,7 @@ class CoroutinesUtilsTests { @Test fun invokeSuspendingFunctionWithValueClassParameter() { - val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClass") } + val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClassParameter") } val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, "foo", null) as Mono runBlocking { Assertions.assertThat(mono.awaitSingle()).isEqualTo("foo") @@ -204,7 +204,16 @@ class CoroutinesUtilsTests { val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClassReturnValue") } val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null) as Mono runBlocking { - Assertions.assertThat(mono.awaitSingle()).isEqualTo("foo") + Assertions.assertThat(mono.awaitSingle()).isEqualTo(ValueClass("foo")) + } + } + + @Test + fun invokeSuspendingFunctionWithResultOfUnitReturnValue() { + val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithResultOfUnitReturnValue") } + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingle()).isEqualTo(Result.success(Unit)) } } @@ -314,7 +323,7 @@ class CoroutinesUtilsTests { return null } - suspend fun suspendingFunctionWithValueClass(value: ValueClass): String { + suspend fun suspendingFunctionWithValueClassParameter(value: ValueClass): String { delay(1) return value.value } @@ -324,6 +333,11 @@ class CoroutinesUtilsTests { return ValueClass("foo") } + suspend fun suspendingFunctionWithResultOfUnitReturnValue(): Result { + delay(1) + return Result.success(Unit) + } + suspend fun suspendingFunctionWithValueClassWithInit(value: ValueClassWithInit): String { delay(1) return value.value diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java index b1398f9a9b54..6f14f487aee8 100644 --- a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -30,6 +30,8 @@ import kotlin.reflect.full.KClasses; import kotlin.reflect.jvm.KCallablesJvm; import kotlin.reflect.jvm.ReflectJvmMapping; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SynchronousSink; import org.springframework.context.MessageSource; import org.springframework.core.CoroutinesUtils; @@ -288,7 +290,8 @@ else if (targetException instanceof Exception exception) { * @since 6.0 */ protected Object invokeSuspendingFunction(Method method, Object target, Object[] args) { - return CoroutinesUtils.invokeSuspendingFunction(method, target, args); + Object result = CoroutinesUtils.invokeSuspendingFunction(method, target, args); + return (result instanceof Mono mono ? mono.handle(KotlinDelegate::handleResult) : result); } @@ -298,7 +301,7 @@ protected Object invokeSuspendingFunction(Method method, Object target, Object[] private static class KotlinDelegate { @Nullable - @SuppressWarnings({"deprecation", "DataFlowIssue"}) + @SuppressWarnings("DataFlowIssue") public static Object invokeFunction(Method method, Object target, Object[] args) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException { KFunction function = ReflectJvmMapping.getKotlinFunction(method); // For property accessors @@ -333,10 +336,33 @@ public static Object invokeFunction(Method method, Object target, Object[] args) } Object result = function.callBy(argMap); if (result != null && KotlinDetector.isInlineClass(result.getClass())) { - return result.getClass().getDeclaredMethod("unbox-impl").invoke(result); + result = unbox(result); } return (result == Unit.INSTANCE ? null : result); } + + private static void handleResult(Object result, SynchronousSink sink) { + if (KotlinDetector.isInlineClass(result.getClass())) { + try { + Object unboxed = unbox(result); + if (unboxed != Unit.INSTANCE) { + sink.next(unboxed); + } + sink.complete(); + } + catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) { + sink.error(ex); + } + } + else { + sink.next(result); + sink.complete(); + } + } + + private static Object unbox(Object result) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException { + return result.getClass().getDeclaredMethod("unbox-impl").invoke(result); + } } } diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt index 60e52f4b1b5e..c1f63e8b30cc 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt @@ -16,14 +16,18 @@ package org.springframework.web.method.support +import kotlinx.coroutines.delay import org.assertj.core.api.Assertions import org.junit.jupiter.api.Test +import org.springframework.core.MethodParameter import org.springframework.util.ReflectionUtils +import org.springframework.web.bind.support.WebDataBinderFactory import org.springframework.web.context.request.NativeWebRequest import org.springframework.web.context.request.ServletWebRequest -import org.springframework.web.testfixture.method.ResolvableMethod import org.springframework.web.testfixture.servlet.MockHttpServletRequest import org.springframework.web.testfixture.servlet.MockHttpServletResponse +import reactor.core.publisher.Mono +import reactor.test.StepVerifier import java.lang.reflect.Method import kotlin.reflect.jvm.javaGetter import kotlin.reflect.jvm.javaMethod @@ -33,6 +37,7 @@ import kotlin.reflect.jvm.javaMethod * * @author Sebastien Deleuze */ +@Suppress("UNCHECKED_CAST") class InvocableHandlerMethodKotlinTests { private val request: NativeWebRequest = ServletWebRequest(MockHttpServletRequest(), MockHttpServletResponse()) @@ -110,6 +115,12 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThat(value).isEqualTo("foo") } + @Test + fun resultOfUnitReturnValue() { + val value = getInvocable(ValueClassHandler::resultOfUnitReturnValue.javaMethod!!).invokeForRequest(request, null) + Assertions.assertThat(value).isNull() + } + @Test fun valueClassDefaultValue() { composite.addResolver(StubArgumentResolver(Double::class.java)) @@ -138,6 +149,60 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThat(value).isEqualTo('a') } + @Test + fun suspendingValueClass() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + composite.addResolver(StubArgumentResolver(Long::class.java, 1L)) + val value = getInvocable(SuspendingValueClassHandler::longValueClass.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).expectNext(1L).verifyComplete() + } + + @Test + fun suspendingValueClassReturnValue() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + val value = getInvocable(SuspendingValueClassHandler::valueClassReturnValue.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).expectNext("foo").verifyComplete() + } + + @Test + fun suspendingResultOfUnitReturnValue() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + val value = getInvocable(SuspendingValueClassHandler::resultOfUnitReturnValue.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).verifyComplete() + } + + @Test + fun suspendingValueClassDefaultValue() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + composite.addResolver(StubArgumentResolver(Double::class.java)) + val value = getInvocable(SuspendingValueClassHandler::doubleValueClass.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).expectNext(3.1).verifyComplete() + } + + @Test + fun suspendingValueClassWithInit() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + composite.addResolver(StubArgumentResolver(String::class.java, "")) + val value = getInvocable(SuspendingValueClassHandler::valueClassWithInit.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).verifyError(IllegalArgumentException::class.java) + } + + @Test + fun suspendingValueClassWithNullable() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + composite.addResolver(StubArgumentResolver(LongValueClass::class.java, null)) + val value = getInvocable(SuspendingValueClassHandler::valueClassWithNullable.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).verifyComplete() + } + + @Test + fun suspendingValueClassWithPrivateConstructor() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + composite.addResolver(StubArgumentResolver(Char::class.java, 'a')) + val value = getInvocable(SuspendingValueClassHandler::valueClassWithPrivateConstructor.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).expectNext('a').verifyComplete() + } + @Test fun propertyAccessor() { val value = getInvocable(PropertyAccessorHandler::prop.javaGetter!!).invokeForRequest(request, null) @@ -206,23 +271,58 @@ class InvocableHandlerMethodKotlinTests { private class ValueClassHandler { - fun valueClassReturnValue() = - StringValueClass("foo") + fun valueClassReturnValue() = StringValueClass("foo") + + fun resultOfUnitReturnValue() = Result.success(Unit) + + fun longValueClass(limit: LongValueClass) = limit.value + + fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)) = limit.value + + fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass + + fun valueClassWithNullable(limit: LongValueClass?) = limit?.value + + fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = limit.value + } + + private class SuspendingValueClassHandler { + + suspend fun valueClassReturnValue(): StringValueClass { + delay(1) + return StringValueClass("foo") + } + + suspend fun resultOfUnitReturnValue(): Result { + delay(1) + return Result.success(Unit) + } - fun longValueClass(limit: LongValueClass) = - limit.value + suspend fun longValueClass(limit: LongValueClass): Long { + delay(1) + return limit.value + } - fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)) = - limit.value - fun valueClassWithInit(valueClass: ValueClassWithInit) = - valueClass + suspend fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)): Double { + delay(1) + return limit.value + } - fun valueClassWithNullable(limit: LongValueClass?) = - limit?.value + suspend fun valueClassWithInit(valueClass: ValueClassWithInit): ValueClassWithInit { + delay(1) + return valueClass + } + + suspend fun valueClassWithNullable(limit: LongValueClass?): Long? { + delay(1) + return limit?.value + } - fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = - limit.value + suspend fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor): Char { + delay(1) + return limit.value + } } private class PropertyAccessorHandler { @@ -282,4 +382,19 @@ class InvocableHandlerMethodKotlinTests { class CustomException(message: String) : Throwable(message) + // Avoid adding a spring-webmvc dependency + class ContinuationHandlerMethodArgumentResolver : HandlerMethodArgumentResolver { + + override fun supportsParameter(parameter: MethodParameter) = + "kotlin.coroutines.Continuation" == parameter.getParameterType().getName() + + override fun resolveArgument( + parameter: MethodParameter, + mavContainer: ModelAndViewContainer?, + webRequest: NativeWebRequest, + binderFactory: WebDataBinderFactory? + ) = null + + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index cc9bd393d00c..836a77530887 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -36,6 +36,7 @@ import kotlin.reflect.jvm.KCallablesJvm; import kotlin.reflect.jvm.ReflectJvmMapping; import reactor.core.publisher.Mono; +import reactor.core.publisher.SynchronousSink; import reactor.core.scheduler.Scheduler; import org.springframework.core.CoroutinesUtils; @@ -323,18 +324,15 @@ private static class KotlinDelegate { private static final String COROUTINE_CONTEXT_ATTRIBUTE = "org.springframework.web.server.CoWebFilter.context"; @Nullable - @SuppressWarnings({"deprecation", "DataFlowIssue"}) + @SuppressWarnings("DataFlowIssue") public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction, ServerWebExchange exchange) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException { if (isSuspendingFunction) { Object coroutineContext = exchange.getAttribute(COROUTINE_CONTEXT_ATTRIBUTE); - if (coroutineContext == null) { - return CoroutinesUtils.invokeSuspendingFunction(method, target, args); - } - else { - return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args); - } + Object result = (coroutineContext == null ? CoroutinesUtils.invokeSuspendingFunction(method, target, args) : + CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args)); + return (result instanceof Mono mono ? mono.handle(KotlinDelegate::handleResult) : result); } else { KFunction function = ReflectJvmMapping.getKotlinFunction(method); @@ -370,11 +368,35 @@ public static Object invokeFunction(Method method, Object target, Object[] args, } Object result = function.callBy(argMap); if (result != null && KotlinDetector.isInlineClass(result.getClass())) { - return result.getClass().getDeclaredMethod("unbox-impl").invoke(result); + result = unbox(result); } return (result == Unit.INSTANCE ? null : result); } } + + private static void handleResult(Object result, SynchronousSink sink) { + if (KotlinDetector.isInlineClass(result.getClass())) { + try { + Object unboxed = unbox(result); + if (unboxed != Unit.INSTANCE) { + sink.next(unboxed); + } + sink.complete(); + } + catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) { + sink.error(ex); + } + } + else { + sink.next(result); + sink.complete(); + } + } + + private static Object unbox(Object result) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException { + return result.getClass().getDeclaredMethod("unbox-impl").invoke(result); + } + } } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt index 6f254d473b7a..595928ec134e 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt @@ -208,10 +208,17 @@ class InvocableHandlerMethodKotlinTests { @Test fun valueClassReturnValue() { val method = ValueClassController::valueClassReturnValue.javaMethod!! - val result = invoke(ValueClassController(), method,) + val result = invoke(ValueClassController(), method) assertHandlerResultValue(result, "foo") } + @Test + fun resultOfUnitReturnValue() { + val method = ValueClassController::resultOfUnitReturnValue.javaMethod!! + val result = invoke(ValueClassController(), method) + assertHandlerResultValue(result, null) + } + @Test fun valueClassWithDefaultValue() { this.resolvers.add(stubResolver(null, Double::class.java)) @@ -244,6 +251,60 @@ class InvocableHandlerMethodKotlinTests { assertHandlerResultValue(result, "1") } + @Test + fun suspendingValueClass() { + this.resolvers.add(stubResolver(1L, Long::class.java)) + val method = SuspendingValueClassController::valueClass.javaMethod!! + val result = invoke(SuspendingValueClassController(), method,1L) + assertHandlerResultValue(result, "1") + } + + @Test + fun suspendingValueClassReturnValue() { + val method = SuspendingValueClassController::valueClassReturnValue.javaMethod!! + val result = invoke(SuspendingValueClassController(), method) + assertHandlerResultValue(result, "foo") + } + + @Test + fun suspendingResultOfUnitReturnValue() { + val method = SuspendingValueClassController::resultOfUnitReturnValue.javaMethod!! + val result = invoke(SuspendingValueClassController(), method) + assertComplete(result) + } + + @Test + fun suspendingValueClassWithDefaultValue() { + this.resolvers.add(stubResolver(null, Double::class.java)) + val method = SuspendingValueClassController::valueClassWithDefault.javaMethod!! + val result = invoke(SuspendingValueClassController(), method) + assertHandlerResultValue(result, "3.1") + } + + @Test + fun suspendingValueClassWithInit() { + this.resolvers.add(stubResolver("", String::class.java)) + val method = SuspendingValueClassController::valueClassWithInit.javaMethod!! + val result = invoke(SuspendingValueClassController(), method) + assertExceptionThrown(result, IllegalArgumentException::class) + } + + @Test + fun suspendingValueClassWithNullable() { + this.resolvers.add(stubResolver(null, LongValueClass::class.java)) + val method = SuspendingValueClassController::valueClassWithNullable.javaMethod!! + val result = invoke(SuspendingValueClassController(), method, null) + assertHandlerResultValue(result, "null") + } + + @Test + fun suspendingValueClassWithPrivateConstructor() { + this.resolvers.add(stubResolver(1L, Long::class.java)) + val method = SuspendingValueClassController::valueClassWithPrivateConstructor.javaMethod!! + val result = invoke(SuspendingValueClassController(), method, 1L) + assertHandlerResultValue(result, "1") + } + @Test fun propertyAccessor() { this.resolvers.add(stubResolver(null, String::class.java)) @@ -313,9 +374,14 @@ class InvocableHandlerMethodKotlinTests { } private fun assertExceptionThrown(mono: Mono, exceptionClass: KClass) { - StepVerifier.create(mono).verifyError(exceptionClass.java) + StepVerifier.create(mono.flatMap { t -> t.returnValue as Mono<*> }).verifyError(exceptionClass.java) + } + + private fun assertComplete(mono: Mono) { + StepVerifier.create(mono.flatMap { t -> t.returnValue as Mono<*> }).verifyComplete() } + class CoroutinesController { suspend fun singleArg(q: String?): String { @@ -380,23 +446,57 @@ class InvocableHandlerMethodKotlinTests { class ValueClassController { - fun valueClass(limit: LongValueClass) = - "${limit.value}" + fun valueClass(limit: LongValueClass) = "${limit.value}" + + fun valueClassReturnValue() = StringValueClass("foo") + + fun resultOfUnitReturnValue() = Result.success(Unit) + + fun valueClassWithDefault(limit: DoubleValueClass = DoubleValueClass(3.1)) = "${limit.value}" + + fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass + + fun valueClassWithNullable(limit: LongValueClass?) = "${limit?.value}" + + fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = "${limit.value}" + } + + class SuspendingValueClassController { - fun valueClassReturnValue() = - StringValueClass("foo") + suspend fun valueClass(limit: LongValueClass): String { + delay(1) + return "${limit.value}" + } + + suspend fun valueClassReturnValue(): StringValueClass { + delay(1) + return StringValueClass("foo") + } + + suspend fun resultOfUnitReturnValue(): Result { + delay(1) + return Result.success(Unit) + } - fun valueClassWithDefault(limit: DoubleValueClass = DoubleValueClass(3.1)) = - "${limit.value}" + suspend fun valueClassWithDefault(limit: DoubleValueClass = DoubleValueClass(3.1)): String { + delay(1) + return "${limit.value}" + } - fun valueClassWithInit(valueClass: ValueClassWithInit) = - valueClass + suspend fun valueClassWithInit(valueClass: ValueClassWithInit): ValueClassWithInit { + delay(1) + return valueClass + } - fun valueClassWithNullable(limit: LongValueClass?) = - "${limit?.value}" + suspend fun valueClassWithNullable(limit: LongValueClass?): String { + delay(1) + return "${limit?.value}" + } - fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = - "${limit.value}" + suspend fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor): String { + delay(1) + return "${limit.value}" + } } class PropertyAccessorController {