netlink_proto/protocol/
protocol.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    collections::{hash_map, HashMap, VecDeque},
5    fmt::Debug,
6};
7
8use netlink_packet_core::{
9    constants::*, NetlinkDeserializable, NetlinkMessage, NetlinkPayload,
10    NetlinkSerializable,
11};
12
13use super::Request;
14use crate::sys::SocketAddr;
15
16#[derive(Debug, Eq, PartialEq, Hash)]
17struct RequestId {
18    sequence_number: u32,
19    port: u32,
20}
21
22impl RequestId {
23    fn new(sequence_number: u32, port: u32) -> Self {
24        Self {
25            sequence_number,
26            port,
27        }
28    }
29}
30
31#[derive(Debug, Eq, PartialEq)]
32pub(crate) struct Response<T, M> {
33    pub done: bool,
34    pub message: NetlinkMessage<T>,
35    pub metadata: M,
36}
37
38#[derive(Debug)]
39struct PendingRequest<M> {
40    expecting_ack: bool,
41    metadata: M,
42}
43
44#[derive(Debug, Default)]
45pub(crate) struct Protocol<T, M> {
46    /// Counter that is incremented for each message sent
47    sequence_id: u32,
48
49    /// Requests for which we're awaiting a response. Metadata are
50    /// associated with each request.
51    pending_requests: HashMap<RequestId, PendingRequest<M>>,
52
53    /// Responses to pending requests
54    pub incoming_responses: VecDeque<Response<T, M>>,
55
56    /// Requests from remote peers
57    pub incoming_requests: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
58
59    /// The messages to be sent out
60    pub outgoing_messages: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
61}
62
63impl<T, M> Protocol<T, M>
64where
65    T: Debug + NetlinkSerializable + NetlinkDeserializable,
66    M: Debug + Clone,
67{
68    pub fn new() -> Self {
69        Self {
70            sequence_id: 0,
71            pending_requests: HashMap::new(),
72            incoming_responses: VecDeque::new(),
73            incoming_requests: VecDeque::new(),
74            outgoing_messages: VecDeque::new(),
75        }
76    }
77
78    pub fn handle_message(
79        &mut self,
80        message: NetlinkMessage<T>,
81        source: SocketAddr,
82    ) {
83        let request_id = RequestId::new(
84            message.header.sequence_number,
85            source.port_number(),
86        );
87        trace!("handling messages (request id = {:?})", request_id);
88        if let hash_map::Entry::Occupied(entry) =
89            self.pending_requests.entry(request_id)
90        {
91            Self::handle_response(&mut self.incoming_responses, entry, message);
92        } else {
93            self.incoming_requests.push_back((message, source));
94        }
95    }
96
97    fn handle_response(
98        incoming_responses: &mut VecDeque<Response<T, M>>,
99        entry: hash_map::OccupiedEntry<RequestId, PendingRequest<M>>,
100        message: NetlinkMessage<T>,
101    ) {
102        let entry_key;
103        let mut request_id = entry.key();
104        trace!("handling response to request {:?}", request_id);
105
106        // A request is processed if we receive an Ack, Error,
107        // Done, Overrun, or InnerMessage without the
108        // multipart flag and we were not expecting an Ack
109        let done = match message.payload {
110            NetlinkPayload::InnerMessage(_)
111                if message.header.flags & NLM_F_MULTIPART
112                    == NLM_F_MULTIPART =>
113            {
114                false
115            }
116            NetlinkPayload::InnerMessage(_) => !entry.get().expecting_ack,
117            _ => true,
118        };
119
120        let metadata = if done {
121            trace!("request {:?} fully processed", request_id);
122            let (k, v) = entry.remove_entry();
123            entry_key = k;
124            request_id = &entry_key;
125            v.metadata
126        } else {
127            trace!("more responses to request {:?} may come", request_id);
128            entry.get().metadata.clone()
129        };
130
131        let response = Response::<T, M> {
132            done,
133            message,
134            metadata,
135        };
136        incoming_responses.push_back(response);
137        trace!("done handling response to request {:?}", request_id);
138    }
139
140    pub fn request(&mut self, request: Request<T, M>) {
141        let Request {
142            mut message,
143            metadata,
144            destination,
145        } = request;
146
147        self.set_sequence_id(&mut message);
148        let request_id =
149            RequestId::new(self.sequence_id, destination.port_number());
150        let flags = message.header.flags;
151        self.outgoing_messages.push_back((message, destination));
152
153        // If we expect a response, we store the request id so that we
154        // can map the response to this specific request.
155        //
156        // Note that we expect responses in three cases only:
157        //  - when the request has the NLM_F_REQUEST flag
158        //  - when the request has the NLM_F_ACK flag
159        //  - when the request has the NLM_F_ECHO flag
160        let expecting_ack = flags & NLM_F_ACK == NLM_F_ACK;
161        if flags & NLM_F_REQUEST == NLM_F_REQUEST
162            || flags & NLM_F_ECHO == NLM_F_ECHO
163            || expecting_ack
164        {
165            self.pending_requests.insert(
166                request_id,
167                PendingRequest {
168                    expecting_ack,
169                    metadata,
170                },
171            );
172        }
173    }
174
175    fn set_sequence_id(&mut self, message: &mut NetlinkMessage<T>) {
176        self.sequence_id += 1;
177        message.header.sequence_number = self.sequence_id;
178    }
179}