p2p_chat/storage/
history.rs

1//! This module defines the storage interface and implementation for managing
2//! the message history.
3use crate::crypto::StorageEncryption;
4use crate::types::Message;
5use anyhow::Result;
6use async_trait::async_trait;
7use libp2p::PeerId;
8use sled::Db;
9
10/// A trait for storing and retrieving messages.
11#[async_trait]
12pub trait MessageStore {
13    /// Stores a message in the history.
14    ///
15    /// # Arguments
16    ///
17    /// * `msg` - The `Message` to store.
18    ///
19    /// # Errors
20    ///
21    /// This function will return an error if the message cannot be stored.
22    async fn store_message(&self, msg: Message) -> Result<()>;
23
24    /// Retrieves a message by its ID.
25    ///
26    /// # Arguments
27    ///
28    /// * `msg_id` - The `Uuid` of the message to retrieve.
29    ///
30    /// # Returns
31    ///
32    /// An `Option` containing the `Message` if found, otherwise `None`.
33    ///
34    /// # Errors
35    ///
36    /// This function will return an error if the message cannot be retrieved.
37    async fn get_message_by_id(&self, msg_id: &uuid::Uuid) -> Result<Option<Message>>;
38
39    /// Retrieves the message history for a conversation.
40    ///
41    /// Messages are returned in chronological order.
42    ///
43    /// # Arguments
44    ///
45    /// * `own_id` - The `PeerId` of the local user.
46    /// * `peer` - The `PeerId` of the other participant in the conversation.
47    /// * `limit` - The maximum number of messages to retrieve.
48    ///
49    /// # Returns
50    ///
51    /// A `Vec` of `Message`s representing the conversation history.
52    ///
53    /// # Errors
54    ///
55    /// This function will return an error if the history cannot be retrieved.
56    async fn get_history(
57        &self,
58        own_id: &PeerId,
59        peer: &PeerId,
60        limit: usize,
61    ) -> Result<Vec<Message>>;
62
63    /// Retrieves a limited number of the most recent messages from all conversations.
64    ///
65    /// Messages are returned in chronological order, up to the specified limit.
66    ///
67    /// # Arguments
68    ///
69    /// * `own_id` - The `PeerId` of the local user (used for filtering relevant messages).
70    /// * `limit` - The maximum number of recent messages to retrieve.
71    ///
72    /// # Returns
73    ///
74    /// A `Vec` of `Message`s, sorted chronologically.
75    ///
76    /// # Errors
77    ///
78    /// This function will return an error if the messages cannot be retrieved.
79    async fn get_recent_messages(&self, own_id: &PeerId, limit: usize) -> Result<Vec<Message>>;
80
81    /// Retrieves messages before a specific message in a conversation.
82    ///
83    /// Messages are returned in chronological order.
84    ///
85    /// # Arguments
86    ///
87    /// * `own_id` - The `PeerId` of the local user.
88    /// * `peer` - The `PeerId` of the other participant in the conversation.
89    /// * `before_id` - The `Uuid` of the message to retrieve messages before.
90    /// * `limit` - The maximum number of messages to retrieve.
91    ///
92    /// # Returns
93    ///
94    /// A `Vec` of `Message`s.
95    ///
96    /// # Errors
97    ///
98    /// This function will return an error if the messages cannot be retrieved.
99    async fn get_messages_before(
100        &self,
101        own_id: &PeerId,
102        peer: &PeerId,
103        before_id: &uuid::Uuid,
104        limit: usize,
105    ) -> Result<Vec<Message>>;
106
107    /// Retrieves messages after a specific message in a conversation.
108    ///
109    /// Messages are returned in chronological order.
110    ///
111    /// # Arguments
112    ///
113    /// * `own_id` - The `PeerId` of the local user.
114    /// * `peer` - The `PeerId` of the other participant in the conversation.
115    /// * `after_id` - The `Uuid` of the message to retrieve messages after.
116    /// * `limit` - The maximum number of messages to retrieve.
117    ///
118    /// # Returns
119    ///
120    /// A `Vec` of `Message`s.
121    ///
122    /// # Errors
123    ///
124    /// This function will return an error if the messages cannot be retrieved.
125    async fn get_messages_after(
126        &self,
127        own_id: &PeerId,
128        peer: &PeerId,
129        after_id: &uuid::Uuid,
130        limit: usize,
131    ) -> Result<Vec<Message>>;
132
133    /// Updates the delivery status of a message.
134    ///
135    /// # Arguments
136    ///
137    /// * `msg_id` - The `Uuid` of the message to update.
138    /// * `status` - The new `DeliveryStatus` for the message.
139    ///
140    /// # Errors
141    ///
142    /// This function will return an error if the status cannot be updated.
143    async fn update_delivery_status(
144        &self,
145        msg_id: &uuid::Uuid,
146        status: crate::types::DeliveryStatus,
147    ) -> Result<()>;
148}
149
150/// A `MessageStore` implementation using `sled` for storage.
151pub struct MessageHistory {
152    tree: sled::Tree,
153    encryption: Option<StorageEncryption>,
154}
155
156impl MessageHistory {
157    /// Creates a new `MessageHistory` store.
158    ///
159    /// # Arguments
160    ///
161    /// * `db` - The `sled::Db` instance to use for storage.
162    /// * `encryption` - Optional `StorageEncryption` for encrypting message data.
163    ///
164    /// # Errors
165    ///
166    /// Returns an error if the underlying `sled` tree cannot be opened.
167    pub fn new(db: Db, encryption: Option<StorageEncryption>) -> Result<Self> {
168        let tree = db.open_tree("history")?;
169        Ok(Self { tree, encryption })
170    }
171
172    /// Creates a canonical, ordered conversation ID from two `PeerId`s.
173    ///
174    /// This ensures that the conversation ID is always the same regardless of
175    /// the order of the `PeerId`s.
176    fn get_conversation_id(p1: &PeerId, p2: &PeerId) -> Vec<u8> {
177        let mut p1_bytes = p1.to_bytes();
178        let mut p2_bytes = p2.to_bytes();
179
180        if p1_bytes > p2_bytes {
181            std::mem::swap(&mut p1_bytes, &mut p2_bytes);
182        }
183
184        [p1_bytes, p2_bytes].concat()
185    }
186
187    /// Creates a composite key for storing a message, based on conversation ID, timestamp, and nonce.
188    fn make_composite_key(conversation_id: &[u8], timestamp: i64, nonce: u64) -> Vec<u8> {
189        let mut key = Vec::new();
190        key.extend_from_slice(conversation_id);
191        key.extend_from_slice(&timestamp.to_be_bytes());
192        key.extend_from_slice(&nonce.to_be_bytes());
193        key
194    }
195
196    /// Serializes a `Message` and encrypts it if encryption is enabled.
197    fn serialize_message(&self, msg: &Message) -> Result<Vec<u8>> {
198        let serialized = serde_json::to_vec(msg)?;
199
200        if let Some(ref encryption) = self.encryption {
201            encryption.encrypt_value(&serialized)
202        } else {
203            Ok(serialized)
204        }
205    }
206
207    /// Decrypts and deserializes a `Message`.
208    fn deserialize_message(&self, data: &[u8]) -> Result<Message> {
209        let decrypted = if let Some(ref encryption) = self.encryption {
210            encryption.decrypt_value(data)?
211        } else {
212            data.to_vec()
213        };
214
215        Ok(serde_json::from_slice(&decrypted)?)
216    }
217}
218
219#[async_trait]
220impl MessageStore for MessageHistory {
221    async fn store_message(&self, msg: Message) -> Result<()> {
222        let conversation_id = Self::get_conversation_id(&msg.sender, &msg.recipient);
223        let key = Self::make_composite_key(&conversation_id, msg.timestamp, msg.nonce);
224        let value = self.serialize_message(&msg)?;
225
226        self.tree.insert(key, value)?;
227        self.tree.flush_async().await?;
228        Ok(())
229    }
230
231    async fn get_message_by_id(&self, msg_id: &uuid::Uuid) -> Result<Option<Message>> {
232        // Scan all messages to find the one with the given ID.
233        for result in self.tree.iter() {
234            let (_key, value) = result?;
235            let msg = self.deserialize_message(&value)?;
236
237            if msg.id == *msg_id {
238                return Ok(Some(msg));
239            }
240        }
241
242        Ok(None)
243    }
244
245    async fn get_history(
246        &self,
247        own_id: &PeerId,
248        peer: &PeerId,
249        limit: usize,
250    ) -> Result<Vec<Message>> {
251        let conversation_id = Self::get_conversation_id(own_id, peer);
252        let mut messages = Vec::new();
253
254        // Iterate in reverse to get most recent messages first.
255        for result in self.tree.scan_prefix(&conversation_id).rev().take(limit) {
256            let (_key, value) = result?;
257            messages.push(self.deserialize_message(&value)?);
258        }
259
260        // Reverse again to get chronological order.
261        messages.reverse();
262        Ok(messages)
263    }
264
265    async fn get_recent_messages(&self, own_id: &PeerId, limit: usize) -> Result<Vec<Message>> {
266        let tree = self.tree.clone();
267        let encryption = self.encryption.clone();
268        let own_id = *own_id;
269
270        let mut messages: Vec<Message> =
271            tokio::task::spawn_blocking(move || -> Result<Vec<Message>> {
272                let mut collected = Vec::new();
273                for result in tree.iter() {
274                    let (_key, value) = result?;
275                    let decrypted = if let Some(ref enc) = encryption {
276                        enc.decrypt_value(&value)?
277                    } else {
278                        value.to_vec()
279                    };
280                    let message: Message = serde_json::from_slice(&decrypted)?;
281                    if message.sender == own_id || message.recipient == own_id {
282                        collected.push(message);
283                    }
284                }
285                Ok(collected)
286            })
287            .await??;
288
289        messages.sort_by_key(|msg| (msg.timestamp, msg.nonce));
290
291        if messages.len() > limit {
292            let drop_count = messages.len() - limit;
293            messages.drain(0..drop_count);
294        }
295
296        Ok(messages)
297    }
298
299    async fn get_messages_before(
300        &self,
301        own_id: &PeerId,
302        peer: &PeerId,
303        before_id: &uuid::Uuid,
304        limit: usize,
305    ) -> Result<Vec<Message>> {
306        let conversation_id = Self::get_conversation_id(own_id, peer);
307
308        // First, find the message with before_id to get its timestamp.
309        let mut before_timestamp = None;
310        let mut before_nonce = None;
311
312        for result in self.tree.scan_prefix(&conversation_id) {
313            let (_key, value) = result?;
314            let msg = self.deserialize_message(&value)?;
315            if msg.id == *before_id {
316                before_timestamp = Some(msg.timestamp);
317                before_nonce = Some(msg.nonce);
318                break;
319            }
320        }
321
322        let (before_ts, before_n) = match (before_timestamp, before_nonce) {
323            (Some(ts), Some(n)) => (ts, n),
324            _ => return Ok(Vec::new()), // Message not found.
325        };
326
327        // Collect all messages before this timestamp+nonce.
328        let mut messages = Vec::new();
329        for result in self.tree.scan_prefix(&conversation_id) {
330            let (_key, value) = result?;
331            let msg = self.deserialize_message(&value)?;
332
333            // Include messages that are strictly before (timestamp, nonce).
334            if msg.timestamp < before_ts || (msg.timestamp == before_ts && msg.nonce < before_n) {
335                messages.push(msg);
336            }
337        }
338
339        // Sort by timestamp and nonce, take last N (most recent before the target).
340        messages.sort_by_key(|msg| (msg.timestamp, msg.nonce));
341        if messages.len() > limit {
342            let start_idx = messages.len() - limit;
343            messages.drain(0..start_idx);
344        }
345
346        Ok(messages)
347    }
348
349    async fn get_messages_after(
350        &self,
351        own_id: &PeerId,
352        peer: &PeerId,
353        after_id: &uuid::Uuid,
354        limit: usize,
355    ) -> Result<Vec<Message>> {
356        let conversation_id = Self::get_conversation_id(own_id, peer);
357
358        // First, find the message with after_id to get its timestamp.
359        let mut after_timestamp = None;
360        let mut after_nonce = None;
361
362        for result in self.tree.scan_prefix(&conversation_id) {
363            let (_key, value) = result?;
364            let msg = self.deserialize_message(&value)?;
365            if msg.id == *after_id {
366                after_timestamp = Some(msg.timestamp);
367                after_nonce = Some(msg.nonce);
368                break;
369            }
370        }
371
372        let (after_ts, after_n) = match (after_timestamp, after_nonce) {
373            (Some(ts), Some(n)) => (ts, n),
374            _ => return Ok(Vec::new()), // Message not found.
375        };
376
377        // Collect messages after this timestamp+nonce.
378        let mut messages = Vec::new();
379        for result in self.tree.scan_prefix(&conversation_id) {
380            let (_key, value) = result?;
381            let msg = self.deserialize_message(&value)?;
382
383            // Include messages that are strictly after (timestamp, nonce).
384            if msg.timestamp > after_ts || (msg.timestamp == after_ts && msg.nonce > after_n) {
385                messages.push(msg);
386                if messages.len() >= limit {
387                    break;
388                }
389            }
390        }
391
392        // Sort by timestamp and nonce.
393        messages.sort_by_key(|msg| (msg.timestamp, msg.nonce));
394
395        Ok(messages)
396    }
397
398    async fn update_delivery_status(
399        &self,
400        msg_id: &uuid::Uuid,
401        status: crate::types::DeliveryStatus,
402    ) -> Result<()> {
403        // Scan all messages to find the one with the given ID.
404        for result in self.tree.iter() {
405            let (key, value) = result?;
406            let mut msg = self.deserialize_message(&value)?;
407
408            if msg.id == *msg_id {
409                // Update the delivery status.
410                msg.delivery_status = status;
411
412                // Re-serialize and store.
413                let new_value = self.serialize_message(&msg)?;
414                self.tree.insert(key, new_value)?;
415                self.tree.flush_async().await?;
416                return Ok(());
417            }
418        }
419
420        // Message not found - not necessarily an error, might be old/deleted.
421        Ok(())
422    }
423}