leak_playground_tokio/
task.rs1use std::{future::Future, marker::PhantomData, pin::Pin, ptr::NonNull};
4
5use leak_playground_std::marker::Unforget;
6use leak_playground_std::mem::ManuallyDrop;
7use tokio::task::{AbortHandle, JoinError, JoinHandle};
8
9pub fn spawn_scoped<'a, F>(future: F) -> ScopedJoinHandle<'a, F::Output>
11where
12 F: Future + Send + 'a,
13 F::Output: Send + 'a,
14{
15 ScopedJoinHandle {
16 inner: unsafe {
17 ManuallyDrop::new_unchecked(tokio::task::spawn(erased_send_future(future)))
18 },
19 _unforget: Unforget::new(PhantomData),
20 _unsend: PhantomData,
21 _output: PhantomData,
22 }
23}
24
25pub fn spawn_local_scoped<'a, F>(future: F) -> ScopedJoinHandle<'a, F::Output>
27where
28 F: Future + 'a,
29 F::Output: 'a,
30{
31 ScopedJoinHandle {
32 inner: unsafe {
33 ManuallyDrop::new_unchecked(tokio::task::spawn_local(erased_future(future)))
34 },
35 _unforget: Unforget::new(PhantomData),
36 _unsend: PhantomData,
37 _output: PhantomData,
38 }
39}
40
41pub fn spawn_blocking_scoped<'a, F, T>(f: F) -> ScopedJoinHandle<'a, T>
43where
44 F: FnOnce() -> T + Send + 'a,
45 T: Send + 'a,
46{
47 ScopedJoinHandle {
48 inner: unsafe {
49 ManuallyDrop::new_unchecked(tokio::task::spawn_blocking(erased_send_fn_once(f)))
50 },
51 _unforget: Unforget::new(PhantomData),
52 _unsend: PhantomData,
53 _output: PhantomData,
54 }
55}
56
57pub struct ScopedJoinHandle<'a, T> {
64 inner: ManuallyDrop<JoinHandle<Payload>>,
65 _unforget: Unforget<'static, PhantomData<&'a ()>>,
66 _output: PhantomData<T>,
68 _unsend: PhantomData<*mut ()>,
69}
70
71unsafe impl<T: Send> Send for ScopedJoinHandle<'static, T> {}
72unsafe impl<T: Send> Sync for ScopedJoinHandle<'_, T> {}
73impl<T> Unpin for ScopedJoinHandle<'_, T> {}
74
75impl<'a, T> Future for ScopedJoinHandle<'a, T> {
76 type Output = Result<T, JoinError>;
77
78 fn poll(
79 mut self: Pin<&mut Self>,
80 cx: &mut std::task::Context<'_>,
81 ) -> std::task::Poll<Self::Output> {
82 JoinHandle::poll(Pin::new(&mut self.inner), cx)
83 .map(|r| r.map(|r| unsafe { r.get_unchecked::<T>() }))
84 }
85}
86
87impl<'a, T> ScopedJoinHandle<'a, T> {
88 pub async fn cancel(mut self) -> Result<(), JoinError> {
89 self.inner.abort();
90 let task = unsafe { ManuallyDrop::take(&mut self.inner) };
91 match task.await {
92 Err(e) if e.is_cancelled() => Ok(()),
93 Ok(_) => Ok(()),
94 Err(e) => Err(e),
95 }
96 }
97
98 pub fn abort(&self) {
99 self.inner.abort();
100 }
101
102 pub fn abort_handle(&self) -> AbortHandle {
103 self.inner.abort_handle()
104 }
105}
106
107impl<'a, T> Drop for ScopedJoinHandle<'a, T> {
111 fn drop(&mut self) {
112 self.inner.abort();
113 let task = unsafe { ManuallyDrop::take(&mut self.inner) };
114 tokio::task::block_in_place(move || {
116 tokio::runtime::Handle::current().block_on(async move {
117 match task.await {
118 Err(e) if e.is_cancelled() => (),
119 Ok(_) => (),
120 Err(e) => std::panic::resume_unwind(e.into_panic()),
121 }
122 })
123 });
124 }
125}
126
127unsafe fn erased_send_fn_once<F, R>(f: F) -> impl FnOnce() -> Payload + Send + 'static
130where
131 F: FnOnce() -> R + Send,
132{
133 let f = move || Payload::new_unchecked(f());
134 let f: Box<dyn FnOnce() -> Payload + Send + '_> = Box::new(f);
135 let f: Box<dyn FnOnce() -> Payload + Send> = std::mem::transmute(f);
136 f
137}
138
139unsafe fn erased_send_future<F>(f: F) -> impl Future<Output = Payload> + Send + 'static
140where
141 F: Future + Send,
142{
143 let f = async move { Payload::new_unchecked(f.await) };
144 let f: Pin<Box<dyn Future<Output = Payload> + Send + '_>> = Box::pin(f);
145 let f: Pin<Box<dyn Future<Output = Payload> + Send>> = std::mem::transmute(f);
146 f
147}
148
149unsafe fn erased_future<F>(f: F) -> impl Future<Output = Payload> + 'static
150where
151 F: Future,
152{
153 let f = async move { Payload::new_unchecked(f.await) };
154 let f: Pin<Box<dyn Future<Output = Payload> + '_>> = Box::pin(f);
155 let f: Pin<Box<dyn Future<Output = Payload>>> = std::mem::transmute(f);
156 f
157}
158
159struct Payload {
160 ptr: NonNull<()>,
161}
162
163unsafe impl Send for Payload {}
164unsafe impl Sync for Payload {}
165
166impl Payload {
167 unsafe fn new_unchecked<T>(v: T) -> Payload {
168 Payload {
169 ptr: NonNull::new_unchecked(Box::into_raw(Box::new(v)).cast()),
170 }
171 }
172
173 unsafe fn get_unchecked<T>(self) -> T {
174 *Box::from_raw(self.ptr.cast().as_ptr())
175 }
176}