sled/
lru.rs

1#![allow(unsafe_code)]
2
3use std::convert::TryFrom;
4use std::mem::MaybeUninit;
5use std::ptr;
6use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
7
8use crate::{
9    atomic_shim::AtomicU64,
10    debug_delay,
11    dll::{DoublyLinkedList, Node},
12    fastlock::FastLock,
13    Guard, PageId,
14};
15
16#[cfg(any(test, feature = "lock_free_delays"))]
17const MAX_QUEUE_ITEMS: usize = 4;
18
19#[cfg(not(any(test, feature = "lock_free_delays")))]
20const MAX_QUEUE_ITEMS: usize = 64;
21
22#[cfg(any(test, feature = "lock_free_delays"))]
23const N_SHARDS: usize = 2;
24
25#[cfg(not(any(test, feature = "lock_free_delays")))]
26const N_SHARDS: usize = 256;
27
28struct AccessBlock {
29    len: AtomicUsize,
30    block: [AtomicU64; MAX_QUEUE_ITEMS],
31    next: AtomicPtr<AccessBlock>,
32}
33
34impl Default for AccessBlock {
35    fn default() -> AccessBlock {
36        AccessBlock {
37            len: AtomicUsize::new(0),
38            block: unsafe { MaybeUninit::zeroed().assume_init() },
39            next: AtomicPtr::default(),
40        }
41    }
42}
43
44struct AccessQueue {
45    writing: AtomicPtr<AccessBlock>,
46    full_list: AtomicPtr<AccessBlock>,
47}
48
49impl Default for AccessQueue {
50    fn default() -> AccessQueue {
51        AccessQueue {
52            writing: AtomicPtr::new(Box::into_raw(Box::new(
53                AccessBlock::default(),
54            ))),
55            full_list: AtomicPtr::default(),
56        }
57    }
58}
59
60impl AccessQueue {
61    fn push(&self, item: CacheAccess) -> bool {
62        let mut filled = false;
63        loop {
64            debug_delay();
65            let head = self.writing.load(Ordering::Acquire);
66            let block = unsafe { &*head };
67
68            debug_delay();
69            let offset = block.len.fetch_add(1, Ordering::Release);
70
71            if offset < MAX_QUEUE_ITEMS {
72                debug_delay();
73                unsafe {
74                    block
75                        .block
76                        .get_unchecked(offset)
77                        .store(item.0, Ordering::Release);
78                }
79                return filled;
80            } else {
81                // install new writer
82                let new = Box::into_raw(Box::new(AccessBlock::default()));
83                debug_delay();
84                let prev =
85                    self.writing.compare_and_swap(head, new, Ordering::Release);
86                if prev != head {
87                    // we lost the CAS, free the new item that was
88                    // never published to other threads
89                    unsafe {
90                        drop(Box::from_raw(new));
91                    }
92                    continue;
93                }
94
95                // push the now-full item to the full list for future
96                // consumption
97                let mut ret;
98                let mut full_list_ptr = self.full_list.load(Ordering::Acquire);
99                while {
100                    // we loop because maybe other threads are pushing stuff too
101                    block.next.store(full_list_ptr, Ordering::Release);
102                    debug_delay();
103                    ret = self.full_list.compare_and_swap(
104                        full_list_ptr,
105                        head,
106                        Ordering::Release,
107                    );
108                    ret != full_list_ptr
109                } {
110                    full_list_ptr = ret;
111                }
112                filled = true;
113            }
114        }
115    }
116
117    fn take<'a>(&self, guard: &'a Guard) -> CacheAccessIter<'a> {
118        debug_delay();
119        let ptr = self.full_list.swap(std::ptr::null_mut(), Ordering::AcqRel);
120
121        CacheAccessIter { guard, current_offset: 0, current_block: ptr }
122    }
123}
124
125impl Drop for AccessQueue {
126    fn drop(&mut self) {
127        debug_delay();
128        let writing = self.writing.load(Ordering::Acquire);
129        unsafe {
130            Box::from_raw(writing);
131        }
132        debug_delay();
133        let mut head = self.full_list.load(Ordering::Acquire);
134        while !head.is_null() {
135            unsafe {
136                debug_delay();
137                let next =
138                    (*head).next.swap(std::ptr::null_mut(), Ordering::Release);
139                Box::from_raw(head);
140                head = next;
141            }
142        }
143    }
144}
145
146struct CacheAccessIter<'a> {
147    guard: &'a Guard,
148    current_offset: usize,
149    current_block: *mut AccessBlock,
150}
151
152impl<'a> Iterator for CacheAccessIter<'a> {
153    type Item = CacheAccess;
154
155    fn next(&mut self) -> Option<CacheAccess> {
156        while !self.current_block.is_null() {
157            let current_block = unsafe { &*self.current_block };
158
159            debug_delay();
160            if self.current_offset >= MAX_QUEUE_ITEMS {
161                let to_drop_ptr = self.current_block;
162                debug_delay();
163                self.current_block = current_block.next.load(Ordering::Acquire);
164                self.current_offset = 0;
165                debug_delay();
166                let to_drop = unsafe { Box::from_raw(to_drop_ptr) };
167                self.guard.defer(|| to_drop);
168                continue;
169            }
170
171            let mut next = 0;
172            while next == 0 {
173                // we spin here because there's a race between bumping
174                // the offset and setting the value to something other
175                // than 0 (and 0 is an invalid value)
176                debug_delay();
177                next = current_block.block[self.current_offset]
178                    .load(Ordering::Acquire);
179            }
180            self.current_offset += 1;
181            return Some(CacheAccess(next));
182        }
183
184        None
185    }
186}
187
188#[derive(Clone, Copy)]
189struct CacheAccess(u64);
190
191impl CacheAccess {
192    fn new(pid: PageId, sz: u64) -> CacheAccess {
193        let rounded_up_power_of_2 =
194            u64::from(sz.next_power_of_two().trailing_zeros());
195
196        assert!(rounded_up_power_of_2 < 256);
197
198        CacheAccess(pid | (rounded_up_power_of_2 << 56))
199    }
200
201    const fn decompose(self) -> (PageId, u64) {
202        let sz = 1 << (self.0 >> 56);
203        let pid = self.0 << 8 >> 8;
204        (pid, sz)
205    }
206}
207
208/// A simple LRU cache.
209pub struct Lru {
210    shards: Vec<(AccessQueue, FastLock<Shard>)>,
211}
212
213unsafe impl Sync for Lru {}
214
215impl Lru {
216    /// Instantiates a new `Lru` cache.
217    pub(crate) fn new(cache_capacity: u64) -> Self {
218        assert!(
219            cache_capacity >= 256,
220            "Please configure the cache \
221             capacity to be at least 256 bytes"
222        );
223        let shard_capacity = cache_capacity / N_SHARDS as u64;
224
225        let mut shards = Vec::with_capacity(N_SHARDS);
226        shards.resize_with(N_SHARDS, || {
227            (AccessQueue::default(), FastLock::new(Shard::new(shard_capacity)))
228        });
229
230        Self { shards }
231    }
232
233    /// Called when an item is accessed. Returns a Vec of items to be
234    /// evicted. Uses flat-combining to avoid blocking on what can
235    /// be an asynchronous operation.
236    ///
237    /// layout:
238    ///   items:   1 2 3 4 5 6 7 8 9 10
239    ///   shards:  1 0 1 0 1 0 1 0 1 0
240    ///   shard 0:   2   4   6   8   10
241    ///   shard 1: 1   3   5   7   9
242    pub(crate) fn accessed(
243        &self,
244        id: PageId,
245        item_size: u64,
246        guard: &Guard,
247    ) -> Vec<PageId> {
248        let mut ret = vec![];
249        let shards = self.shards.len() as u64;
250        let (shard_idx, item_pos) = (id % shards, id / shards);
251        let (stack, shard_mu) = &self.shards[safe_usize(shard_idx)];
252
253        let filled = stack.push(CacheAccess::new(item_pos, item_size));
254
255        if filled {
256            // only try to acquire this if
257            if let Some(mut shard) = shard_mu.try_lock() {
258                let accesses = stack.take(guard);
259                for item in accesses {
260                    let (item_pos, item_size) = item.decompose();
261                    let to_evict =
262                        shard.accessed(safe_usize(item_pos), item_size);
263                    // map shard internal offsets to global items ids
264                    for pos in to_evict {
265                        let item = (pos * shards) + shard_idx;
266                        ret.push(item);
267                    }
268                }
269            }
270        }
271        ret
272    }
273}
274
275#[derive(Clone)]
276struct Entry {
277    ptr: *mut Node,
278    size: u64,
279}
280
281impl Default for Entry {
282    fn default() -> Self {
283        Self { ptr: ptr::null_mut(), size: 0 }
284    }
285}
286
287struct Shard {
288    list: DoublyLinkedList,
289    entries: Vec<Entry>,
290    capacity: u64,
291    size: u64,
292}
293
294impl Shard {
295    fn new(capacity: u64) -> Self {
296        assert!(capacity > 0, "shard capacity must be non-zero");
297
298        Self {
299            list: DoublyLinkedList::default(),
300            entries: vec![],
301            capacity,
302            size: 0,
303        }
304    }
305
306    /// `PageId`s in the shard list are indexes of the entries.
307    fn accessed(&mut self, pos: usize, size: u64) -> Vec<PageId> {
308        if pos >= self.entries.len() {
309            self.entries.resize(pos + 1, Entry::default());
310        }
311
312        {
313            let entry = &mut self.entries[pos];
314
315            self.size -= entry.size;
316            entry.size = size;
317            self.size += size;
318
319            if entry.ptr.is_null() {
320                entry.ptr = self.list.push_head(PageId::try_from(pos).unwrap());
321            } else {
322                entry.ptr = self.list.promote(entry.ptr);
323            }
324        }
325
326        let mut to_evict = vec![];
327        while self.size > self.capacity {
328            if self.list.len() == 1 {
329                // don't evict what we just added
330                break;
331            }
332
333            let min_pid = self.list.pop_tail().unwrap();
334            let min_pid_idx = safe_usize(min_pid);
335
336            self.entries[min_pid_idx].ptr = ptr::null_mut();
337
338            to_evict.push(min_pid);
339
340            self.size -= self.entries[min_pid_idx].size;
341            self.entries[min_pid_idx].size = 0;
342        }
343
344        to_evict
345    }
346}
347
348#[inline]
349fn safe_usize(value: PageId) -> usize {
350    usize::try_from(value).unwrap()
351}
352
353#[test]
354fn lru_smoke_test() {
355    use crate::pin;
356
357    let lru = Lru::new(256);
358    for i in 0..1000 {
359        let guard = pin();
360        lru.accessed(i, 16, &guard);
361    }
362}