From 0f60e9969dc115dffa19748f5023f8127b48186a Mon Sep 17 00:00:00 2001 From: Rustam Date: Mon, 27 Jun 2022 23:29:38 +0200 Subject: [PATCH] KTOR-4164 Fix ClassCastException when development mode is on (#3082) --- .../ApplicationEngineEnvironmentReloading.kt | 24 +++++++- .../jvm/test/TestApplicationTestJvm.kt | 60 ++++++++++++++++++- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/ktor-server/ktor-server-host-common/jvm/src/io/ktor/server/engine/ApplicationEngineEnvironmentReloading.kt b/ktor-server/ktor-server-host-common/jvm/src/io/ktor/server/engine/ApplicationEngineEnvironmentReloading.kt index ffde60f2e21..b3219cc8206 100644 --- a/ktor-server/ktor-server-host-common/jvm/src/io/ktor/server/engine/ApplicationEngineEnvironmentReloading.kt +++ b/ktor-server/ktor-server-host-common/jvm/src/io/ktor/server/engine/ApplicationEngineEnvironmentReloading.kt @@ -36,11 +36,16 @@ public class ApplicationEngineEnvironmentReloading( override val connectors: List, internal val modules: List Unit>, internal val watchPaths: List = emptyList(), - override val parentCoroutineContext: CoroutineContext = EmptyCoroutineContext, + parentCoroutineContext: CoroutineContext = EmptyCoroutineContext, override val rootPath: String = "", override val developmentMode: Boolean = true ) : ApplicationEngineEnvironment { + override val parentCoroutineContext: CoroutineContext = when { + developmentMode -> parentCoroutineContext + ClassLoaderAwareContinuationInterceptor + else -> parentCoroutineContext + } + public constructor( classLoader: ClassLoader, log: Logger, @@ -365,3 +370,20 @@ public class ApplicationEngineEnvironmentReloading( public companion object } + +private object ClassLoaderAwareContinuationInterceptor : ContinuationInterceptor { + override val key: CoroutineContext.Key<*> = + object : CoroutineContext.Key {} + + override fun interceptContinuation(continuation: Continuation): Continuation { + val classLoader = Thread.currentThread().contextClassLoader + return object : Continuation { + override val context: CoroutineContext = continuation.context + + override fun resumeWith(result: Result) { + Thread.currentThread().contextClassLoader = classLoader + continuation.resumeWith(result) + } + } + } +} diff --git a/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt b/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt index e5b4f040a66..cffa401a44e 100644 --- a/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt +++ b/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt @@ -14,6 +14,9 @@ import io.ktor.server.routing.* import io.ktor.server.testing.* import io.ktor.server.websocket.* import io.ktor.websocket.* +import kotlinx.coroutines.* +import java.io.* +import kotlin.coroutines.* import kotlin.test.* import io.ktor.client.plugins.websocket.WebSockets as ClientWebSockets @@ -100,9 +103,9 @@ class TestApplicationTestJvm { @Test fun testExternalServicesCustomConfig() = testApplication { - environment { - config = ApplicationConfig("application-custom.conf") - } + environment { + config = ApplicationConfig("application-custom.conf") + } externalServices { hosts("http://www.google.com") { val config = environment.config @@ -119,9 +122,60 @@ class TestApplicationTestJvm { assertEquals("another_test_value", external.bodyAsText()) } + @Test + fun testModuleWithLaunch() = testApplication { + var error: Throwable? = null + val exceptionHandler: CoroutineContext = object : CoroutineExceptionHandler { + override val key: CoroutineContext.Key<*> = CoroutineExceptionHandler.Key + override fun handleException(context: CoroutineContext, exception: Throwable) { + error = exception + } + } + environment { + parentCoroutineContext = exceptionHandler + } + application { + launch { + val byteArrayInputStream = ByteArrayOutputStream() + val objectOutputStream = ObjectOutputStream(byteArrayInputStream) + objectOutputStream.writeObject(TestClass(123)) + objectOutputStream.flush() + objectOutputStream.close() + + val ois = TestObjectInputStream(ByteArrayInputStream(byteArrayInputStream.toByteArray())) + val test = ois.readObject() + test as TestClass + } + } + routing { + get("/") { + call.respond("OK") + } + } + + client.get("/") + Thread.sleep(3000) + assertNull(error) + } + public fun Application.module() { routing { get { call.respond("OK FROM MODULE") } } } } + +class TestClass(val value: Int) : Serializable + +class TestObjectInputStream(input: InputStream) : ObjectInputStream(input) { + override fun resolveClass(desc: ObjectStreamClass?): Class<*> { + val name = desc?.name + val loader = Thread.currentThread().contextClassLoader + + return try { + Class.forName(name, false, loader) + } catch (e: ClassNotFoundException) { + super.resolveClass(desc) + } + } +}