p2p_chat/web/
websocket.rs

1//! This module handles WebSocket connections for the web UI.
2use axum::{
3    extract::{ws::WebSocket, State, WebSocketUpgrade},
4    response::Response,
5};
6use futures::{sink::SinkExt, stream::StreamExt};
7use serde::Serialize;
8use std::sync::Arc;
9use tokio::sync::broadcast;
10use tracing::{debug, error};
11
12/// Represents messages that can be sent over the WebSocket to the web UI.
13#[derive(Serialize, Clone)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum WebSocketMessage {
16    /// A new chat message has been received or sent.
17    NewMessage {
18        id: String,
19        sender: String,
20        recipient: String,
21        content: String,
22        timestamp: i64,
23        nonce: u64,
24        delivery_status: String,
25    },
26    /// A peer has connected to the network.
27    PeerConnected {
28        peer_id: String,
29    },
30    /// A peer has disconnected from the network.
31    PeerDisconnected {
32        peer_id: String,
33    },
34    /// The delivery status of a message has been updated.
35    DeliveryStatusUpdate {
36        message_id: String,
37        new_status: String,
38    },
39}
40
41/// The state shared across WebSocket connections.
42pub struct WebSocketState {
43    /// A broadcast sender for distributing messages to all connected WebSocket clients.
44    pub broadcast_tx: broadcast::Sender<WebSocketMessage>,
45}
46
47/// Handles the WebSocket upgrade request.
48///
49/// This function is an Axum handler that takes a `WebSocketUpgrade` and
50/// a `WebSocketState`, then upgrades the connection to a WebSocket and
51/// spawns a task to handle the socket.
52///
53/// # Arguments
54///
55/// * `ws` - The `WebSocketUpgrade` extractor.
56/// * `State(state)` - The shared `WebSocketState`.
57///
58/// # Returns
59///
60/// An Axum `Response`.
61pub async fn ws_handler(
62    ws: WebSocketUpgrade,
63    State(state): State<Arc<WebSocketState>>,
64) -> Response {
65    ws.on_upgrade(|socket| handle_socket(socket, state))
66}
67
68/// Handles a single WebSocket connection.
69///
70/// This asynchronous function manages sending messages from a broadcast channel
71/// to the client and handles incoming messages from the client (e.g., pings, close).
72///
73/// # Arguments
74///
75/// * `socket` - The established `WebSocket` connection.
76/// * `state` - The shared `WebSocketState`.
77async fn handle_socket(socket: WebSocket, state: Arc<WebSocketState>) {
78    let (sender, mut receiver) = socket.split();
79
80    // Subscribe to broadcast channel.
81    let mut broadcast_rx = state.broadcast_tx.subscribe();
82
83    // Task for sending messages from broadcast channel to WebSocket client.
84    let send_task = tokio::spawn(async move {
85        let mut sender = sender;
86        loop {
87            match broadcast_rx.recv().await {
88                Ok(msg) => {
89                    let json = match serde_json::to_string(&msg) {
90                        Ok(j) => j,
91                        Err(e) => {
92                            error!("Failed to serialize WebSocket message: {}", e);
93                            continue;
94                        }
95                    };
96
97                    if sender
98                        .send(axum::extract::ws::Message::Text(json))
99                        .await
100                        .is_err()
101                    {
102                        break;
103                    }
104                }
105                Err(broadcast::error::RecvError::Lagged(n)) => {
106                    error!("WebSocket lagged by {} messages", n);
107                    continue;
108                }
109                Err(broadcast::error::RecvError::Closed) => {
110                    break;
111                }
112            }
113        }
114    });
115
116    // Handle incoming WebSocket messages (ping/pong, close).
117    while let Some(msg) = receiver.next().await {
118        match msg {
119            Ok(axum::extract::ws::Message::Close(_)) => {
120                debug!("WebSocket client disconnected");
121                break;
122            }
123            Ok(axum::extract::ws::Message::Ping(_)) => {
124                // Currently, pings are acknowledged implicitly by tokio-tungstenite.
125                // Explicit Pong response is not needed here.
126            }
127            Err(e) => {
128                error!("WebSocket error: {}", e);
129                break;
130            }
131            _ => {}
132        }
133    }
134
135    // Abort the send task when the receive loop ends (client disconnected or error).
136    send_task.abort();
137}