1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll, Waker},
5 time::{Duration, Instant},
6};
7
8use parking_lot::{Condvar, Mutex};
9
10use crate::Arc;
11
12#[derive(Debug)]
13struct OneShotState<T> {
14 filled: bool,
15 fused: bool,
16 item: Option<T>,
17 waker: Option<Waker>,
18}
19
20impl<T> Default for OneShotState<T> {
21 fn default() -> OneShotState<T> {
22 OneShotState { filled: false, fused: false, item: None, waker: None }
23 }
24}
25
26#[derive(Debug)]
28pub struct OneShot<T> {
29 mu: Arc<Mutex<OneShotState<T>>>,
30 cv: Arc<Condvar>,
31}
32
33pub struct OneShotFiller<T> {
35 mu: Arc<Mutex<OneShotState<T>>>,
36 cv: Arc<Condvar>,
37}
38
39impl<T> OneShot<T> {
40 pub fn pair() -> (OneShotFiller<T>, Self) {
43 let mu = Arc::new(Mutex::new(OneShotState::default()));
44 let cv = Arc::new(Condvar::new());
45 let future = Self { mu: mu.clone(), cv: cv.clone() };
46 let filler = OneShotFiller { mu, cv };
47
48 (filler, future)
49 }
50
51 pub fn wait(self) -> Option<T> {
54 let mut inner = self.mu.lock();
55 while !inner.filled {
56 self.cv.wait(&mut inner);
57 }
58 inner.item.take()
59 }
60
61 pub fn wait_timeout(
71 &mut self,
72 mut timeout: Duration,
73 ) -> Result<T, std::sync::mpsc::RecvTimeoutError> {
74 let mut inner = self.mu.lock();
75 while !inner.filled {
76 let start = Instant::now();
77 let res = self.cv.wait_for(&mut inner, timeout);
78 if res.timed_out() {
79 return Err(std::sync::mpsc::RecvTimeoutError::Disconnected);
80 }
81 timeout =
82 if let Some(timeout) = timeout.checked_sub(start.elapsed()) {
83 timeout
84 } else {
85 Duration::from_nanos(0)
86 };
87 }
88 if let Some(item) = inner.item.take() {
89 Ok(item)
90 } else {
91 Err(std::sync::mpsc::RecvTimeoutError::Disconnected)
92 }
93 }
94}
95
96impl<T> Future for OneShot<T> {
97 type Output = Option<T>;
98
99 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100 let mut state = self.mu.lock();
101 if state.fused {
102 return Poll::Pending;
103 }
104 if state.filled {
105 state.fused = true;
106 Poll::Ready(state.item.take())
107 } else {
108 state.waker = Some(cx.waker().clone());
109 Poll::Pending
110 }
111 }
112}
113
114impl<T> OneShotFiller<T> {
115 pub fn fill(self, inner: T) {
117 let mut state = self.mu.lock();
118
119 if let Some(waker) = state.waker.take() {
120 waker.wake();
121 }
122
123 state.filled = true;
124 state.item = Some(inner);
125
126 drop(state);
129
130 let _notified = self.cv.notify_all();
131 }
132}
133
134impl<T> Drop for OneShotFiller<T> {
135 fn drop(&mut self) {
136 let mut state = self.mu.lock();
137
138 if state.filled {
139 return;
140 }
141
142 if let Some(waker) = state.waker.take() {
143 waker.wake();
144 }
145
146 state.filled = true;
147
148 drop(state);
151
152 let _notified = self.cv.notify_all();
153 }
154}