Skip to main content

leak_playground_tokio/
task.rs

1//! Possible [`tokio::task`](https://docs.rs/tokio/1.35.1/tokio/task/index.html) additions.
2
3use std::{future::Future, marker::PhantomData, pin::Pin, ptr::NonNull};
4
5use leak_playground_std::marker::Unforget;
6use leak_playground_std::mem::ManuallyDrop;
7use tokio::task::{AbortHandle, JoinError, JoinHandle};
8
9/// Spawns a non-static `Send` future, returning for non-static cases a `!Send` task handle.
10pub fn spawn_scoped<'a, F>(future: F) -> ScopedJoinHandle<'a, F::Output>
11where
12    F: Future + Send + 'a,
13    F::Output: Send + 'a,
14{
15    ScopedJoinHandle {
16        inner: unsafe {
17            ManuallyDrop::new_unchecked(tokio::task::spawn(erased_send_future(future)))
18        },
19        _unforget: Unforget::new(PhantomData),
20        _unsend: PhantomData,
21        _output: PhantomData,
22    }
23}
24
25/// Spawns a non-static `!Send` future.
26pub fn spawn_local_scoped<'a, F>(future: F) -> ScopedJoinHandle<'a, F::Output>
27where
28    F: Future + 'a,
29    F::Output: 'a,
30{
31    ScopedJoinHandle {
32        inner: unsafe {
33            ManuallyDrop::new_unchecked(tokio::task::spawn_local(erased_future(future)))
34        },
35        _unforget: Unforget::new(PhantomData),
36        _unsend: PhantomData,
37        _output: PhantomData,
38    }
39}
40
41/// Runs the provided non-static closure on a thread where blocking is acceptable.
42pub fn spawn_blocking_scoped<'a, F, T>(f: F) -> ScopedJoinHandle<'a, T>
43where
44    F: FnOnce() -> T + Send + 'a,
45    T: Send + 'a,
46{
47    ScopedJoinHandle {
48        inner: unsafe {
49            ManuallyDrop::new_unchecked(tokio::task::spawn_blocking(erased_send_fn_once(f)))
50        },
51        _unforget: Unforget::new(PhantomData),
52        _unsend: PhantomData,
53        _output: PhantomData,
54    }
55}
56
57/// Handle to a task, which cancels on drop.
58///
59/// This is made to ensure we won't put task into itself, thus forgetting it.
60///
61/// To spawn use [`spawn_scoped`], [`spawn_local_scoped`], or
62/// [`spawn_blocking_scoped`].
63pub struct ScopedJoinHandle<'a, T> {
64    inner: ManuallyDrop<JoinHandle<Payload>>,
65    _unforget: Unforget<'static, PhantomData<&'a ()>>,
66    // No need for Unforget since we put bound `T: 'a` on constructors
67    _output: PhantomData<T>,
68    _unsend: PhantomData<*mut ()>,
69}
70
71unsafe impl<T: Send> Send for ScopedJoinHandle<'static, T> {}
72unsafe impl<T: Send> Sync for ScopedJoinHandle<'_, T> {}
73impl<T> Unpin for ScopedJoinHandle<'_, T> {}
74
75impl<'a, T> Future for ScopedJoinHandle<'a, T> {
76    type Output = Result<T, JoinError>;
77
78    fn poll(
79        mut self: Pin<&mut Self>,
80        cx: &mut std::task::Context<'_>,
81    ) -> std::task::Poll<Self::Output> {
82        JoinHandle::poll(Pin::new(&mut self.inner), cx)
83            .map(|r| r.map(|r| unsafe { r.get_unchecked::<T>() }))
84    }
85}
86
87impl<'a, T> ScopedJoinHandle<'a, T> {
88    pub async fn cancel(mut self) -> Result<(), JoinError> {
89        self.inner.abort();
90        let task = unsafe { ManuallyDrop::take(&mut self.inner) };
91        match task.await {
92            Err(e) if e.is_cancelled() => Ok(()),
93            Ok(_) => Ok(()),
94            Err(e) => Err(e),
95        }
96    }
97
98    pub fn abort(&self) {
99        self.inner.abort();
100    }
101
102    pub fn abort_handle(&self) -> AbortHandle {
103        self.inner.abort_handle()
104    }
105}
106
107// TODO: `impl<T> From<ScopedJoinHandle<'static, T>> for JoinHandle<T>`
108//  is possible but requires internals to avoid hacky `Payload` return type
109
110impl<'a, T> Drop for ScopedJoinHandle<'a, T> {
111    fn drop(&mut self) {
112        self.inner.abort();
113        let task = unsafe { ManuallyDrop::take(&mut self.inner) };
114        // TODO: this is a hack-around without async drop
115        tokio::task::block_in_place(move || {
116            tokio::runtime::Handle::current().block_on(async move {
117                match task.await {
118                    Err(e) if e.is_cancelled() => (),
119                    Ok(_) => (),
120                    Err(e) => std::panic::resume_unwind(e.into_panic()),
121                }
122            })
123        });
124    }
125}
126
127// # Hack-around utilities
128
129unsafe fn erased_send_fn_once<F, R>(f: F) -> impl FnOnce() -> Payload + Send + 'static
130where
131    F: FnOnce() -> R + Send,
132{
133    let f = move || Payload::new_unchecked(f());
134    let f: Box<dyn FnOnce() -> Payload + Send + '_> = Box::new(f);
135    let f: Box<dyn FnOnce() -> Payload + Send> = std::mem::transmute(f);
136    f
137}
138
139unsafe fn erased_send_future<F>(f: F) -> impl Future<Output = Payload> + Send + 'static
140where
141    F: Future + Send,
142{
143    let f = async move { Payload::new_unchecked(f.await) };
144    let f: Pin<Box<dyn Future<Output = Payload> + Send + '_>> = Box::pin(f);
145    let f: Pin<Box<dyn Future<Output = Payload> + Send>> = std::mem::transmute(f);
146    f
147}
148
149unsafe fn erased_future<F>(f: F) -> impl Future<Output = Payload> + 'static
150where
151    F: Future,
152{
153    let f = async move { Payload::new_unchecked(f.await) };
154    let f: Pin<Box<dyn Future<Output = Payload> + '_>> = Box::pin(f);
155    let f: Pin<Box<dyn Future<Output = Payload>>> = std::mem::transmute(f);
156    f
157}
158
159struct Payload {
160    ptr: NonNull<()>,
161}
162
163unsafe impl Send for Payload {}
164unsafe impl Sync for Payload {}
165
166impl Payload {
167    unsafe fn new_unchecked<T>(v: T) -> Payload {
168        Payload {
169            ptr: NonNull::new_unchecked(Box::into_raw(Box::new(v)).cast()),
170        }
171    }
172
173    unsafe fn get_unchecked<T>(self) -> T {
174        *Box::from_raw(self.ptr.cast().as_ptr())
175    }
176}