p2p_chat/web/websocket.rs
1//! This module handles WebSocket connections for the web UI.
2use axum::{
3 extract::{ws::WebSocket, State, WebSocketUpgrade},
4 response::Response,
5};
6use futures::{sink::SinkExt, stream::StreamExt};
7use serde::Serialize;
8use std::sync::Arc;
9use tokio::sync::broadcast;
10use tracing::{debug, error};
11
12/// Represents messages that can be sent over the WebSocket to the web UI.
13#[derive(Serialize, Clone)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum WebSocketMessage {
16 /// A new chat message has been received or sent.
17 NewMessage {
18 id: String,
19 sender: String,
20 recipient: String,
21 content: String,
22 timestamp: i64,
23 nonce: u64,
24 delivery_status: String,
25 },
26 /// A peer has connected to the network.
27 PeerConnected {
28 peer_id: String,
29 },
30 /// A peer has disconnected from the network.
31 PeerDisconnected {
32 peer_id: String,
33 },
34 /// The delivery status of a message has been updated.
35 DeliveryStatusUpdate {
36 message_id: String,
37 new_status: String,
38 },
39}
40
41/// The state shared across WebSocket connections.
42pub struct WebSocketState {
43 /// A broadcast sender for distributing messages to all connected WebSocket clients.
44 pub broadcast_tx: broadcast::Sender<WebSocketMessage>,
45}
46
47/// Handles the WebSocket upgrade request.
48///
49/// This function is an Axum handler that takes a `WebSocketUpgrade` and
50/// a `WebSocketState`, then upgrades the connection to a WebSocket and
51/// spawns a task to handle the socket.
52///
53/// # Arguments
54///
55/// * `ws` - The `WebSocketUpgrade` extractor.
56/// * `State(state)` - The shared `WebSocketState`.
57///
58/// # Returns
59///
60/// An Axum `Response`.
61pub async fn ws_handler(
62 ws: WebSocketUpgrade,
63 State(state): State<Arc<WebSocketState>>,
64) -> Response {
65 ws.on_upgrade(|socket| handle_socket(socket, state))
66}
67
68/// Handles a single WebSocket connection.
69///
70/// This asynchronous function manages sending messages from a broadcast channel
71/// to the client and handles incoming messages from the client (e.g., pings, close).
72///
73/// # Arguments
74///
75/// * `socket` - The established `WebSocket` connection.
76/// * `state` - The shared `WebSocketState`.
77async fn handle_socket(socket: WebSocket, state: Arc<WebSocketState>) {
78 let (sender, mut receiver) = socket.split();
79
80 // Subscribe to broadcast channel.
81 let mut broadcast_rx = state.broadcast_tx.subscribe();
82
83 // Task for sending messages from broadcast channel to WebSocket client.
84 let send_task = tokio::spawn(async move {
85 let mut sender = sender;
86 loop {
87 match broadcast_rx.recv().await {
88 Ok(msg) => {
89 let json = match serde_json::to_string(&msg) {
90 Ok(j) => j,
91 Err(e) => {
92 error!("Failed to serialize WebSocket message: {}", e);
93 continue;
94 }
95 };
96
97 if sender
98 .send(axum::extract::ws::Message::Text(json))
99 .await
100 .is_err()
101 {
102 break;
103 }
104 }
105 Err(broadcast::error::RecvError::Lagged(n)) => {
106 error!("WebSocket lagged by {} messages", n);
107 continue;
108 }
109 Err(broadcast::error::RecvError::Closed) => {
110 break;
111 }
112 }
113 }
114 });
115
116 // Handle incoming WebSocket messages (ping/pong, close).
117 while let Some(msg) = receiver.next().await {
118 match msg {
119 Ok(axum::extract::ws::Message::Close(_)) => {
120 debug!("WebSocket client disconnected");
121 break;
122 }
123 Ok(axum::extract::ws::Message::Ping(_)) => {
124 // Currently, pings are acknowledged implicitly by tokio-tungstenite.
125 // Explicit Pong response is not needed here.
126 }
127 Err(e) => {
128 error!("WebSocket error: {}", e);
129 break;
130 }
131 _ => {}
132 }
133 }
134
135 // Abort the send task when the receive loop ends (client disconnected or error).
136 send_task.abort();
137}