sled/
concurrency_control.rs

1#[cfg(feature = "testing")]
2use std::cell::RefCell;
3use std::sync::atomic::AtomicBool;
4
5use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
6
7use super::*;
8
9#[cfg(feature = "testing")]
10thread_local! {
11    pub static COUNT: RefCell<u32> = RefCell::new(0);
12}
13
14const RW_REQUIRED_BIT: usize = 1 << 31;
15
16#[derive(Default)]
17pub(crate) struct ConcurrencyControl {
18    active: AtomicUsize,
19    upgrade_complete: AtomicBool,
20    rw: RwLock<()>,
21}
22
23static CONCURRENCY_CONTROL: Lazy<
24    ConcurrencyControl,
25    fn() -> ConcurrencyControl,
26> = Lazy::new(init_cc);
27
28fn init_cc() -> ConcurrencyControl {
29    ConcurrencyControl::default()
30}
31
32#[derive(Debug)]
33#[must_use]
34pub(crate) enum Protector<'a> {
35    Write(RwLockWriteGuard<'a, ()>),
36    Read(RwLockReadGuard<'a, ()>),
37    None(&'a AtomicUsize),
38}
39
40impl<'a> Drop for Protector<'a> {
41    fn drop(&mut self) {
42        if let Protector::None(active) = self {
43            active.fetch_sub(1, Release);
44        }
45        #[cfg(feature = "testing")]
46        COUNT.with(|c| {
47            let mut c = c.borrow_mut();
48            *c -= 1;
49            assert_eq!(*c, 0);
50        });
51    }
52}
53
54pub(crate) fn read<'a>() -> Protector<'a> {
55    CONCURRENCY_CONTROL.read()
56}
57
58pub(crate) fn write<'a>() -> Protector<'a> {
59    CONCURRENCY_CONTROL.write()
60}
61
62impl ConcurrencyControl {
63    fn enable(&self) {
64        if self.active.fetch_or(RW_REQUIRED_BIT, SeqCst) < RW_REQUIRED_BIT {
65            // we are the first to set this bit
66            while self.active.load(Acquire) != RW_REQUIRED_BIT {
67                std::sync::atomic::spin_loop_hint()
68            }
69            self.upgrade_complete.store(true, Release);
70        }
71    }
72
73    fn read(&self) -> Protector<'_> {
74        #[cfg(feature = "testing")]
75        COUNT.with(|c| {
76            let mut c = c.borrow_mut();
77            *c += 1;
78            assert_eq!(*c, 1);
79        });
80
81        let active = self.active.fetch_add(1, Release);
82
83        if active >= RW_REQUIRED_BIT {
84            self.active.fetch_sub(1, Release);
85            Protector::Read(self.rw.read())
86        } else {
87            Protector::None(&self.active)
88        }
89    }
90
91    fn write(&self) -> Protector<'_> {
92        #[cfg(feature = "testing")]
93        COUNT.with(|c| {
94            let mut c = c.borrow_mut();
95            *c += 1;
96            assert_eq!(*c, 1);
97        });
98        self.enable();
99        while !self.upgrade_complete.load(Acquire) {
100            std::sync::atomic::spin_loop_hint()
101        }
102        Protector::Write(self.rw.write())
103    }
104}