sled/
stack.rs

1#![allow(unsafe_code)]
2
3use std::{
4    fmt::{self, Debug},
5    ops::Deref,
6    sync::atomic::Ordering::{Acquire, Release},
7};
8
9use crossbeam_epoch::{unprotected, Atomic, Guard, Owned, Shared};
10
11use crate::debug_delay;
12
13/// A node in the lock-free `Stack`.
14#[derive(Debug)]
15pub struct Node<T: Send + 'static> {
16    pub(crate) inner: T,
17    pub(crate) next: Atomic<Node<T>>,
18}
19
20impl<T: Send + 'static> Drop for Node<T> {
21    fn drop(&mut self) {
22        unsafe {
23            let mut cursor = self.next.load(Acquire, unprotected());
24
25            while !cursor.is_null() {
26                // we carefully unset the next pointer here to avoid
27                // a stack overflow when freeing long lists.
28                let node = cursor.into_owned();
29                cursor = node.next.swap(Shared::null(), Acquire, unprotected());
30                drop(node);
31            }
32        }
33    }
34}
35
36/// A simple lock-free stack, with the ability to atomically
37/// append or entirely swap-out entries.
38pub struct Stack<T: Send + 'static> {
39    head: Atomic<Node<T>>,
40}
41
42impl<T: Send + 'static> Default for Stack<T> {
43    fn default() -> Self {
44        Self { head: Atomic::null() }
45    }
46}
47
48impl<T: Send + 'static> Drop for Stack<T> {
49    fn drop(&mut self) {
50        unsafe {
51            let curr = self.head.load(Acquire, unprotected());
52            if !curr.as_raw().is_null() {
53                drop(curr.into_owned());
54            }
55        }
56    }
57}
58
59impl<T> Debug for Stack<T>
60where
61    T: Clone + Debug + Send + 'static + Sync,
62{
63    fn fmt(
64        &self,
65        formatter: &mut fmt::Formatter<'_>,
66    ) -> Result<(), fmt::Error> {
67        let guard = crossbeam_epoch::pin();
68        let head = self.head(&guard);
69        let iter = Iter::from_ptr(head, &guard);
70
71        formatter.write_str("Stack [")?;
72        let mut written = false;
73        for node in iter {
74            if written {
75                formatter.write_str(", ")?;
76            }
77            formatter.write_str(&*format!("({:?}) ", &node))?;
78            node.fmt(formatter)?;
79            written = true;
80        }
81        formatter.write_str("]")?;
82        Ok(())
83    }
84}
85
86impl<T: Send + 'static> Deref for Node<T> {
87    type Target = T;
88    fn deref(&self) -> &T {
89        &self.inner
90    }
91}
92
93impl<T: Send + Sync + 'static> Stack<T> {
94    /// Add an item to the stack, spinning until successful.
95    pub(crate) fn push(&self, inner: T, guard: &Guard) {
96        debug_delay();
97        let node = Owned::new(Node { inner, next: Atomic::null() });
98
99        unsafe {
100            let node = node.into_shared(guard);
101
102            loop {
103                let head = self.head(guard);
104                node.deref().next.store(head, Release);
105                if self.head.compare_and_set(head, node, Release, guard).is_ok()
106                {
107                    return;
108                }
109            }
110        }
111    }
112
113    /// Clears the stack and returns all items
114    pub(crate) fn take_iter<'a>(
115        &self,
116        guard: &'a Guard,
117    ) -> impl Iterator<Item = &'a T> {
118        debug_delay();
119        let node = self.head.swap(Shared::null(), Release, guard);
120
121        let iter = Iter { inner: node, guard };
122
123        if !node.is_null() {
124            unsafe {
125                guard.defer_destroy(node);
126            }
127        }
128
129        iter
130    }
131
132    /// Pop the next item off the stack. Returns None if nothing is there.
133    #[cfg(any(test, feature = "event_log"))]
134    pub(crate) fn pop(&self, guard: &Guard) -> Option<T> {
135        use std::ptr;
136        use std::sync::atomic::Ordering::SeqCst;
137        debug_delay();
138        let mut head = self.head(guard);
139        loop {
140            match unsafe { head.as_ref() } {
141                Some(h) => {
142                    let next = h.next.load(Acquire, guard);
143                    match self.head.compare_and_set(head, next, Release, guard)
144                    {
145                        Ok(_) => unsafe {
146                            // we unset the next pointer before destruction
147                            // to avoid double-frees.
148                            h.next.store(Shared::default(), SeqCst);
149                            guard.defer_destroy(head);
150                            return Some(ptr::read(&h.inner));
151                        },
152                        Err(h) => head = h.current,
153                    }
154                }
155                None => return None,
156            }
157        }
158    }
159
160    /// Returns the current head pointer of the stack, which can
161    /// later be used as the key for cas and cap operations.
162    pub(crate) fn head<'g>(&self, guard: &'g Guard) -> Shared<'g, Node<T>> {
163        self.head.load(Acquire, guard)
164    }
165}
166
167/// An iterator over nodes in a lock-free stack.
168pub struct Iter<'a, T>
169where
170    T: Send + 'static + Sync,
171{
172    inner: Shared<'a, Node<T>>,
173    guard: &'a Guard,
174}
175
176impl<'a, T> Iter<'a, T>
177where
178    T: 'a + Send + 'static + Sync,
179{
180    /// Creates a `Iter` from a pointer to one.
181    pub(crate) fn from_ptr<'b>(
182        ptr: Shared<'b, Node<T>>,
183        guard: &'b Guard,
184    ) -> Iter<'b, T> {
185        Iter { inner: ptr, guard }
186    }
187}
188
189impl<'a, T> Iterator for Iter<'a, T>
190where
191    T: Send + 'static + Sync,
192{
193    type Item = &'a T;
194
195    fn next(&mut self) -> Option<Self::Item> {
196        debug_delay();
197        if self.inner.is_null() {
198            None
199        } else {
200            unsafe {
201                let ret = &self.inner.deref().inner;
202                self.inner = self.inner.deref().next.load(Acquire, self.guard);
203                Some(ret)
204            }
205        }
206    }
207
208    fn size_hint(&self) -> (usize, Option<usize>) {
209        let mut size = 0;
210        let mut cursor = self.inner;
211
212        while !cursor.is_null() {
213            unsafe {
214                cursor = cursor.deref().next.load(Acquire, self.guard);
215            }
216            size += 1;
217        }
218
219        (size, Some(size))
220    }
221}
222
223#[test]
224#[cfg(not(miri))] // can't create threads
225fn basic_functionality() {
226    use crossbeam_epoch::pin;
227    use crossbeam_utils::CachePadded;
228    use std::sync::Arc;
229    use std::thread;
230
231    let guard = pin();
232    let ll = Arc::new(Stack::default());
233    assert_eq!(ll.pop(&guard), None);
234    ll.push(CachePadded::new(1), &guard);
235    let ll2 = Arc::clone(&ll);
236    let t = thread::spawn(move || {
237        let guard = pin();
238        ll2.push(CachePadded::new(2), &guard);
239        ll2.push(CachePadded::new(3), &guard);
240        ll2.push(CachePadded::new(4), &guard);
241        guard.flush();
242    });
243    t.join().unwrap();
244    ll.push(CachePadded::new(5), &guard);
245    assert_eq!(ll.pop(&guard), Some(CachePadded::new(5)));
246    assert_eq!(ll.pop(&guard), Some(CachePadded::new(4)));
247    let ll3 = Arc::clone(&ll);
248    let t = thread::spawn(move || {
249        let guard = pin();
250        assert_eq!(ll3.pop(&guard), Some(CachePadded::new(3)));
251        assert_eq!(ll3.pop(&guard), Some(CachePadded::new(2)));
252        guard.flush();
253    });
254    t.join().unwrap();
255    assert_eq!(ll.pop(&guard), Some(CachePadded::new(1)));
256    let ll4 = Arc::clone(&ll);
257    let t = thread::spawn(move || {
258        let guard = pin();
259        assert_eq!(ll4.pop(&guard), None);
260        guard.flush();
261    });
262    t.join().unwrap();
263    drop(ll);
264    guard.flush();
265    drop(guard);
266}