sled/
oneshot.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll, Waker},
5    time::{Duration, Instant},
6};
7
8use parking_lot::{Condvar, Mutex};
9
10use crate::Arc;
11
12#[derive(Debug)]
13struct OneShotState<T> {
14    filled: bool,
15    fused: bool,
16    item: Option<T>,
17    waker: Option<Waker>,
18}
19
20impl<T> Default for OneShotState<T> {
21    fn default() -> OneShotState<T> {
22        OneShotState { filled: false, fused: false, item: None, waker: None }
23    }
24}
25
26/// A Future value which may or may not be filled
27#[derive(Debug)]
28pub struct OneShot<T> {
29    mu: Arc<Mutex<OneShotState<T>>>,
30    cv: Arc<Condvar>,
31}
32
33/// The completer side of the Future
34pub struct OneShotFiller<T> {
35    mu: Arc<Mutex<OneShotState<T>>>,
36    cv: Arc<Condvar>,
37}
38
39impl<T> OneShot<T> {
40    /// Create a new `OneShotFiller` and the `OneShot`
41    /// that will be filled by its completion.
42    pub fn pair() -> (OneShotFiller<T>, Self) {
43        let mu = Arc::new(Mutex::new(OneShotState::default()));
44        let cv = Arc::new(Condvar::new());
45        let future = Self { mu: mu.clone(), cv: cv.clone() };
46        let filler = OneShotFiller { mu, cv };
47
48        (filler, future)
49    }
50
51    /// Block on the `OneShot`'s completion
52    /// or dropping of the `OneShotFiller`
53    pub fn wait(self) -> Option<T> {
54        let mut inner = self.mu.lock();
55        while !inner.filled {
56            self.cv.wait(&mut inner);
57        }
58        inner.item.take()
59    }
60
61    /// Block on the `OneShot`'s completion
62    /// or dropping of the `OneShotFiller`,
63    /// returning an error if not filled
64    /// before a given timeout or if the
65    /// system shuts down before then.
66    ///
67    /// Upon a successful receive, the
68    /// oneshot should be dropped, as it
69    /// will never yield that value again.
70    pub fn wait_timeout(
71        &mut self,
72        mut timeout: Duration,
73    ) -> Result<T, std::sync::mpsc::RecvTimeoutError> {
74        let mut inner = self.mu.lock();
75        while !inner.filled {
76            let start = Instant::now();
77            let res = self.cv.wait_for(&mut inner, timeout);
78            if res.timed_out() {
79                return Err(std::sync::mpsc::RecvTimeoutError::Disconnected);
80            }
81            timeout =
82                if let Some(timeout) = timeout.checked_sub(start.elapsed()) {
83                    timeout
84                } else {
85                    Duration::from_nanos(0)
86                };
87        }
88        if let Some(item) = inner.item.take() {
89            Ok(item)
90        } else {
91            Err(std::sync::mpsc::RecvTimeoutError::Disconnected)
92        }
93    }
94}
95
96impl<T> Future for OneShot<T> {
97    type Output = Option<T>;
98
99    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100        let mut state = self.mu.lock();
101        if state.fused {
102            return Poll::Pending;
103        }
104        if state.filled {
105            state.fused = true;
106            Poll::Ready(state.item.take())
107        } else {
108            state.waker = Some(cx.waker().clone());
109            Poll::Pending
110        }
111    }
112}
113
114impl<T> OneShotFiller<T> {
115    /// Complete the `OneShot`
116    pub fn fill(self, inner: T) {
117        let mut state = self.mu.lock();
118
119        if let Some(waker) = state.waker.take() {
120            waker.wake();
121        }
122
123        state.filled = true;
124        state.item = Some(inner);
125
126        // having held the mutex makes this linearized
127        // with the notify below.
128        drop(state);
129
130        let _notified = self.cv.notify_all();
131    }
132}
133
134impl<T> Drop for OneShotFiller<T> {
135    fn drop(&mut self) {
136        let mut state = self.mu.lock();
137
138        if state.filled {
139            return;
140        }
141
142        if let Some(waker) = state.waker.take() {
143            waker.wake();
144        }
145
146        state.filled = true;
147
148        // having held the mutex makes this linearized
149        // with the notify below.
150        drop(state);
151
152        let _notified = self.cv.notify_all();
153    }
154}