1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 atomic::{AtomicBool, Ordering::Relaxed},
6 mpsc::{sync_channel, Receiver, SyncSender, TryRecvError},
7 },
8 task::{Context, Poll, Waker},
9 time::{Duration, Instant},
10};
11
12#[cfg(not(feature = "testing"))]
13use std::collections::HashMap as Map;
14
15#[cfg(feature = "testing")]
18use std::collections::BTreeMap as Map;
19
20use crate::*;
21
22static ID_GEN: AtomicUsize = AtomicUsize::new(0);
23
24#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
26pub enum Event {
27 Insert {
29 key: IVec,
31 value: IVec,
33 },
34 Remove {
36 key: IVec,
38 },
39}
40
41impl Event {
42 pub fn key(&self) -> &IVec {
44 match self {
45 Event::Insert { key, .. } | Event::Remove { key } => key,
46 }
47 }
48}
49
50type Senders = Map<usize, (Option<Waker>, SyncSender<OneShot<Option<Event>>>)>;
51
52pub struct Subscriber {
93 id: usize,
94 rx: Receiver<OneShot<Option<Event>>>,
95 existing: Option<OneShot<Option<Event>>>,
96 home: Arc<RwLock<Senders>>,
97}
98
99impl Drop for Subscriber {
100 fn drop(&mut self) {
101 let mut w_senders = self.home.write();
102 w_senders.remove(&self.id);
103 }
104}
105
106impl Subscriber {
107 pub fn next_timeout(
111 &mut self,
112 mut timeout: Duration,
113 ) -> std::result::Result<Event, std::sync::mpsc::RecvTimeoutError> {
114 loop {
115 let start = Instant::now();
116 let mut future_rx = if let Some(future_rx) = self.existing.take() {
117 future_rx
118 } else {
119 self.rx.recv_timeout(timeout)?
120 };
121 timeout =
122 if let Some(timeout) = timeout.checked_sub(start.elapsed()) {
123 timeout
124 } else {
125 Duration::from_nanos(0)
126 };
127
128 let start = Instant::now();
129 match future_rx.wait_timeout(timeout) {
130 Ok(Some(event)) => return Ok(event),
131 Ok(None) => (),
132 Err(timeout_error) => {
133 self.existing = Some(future_rx);
134 return Err(timeout_error);
135 }
136 }
137 timeout =
138 if let Some(timeout) = timeout.checked_sub(start.elapsed()) {
139 timeout
140 } else {
141 Duration::from_nanos(0)
142 };
143 }
144 }
145}
146
147impl Future for Subscriber {
148 type Output = Option<Event>;
149
150 fn poll(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 ) -> Poll<Self::Output> {
154 loop {
155 let mut future_rx = if let Some(future_rx) = self.existing.take() {
156 future_rx
157 } else {
158 match self.rx.try_recv() {
159 Ok(future_rx) => future_rx,
160 Err(TryRecvError::Empty) => break,
161 Err(TryRecvError::Disconnected) => {
162 return Poll::Ready(None)
163 }
164 }
165 };
166
167 match Future::poll(Pin::new(&mut future_rx), cx) {
168 Poll::Ready(Some(event)) => return Poll::Ready(event),
169 Poll::Ready(None) => continue,
170 Poll::Pending => {
171 self.existing = Some(future_rx);
172 return Poll::Pending;
173 }
174 }
175 }
176 let mut home = self.home.write();
177 let entry = home.get_mut(&self.id).unwrap();
178 entry.0 = Some(cx.waker().clone());
179 Poll::Pending
180 }
181}
182
183impl Iterator for Subscriber {
184 type Item = Event;
185
186 fn next(&mut self) -> Option<Event> {
187 loop {
188 let future_rx = self.rx.recv().ok()?;
189 match future_rx.wait() {
190 Some(Some(event)) => return Some(event),
191 Some(None) => return None,
192 None => continue,
193 }
194 }
195 }
196}
197
198#[derive(Debug, Default)]
199pub(crate) struct Subscribers {
200 watched: RwLock<BTreeMap<Vec<u8>, Arc<RwLock<Senders>>>>,
201 ever_used: AtomicBool,
202}
203
204impl Drop for Subscribers {
205 fn drop(&mut self) {
206 let watched = self.watched.read();
207
208 for senders in watched.values() {
209 let senders =
210 std::mem::replace(&mut *senders.write(), Map::default());
211 for (_, (waker, sender)) in senders {
212 drop(sender);
213 if let Some(waker) = waker {
214 waker.wake();
215 }
216 }
217 }
218 }
219}
220
221impl Subscribers {
222 pub(crate) fn register(&self, prefix: &[u8]) -> Subscriber {
223 self.ever_used.store(true, Relaxed);
224 let r_mu = {
225 let r_mu = self.watched.read();
226 if r_mu.contains_key(prefix) {
227 r_mu
228 } else {
229 drop(r_mu);
230 let mut w_mu = self.watched.write();
231 if !w_mu.contains_key(prefix) {
232 let old = w_mu.insert(
233 prefix.to_vec(),
234 Arc::new(RwLock::new(Map::default())),
235 );
236 assert!(old.is_none());
237 }
238 drop(w_mu);
239 self.watched.read()
240 }
241 };
242
243 let (tx, rx) = sync_channel(1024);
244
245 let arc_senders = &r_mu[prefix];
246 let mut w_senders = arc_senders.write();
247
248 let id = ID_GEN.fetch_add(1, Relaxed);
249
250 w_senders.insert(id, (None, tx));
251
252 Subscriber { id, rx, existing: None, home: arc_senders.clone() }
253 }
254
255 pub(crate) fn reserve<R: AsRef<[u8]>>(
256 &self,
257 key: R,
258 ) -> Option<ReservedBroadcast> {
259 if !self.ever_used.load(Relaxed) {
260 return None;
261 }
262
263 let r_mu = self.watched.read();
264 let prefixes = r_mu.iter().filter(|(k, _)| key.as_ref().starts_with(k));
265
266 let mut subscribers = vec![];
267
268 for (_, subs_rwl) in prefixes {
269 let subs = subs_rwl.read();
270
271 for (_id, (waker, sender)) in subs.iter() {
272 let (tx, rx) = OneShot::pair();
273 if sender.send(rx).is_err() {
274 continue;
275 }
276 subscribers.push((waker.clone(), tx));
277 }
278 }
279
280 if subscribers.is_empty() {
281 None
282 } else {
283 Some(ReservedBroadcast { subscribers })
284 }
285 }
286}
287
288pub(crate) struct ReservedBroadcast {
289 subscribers: Vec<(Option<Waker>, OneShotFiller<Option<Event>>)>,
290}
291
292impl ReservedBroadcast {
293 pub fn complete(self, event: &Event) {
294 let iter = self.subscribers.into_iter();
295
296 for (waker, tx) in iter {
297 tx.fill(Some(event.clone()));
298 if let Some(waker) = waker {
299 waker.wake();
300 }
301 }
302 }
303}
304
305#[test]
306fn basic_subscriber() {
307 let subs = Subscribers::default();
308
309 let mut s2 = subs.register(&[0]);
310 let mut s3 = subs.register(&[0, 1]);
311 let mut s4 = subs.register(&[1, 2]);
312
313 let r1 = subs.reserve(b"awft");
314 assert!(r1.is_none());
315
316 let mut s1 = subs.register(&[]);
317
318 let k2: IVec = vec![].into();
319 let r2 = subs.reserve(&k2).unwrap();
320 r2.complete(&Event::Insert { key: k2.clone(), value: k2.clone() });
321
322 let k3: IVec = vec![0].into();
323 let r3 = subs.reserve(&k3).unwrap();
324 r3.complete(&Event::Insert { key: k3.clone(), value: k3.clone() });
325
326 let k4: IVec = vec![0, 1].into();
327 let r4 = subs.reserve(&k4).unwrap();
328 r4.complete(&Event::Remove { key: k4.clone() });
329
330 let k5: IVec = vec![0, 1, 2].into();
331 let r5 = subs.reserve(&k5).unwrap();
332 r5.complete(&Event::Insert { key: k5.clone(), value: k5.clone() });
333
334 let k6: IVec = vec![1, 1, 2].into();
335 let r6 = subs.reserve(&k6).unwrap();
336 r6.complete(&Event::Remove { key: k6.clone() });
337
338 let k7: IVec = vec![1, 1, 2].into();
339 let r7 = subs.reserve(&k7).unwrap();
340 drop(r7);
341
342 let k8: IVec = vec![1, 2, 2].into();
343 let r8 = subs.reserve(&k8).unwrap();
344 r8.complete(&Event::Insert { key: k8.clone(), value: k8.clone() });
345
346 assert_eq!(s1.next().unwrap().key(), &*k2);
347 assert_eq!(s1.next().unwrap().key(), &*k3);
348 assert_eq!(s1.next().unwrap().key(), &*k4);
349 assert_eq!(s1.next().unwrap().key(), &*k5);
350 assert_eq!(s1.next().unwrap().key(), &*k6);
351 assert_eq!(s1.next().unwrap().key(), &*k8);
352
353 assert_eq!(s2.next().unwrap().key(), &*k3);
354 assert_eq!(s2.next().unwrap().key(), &*k4);
355 assert_eq!(s2.next().unwrap().key(), &*k5);
356
357 assert_eq!(s3.next().unwrap().key(), &*k4);
358 assert_eq!(s3.next().unwrap().key(), &*k5);
359
360 assert_eq!(s4.next().unwrap().key(), &*k8);
361}