sled/
subscriber.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        atomic::{AtomicBool, Ordering::Relaxed},
6        mpsc::{sync_channel, Receiver, SyncSender, TryRecvError},
7    },
8    task::{Context, Poll, Waker},
9    time::{Duration, Instant},
10};
11
12#[cfg(not(feature = "testing"))]
13use std::collections::HashMap as Map;
14
15// we avoid HashMap while testing because
16// it makes tests non-deterministic
17#[cfg(feature = "testing")]
18use std::collections::BTreeMap as Map;
19
20use crate::*;
21
22static ID_GEN: AtomicUsize = AtomicUsize::new(0);
23
24/// An event that happened to a key that a subscriber is interested in.
25#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
26pub enum Event {
27    /// A new complete (key, value) pair
28    Insert {
29        /// The key that has been set
30        key: IVec,
31        /// The value that has been set
32        value: IVec,
33    },
34    /// A deleted key
35    Remove {
36        /// The key that has been removed
37        key: IVec,
38    },
39}
40
41impl Event {
42    /// Return the key associated with the `Event`
43    pub fn key(&self) -> &IVec {
44        match self {
45            Event::Insert { key, .. } | Event::Remove { key } => key,
46        }
47    }
48}
49
50type Senders = Map<usize, (Option<Waker>, SyncSender<OneShot<Option<Event>>>)>;
51
52/// A subscriber listening on a specified prefix
53///
54/// `Subscriber` implements both `Iterator<Item = Event>`
55/// and `Future<Output=Option<Event>>`
56///
57/// # Examples
58///
59/// Synchronous, blocking subscriber:
60/// ```
61/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
62/// use sled::{Config, Event};
63/// let config = Config::new().temporary(true);
64///
65/// let tree = config.open()?;
66///
67/// // watch all events by subscribing to the empty prefix
68/// let mut subscriber = tree.watch_prefix(vec![]);
69///
70/// let tree_2 = tree.clone();
71/// let thread = std::thread::spawn(move || {
72///     tree.insert(vec![0], vec![1])
73/// });
74///
75/// // `Subscription` implements `Iterator<Item=Event>`
76/// for event in subscriber.take(1) {
77///     match event {
78///         Event::Insert{ key, value } => assert_eq!(key.as_ref(), &[0]),
79///         Event::Remove {key } => {}
80///     }
81/// }
82///
83/// # thread.join().unwrap();
84/// # Ok(())
85/// # }
86/// ```
87/// Aynchronous, non-blocking subscriber:
88///
89/// `Subscription` implements `Future<Output=Option<Event>>`.
90///
91/// `while let Some(event) = (&mut subscriber).await { /* use it */ }`
92pub struct Subscriber {
93    id: usize,
94    rx: Receiver<OneShot<Option<Event>>>,
95    existing: Option<OneShot<Option<Event>>>,
96    home: Arc<RwLock<Senders>>,
97}
98
99impl Drop for Subscriber {
100    fn drop(&mut self) {
101        let mut w_senders = self.home.write();
102        w_senders.remove(&self.id);
103    }
104}
105
106impl Subscriber {
107    /// Attempts to wait for a value on this `Subscriber`, returning
108    /// an error if no event arrives within the provided `Duration`
109    /// or if the backing `Db` shuts down.
110    pub fn next_timeout(
111        &mut self,
112        mut timeout: Duration,
113    ) -> std::result::Result<Event, std::sync::mpsc::RecvTimeoutError> {
114        loop {
115            let start = Instant::now();
116            let mut future_rx = if let Some(future_rx) = self.existing.take() {
117                future_rx
118            } else {
119                self.rx.recv_timeout(timeout)?
120            };
121            timeout =
122                if let Some(timeout) = timeout.checked_sub(start.elapsed()) {
123                    timeout
124                } else {
125                    Duration::from_nanos(0)
126                };
127
128            let start = Instant::now();
129            match future_rx.wait_timeout(timeout) {
130                Ok(Some(event)) => return Ok(event),
131                Ok(None) => (),
132                Err(timeout_error) => {
133                    self.existing = Some(future_rx);
134                    return Err(timeout_error);
135                }
136            }
137            timeout =
138                if let Some(timeout) = timeout.checked_sub(start.elapsed()) {
139                    timeout
140                } else {
141                    Duration::from_nanos(0)
142                };
143        }
144    }
145}
146
147impl Future for Subscriber {
148    type Output = Option<Event>;
149
150    fn poll(
151        mut self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153    ) -> Poll<Self::Output> {
154        loop {
155            let mut future_rx = if let Some(future_rx) = self.existing.take() {
156                future_rx
157            } else {
158                match self.rx.try_recv() {
159                    Ok(future_rx) => future_rx,
160                    Err(TryRecvError::Empty) => break,
161                    Err(TryRecvError::Disconnected) => {
162                        return Poll::Ready(None)
163                    }
164                }
165            };
166
167            match Future::poll(Pin::new(&mut future_rx), cx) {
168                Poll::Ready(Some(event)) => return Poll::Ready(event),
169                Poll::Ready(None) => continue,
170                Poll::Pending => {
171                    self.existing = Some(future_rx);
172                    return Poll::Pending;
173                }
174            }
175        }
176        let mut home = self.home.write();
177        let entry = home.get_mut(&self.id).unwrap();
178        entry.0 = Some(cx.waker().clone());
179        Poll::Pending
180    }
181}
182
183impl Iterator for Subscriber {
184    type Item = Event;
185
186    fn next(&mut self) -> Option<Event> {
187        loop {
188            let future_rx = self.rx.recv().ok()?;
189            match future_rx.wait() {
190                Some(Some(event)) => return Some(event),
191                Some(None) => return None,
192                None => continue,
193            }
194        }
195    }
196}
197
198#[derive(Debug, Default)]
199pub(crate) struct Subscribers {
200    watched: RwLock<BTreeMap<Vec<u8>, Arc<RwLock<Senders>>>>,
201    ever_used: AtomicBool,
202}
203
204impl Drop for Subscribers {
205    fn drop(&mut self) {
206        let watched = self.watched.read();
207
208        for senders in watched.values() {
209            let senders =
210                std::mem::replace(&mut *senders.write(), Map::default());
211            for (_, (waker, sender)) in senders {
212                drop(sender);
213                if let Some(waker) = waker {
214                    waker.wake();
215                }
216            }
217        }
218    }
219}
220
221impl Subscribers {
222    pub(crate) fn register(&self, prefix: &[u8]) -> Subscriber {
223        self.ever_used.store(true, Relaxed);
224        let r_mu = {
225            let r_mu = self.watched.read();
226            if r_mu.contains_key(prefix) {
227                r_mu
228            } else {
229                drop(r_mu);
230                let mut w_mu = self.watched.write();
231                if !w_mu.contains_key(prefix) {
232                    let old = w_mu.insert(
233                        prefix.to_vec(),
234                        Arc::new(RwLock::new(Map::default())),
235                    );
236                    assert!(old.is_none());
237                }
238                drop(w_mu);
239                self.watched.read()
240            }
241        };
242
243        let (tx, rx) = sync_channel(1024);
244
245        let arc_senders = &r_mu[prefix];
246        let mut w_senders = arc_senders.write();
247
248        let id = ID_GEN.fetch_add(1, Relaxed);
249
250        w_senders.insert(id, (None, tx));
251
252        Subscriber { id, rx, existing: None, home: arc_senders.clone() }
253    }
254
255    pub(crate) fn reserve<R: AsRef<[u8]>>(
256        &self,
257        key: R,
258    ) -> Option<ReservedBroadcast> {
259        if !self.ever_used.load(Relaxed) {
260            return None;
261        }
262
263        let r_mu = self.watched.read();
264        let prefixes = r_mu.iter().filter(|(k, _)| key.as_ref().starts_with(k));
265
266        let mut subscribers = vec![];
267
268        for (_, subs_rwl) in prefixes {
269            let subs = subs_rwl.read();
270
271            for (_id, (waker, sender)) in subs.iter() {
272                let (tx, rx) = OneShot::pair();
273                if sender.send(rx).is_err() {
274                    continue;
275                }
276                subscribers.push((waker.clone(), tx));
277            }
278        }
279
280        if subscribers.is_empty() {
281            None
282        } else {
283            Some(ReservedBroadcast { subscribers })
284        }
285    }
286}
287
288pub(crate) struct ReservedBroadcast {
289    subscribers: Vec<(Option<Waker>, OneShotFiller<Option<Event>>)>,
290}
291
292impl ReservedBroadcast {
293    pub fn complete(self, event: &Event) {
294        let iter = self.subscribers.into_iter();
295
296        for (waker, tx) in iter {
297            tx.fill(Some(event.clone()));
298            if let Some(waker) = waker {
299                waker.wake();
300            }
301        }
302    }
303}
304
305#[test]
306fn basic_subscriber() {
307    let subs = Subscribers::default();
308
309    let mut s2 = subs.register(&[0]);
310    let mut s3 = subs.register(&[0, 1]);
311    let mut s4 = subs.register(&[1, 2]);
312
313    let r1 = subs.reserve(b"awft");
314    assert!(r1.is_none());
315
316    let mut s1 = subs.register(&[]);
317
318    let k2: IVec = vec![].into();
319    let r2 = subs.reserve(&k2).unwrap();
320    r2.complete(&Event::Insert { key: k2.clone(), value: k2.clone() });
321
322    let k3: IVec = vec![0].into();
323    let r3 = subs.reserve(&k3).unwrap();
324    r3.complete(&Event::Insert { key: k3.clone(), value: k3.clone() });
325
326    let k4: IVec = vec![0, 1].into();
327    let r4 = subs.reserve(&k4).unwrap();
328    r4.complete(&Event::Remove { key: k4.clone() });
329
330    let k5: IVec = vec![0, 1, 2].into();
331    let r5 = subs.reserve(&k5).unwrap();
332    r5.complete(&Event::Insert { key: k5.clone(), value: k5.clone() });
333
334    let k6: IVec = vec![1, 1, 2].into();
335    let r6 = subs.reserve(&k6).unwrap();
336    r6.complete(&Event::Remove { key: k6.clone() });
337
338    let k7: IVec = vec![1, 1, 2].into();
339    let r7 = subs.reserve(&k7).unwrap();
340    drop(r7);
341
342    let k8: IVec = vec![1, 2, 2].into();
343    let r8 = subs.reserve(&k8).unwrap();
344    r8.complete(&Event::Insert { key: k8.clone(), value: k8.clone() });
345
346    assert_eq!(s1.next().unwrap().key(), &*k2);
347    assert_eq!(s1.next().unwrap().key(), &*k3);
348    assert_eq!(s1.next().unwrap().key(), &*k4);
349    assert_eq!(s1.next().unwrap().key(), &*k5);
350    assert_eq!(s1.next().unwrap().key(), &*k6);
351    assert_eq!(s1.next().unwrap().key(), &*k8);
352
353    assert_eq!(s2.next().unwrap().key(), &*k3);
354    assert_eq!(s2.next().unwrap().key(), &*k4);
355    assert_eq!(s2.next().unwrap().key(), &*k5);
356
357    assert_eq!(s3.next().unwrap().key(), &*k4);
358    assert_eq!(s3.next().unwrap().key(), &*k5);
359
360    assert_eq!(s4.next().unwrap().key(), &*k8);
361}