Skip to content

Commit 5bdfc8a

Browse files
authored
Merge pull request jooby-project#2328 from codeborne/flexible-coroutine-context
Flexible coroutine context
2 parents 9e1581f + 86804e5 commit 5bdfc8a

File tree

3 files changed

+89
-49
lines changed

3 files changed

+89
-49
lines changed

jooby/src/main/kotlin/io/jooby/CoroutineRouter.kt

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,75 +5,67 @@
55
*/
66
package io.jooby
77

8-
import kotlinx.coroutines.CoroutineExceptionHandler
9-
import kotlinx.coroutines.CoroutineScope
10-
import kotlinx.coroutines.CoroutineStart
11-
import kotlinx.coroutines.asCoroutineDispatcher
12-
import kotlinx.coroutines.launch
8+
import io.jooby.Router.*
9+
import kotlinx.coroutines.*
1310
import kotlin.coroutines.CoroutineContext
1411

15-
internal class RouterCoroutineScope(coroutineContext: CoroutineContext) : CoroutineScope {
16-
override val coroutineContext = coroutineContext
17-
}
12+
internal class RouterCoroutineScope(override val coroutineContext: CoroutineContext) : CoroutineScope
1813

1914
class CoroutineRouter(val coroutineStart: CoroutineStart, val router: Router) {
2015

2116
val coroutineScope: CoroutineScope by lazy {
2217
RouterCoroutineScope(router.worker.asCoroutineDispatcher())
2318
}
2419

25-
@RouterDsl
26-
fun get(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
27-
return route(Router.GET, pattern, handler)
20+
private var extendCoroutineContext: (CoroutineContext) -> CoroutineContext = { it }
21+
fun launchContext(block: (CoroutineContext) -> CoroutineContext) {
22+
extendCoroutineContext = block
2823
}
2924

3025
@RouterDsl
31-
fun post(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
32-
return route(Router.POST, pattern, handler)
33-
}
26+
fun get(pattern: String, handler: suspend HandlerContext.() -> Any) =
27+
route(GET, pattern, handler)
3428

3529
@RouterDsl
36-
fun put(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
37-
return route(Router.PUT, pattern, handler)
38-
}
30+
fun post(pattern: String, handler: suspend HandlerContext.() -> Any) =
31+
route(POST, pattern, handler)
3932

4033
@RouterDsl
41-
fun delete(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
42-
return route(Router.DELETE, pattern, handler)
43-
}
34+
fun put(pattern: String, handler: suspend HandlerContext.() -> Any) =
35+
route(PUT, pattern, handler)
4436

4537
@RouterDsl
46-
fun patch(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
47-
return route(Router.PATCH, pattern, handler)
48-
}
38+
fun delete(pattern: String, handler: suspend HandlerContext.() -> Any) =
39+
route(DELETE, pattern, handler)
4940

5041
@RouterDsl
51-
fun head(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
52-
return route(Router.HEAD, pattern, handler)
53-
}
42+
fun patch(pattern: String, handler: suspend HandlerContext.() -> Any) =
43+
route(PATCH, pattern, handler)
5444

5545
@RouterDsl
56-
fun trace(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
57-
return route(Router.TRACE, pattern, handler)
58-
}
46+
fun head(pattern: String, handler: suspend HandlerContext.() -> Any) =
47+
route(HEAD, pattern, handler)
5948

6049
@RouterDsl
61-
fun options(pattern: String, handler: suspend HandlerContext.() -> Any): Route {
62-
return route(Router.OPTIONS, pattern, handler)
63-
}
50+
fun trace(pattern: String, handler: suspend HandlerContext.() -> Any) =
51+
route(TRACE, pattern, handler)
6452

65-
fun route(method: String, pattern: String, handler: suspend HandlerContext.() -> Any): Route {
66-
return router.route(method, pattern) { ctx ->
67-
val xhandler = CoroutineExceptionHandler { _, x ->
68-
ctx.sendError(x)
69-
}
70-
coroutineScope.launch(xhandler, coroutineStart) {
53+
@RouterDsl
54+
fun options(pattern: String, handler: suspend HandlerContext.() -> Any) =
55+
route(OPTIONS, pattern, handler)
56+
57+
fun route(method: String, pattern: String, handler: suspend HandlerContext.() -> Any): Route =
58+
router.route(method, pattern) { ctx ->
59+
launch(ctx) {
7160
val result = handler(HandlerContext(ctx))
7261
if (result != ctx) {
7362
ctx.render(result)
7463
}
7564
}
76-
}.setHandle(handler)
77-
.attribute("coroutine", true)
65+
}.setHandle(handler).attribute("coroutine", true)
66+
67+
internal fun launch(ctx: Context, block: suspend CoroutineScope.() -> Unit) {
68+
val exceptionHandler = CoroutineExceptionHandler { _, x -> ctx.sendError(x) }
69+
coroutineScope.launch(extendCoroutineContext(exceptionHandler), coroutineStart, block)
7870
}
7971
}

jooby/src/main/kotlin/io/jooby/internal/mvc/CoroutineLauncher.kt

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,22 @@ package io.jooby.internal.mvc
88
import io.jooby.Context
99
import io.jooby.CoroutineRouter
1010
import io.jooby.Route
11-
import kotlinx.coroutines.CoroutineExceptionHandler
12-
import kotlinx.coroutines.launch
1311
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
1412

13+
/**
14+
* Used by compiled MVC-style routes with suspend functions
15+
*/
1516
class CoroutineLauncher(val next: Route.Handler) : Route.Handler {
16-
override fun apply(ctx: Context): Any {
17+
override fun apply(ctx: Context) = ctx.also {
1718
val router = ctx.router.attribute<CoroutineRouter>("coroutineRouter")
18-
val exceptionHandler = CoroutineExceptionHandler { _, x ->
19-
ctx.sendError(x)
20-
}
21-
router.coroutineScope.launch(exceptionHandler, router.coroutineStart) {
19+
router.launch(ctx) {
2220
val result = suspendCoroutineUninterceptedOrReturn<Any> {
2321
ctx.attribute("___continuation", it)
2422
next.apply(ctx)
2523
}
2624
if (!ctx.isResponseStarted) {
27-
ctx.render(result!!)
25+
ctx.render(result)
2826
}
2927
}
30-
return ctx
3128
}
3229
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package io.jooby
2+
3+
import io.jooby.Router.GET
4+
import kotlinx.coroutines.CoroutineExceptionHandler
5+
import kotlinx.coroutines.CoroutineStart
6+
import org.junit.jupiter.api.Test
7+
import org.mockito.ArgumentCaptor
8+
import org.mockito.Mockito.*
9+
import kotlin.coroutines.AbstractCoroutineContextElement
10+
import kotlin.coroutines.CoroutineContext
11+
12+
class CoroutineRouterTest {
13+
private val router = mock(Router::class.java, RETURNS_DEEP_STUBS)
14+
private val ctx = mock(Context::class.java)
15+
16+
@Test
17+
fun withoutLaunchContext() {
18+
CoroutineRouter(CoroutineStart.UNDISPATCHED, router).apply {
19+
get("/path") { "Result" }
20+
}
21+
22+
val handlerCaptor = ArgumentCaptor.forClass(Route.Handler::class.java)
23+
verify(router).route(eq(GET), eq("/path"), handlerCaptor.capture())
24+
handlerCaptor.value.apply(ctx)
25+
26+
verify(ctx).render("Result")
27+
}
28+
29+
@Test
30+
fun launchContext_isRunEveryTime() {
31+
val mockCoroutineContext = mock(CoroutineContext::class.java)
32+
`when`(mockCoroutineContext.plus(any() ?: mockCoroutineContext)).thenReturn(mockCoroutineContext, ExtraContext())
33+
34+
CoroutineRouter(CoroutineStart.DEFAULT, router).apply {
35+
launchContext { mockCoroutineContext + it + ExtraContext() }
36+
get("/path") { "Result" }
37+
}
38+
39+
val handlerCaptor = ArgumentCaptor.forClass(Route.Handler::class.java)
40+
verify(router).route(eq(GET), eq("/path"), handlerCaptor.capture())
41+
verifyNoInteractions(mockCoroutineContext)
42+
43+
handlerCaptor.value.apply(ctx)
44+
verify(mockCoroutineContext).plus(argThat { it is CoroutineExceptionHandler } ?: mockCoroutineContext)
45+
verify(mockCoroutineContext).plus(argThat { it is ExtraContext } ?: mockCoroutineContext)
46+
}
47+
48+
class ExtraContext : AbstractCoroutineContextElement(Key) {
49+
companion object Key : CoroutineContext.Key<ExtraContext>
50+
}
51+
}

0 commit comments

Comments
 (0)