netlink_proto/protocol/
protocol.rs1use std::{
4 collections::{hash_map, HashMap, VecDeque},
5 fmt::Debug,
6};
7
8use netlink_packet_core::{
9 constants::*, NetlinkDeserializable, NetlinkMessage, NetlinkPayload,
10 NetlinkSerializable,
11};
12
13use super::Request;
14use crate::sys::SocketAddr;
15
16#[derive(Debug, Eq, PartialEq, Hash)]
17struct RequestId {
18 sequence_number: u32,
19 port: u32,
20}
21
22impl RequestId {
23 fn new(sequence_number: u32, port: u32) -> Self {
24 Self {
25 sequence_number,
26 port,
27 }
28 }
29}
30
31#[derive(Debug, Eq, PartialEq)]
32pub(crate) struct Response<T, M> {
33 pub done: bool,
34 pub message: NetlinkMessage<T>,
35 pub metadata: M,
36}
37
38#[derive(Debug)]
39struct PendingRequest<M> {
40 expecting_ack: bool,
41 metadata: M,
42}
43
44#[derive(Debug, Default)]
45pub(crate) struct Protocol<T, M> {
46 sequence_id: u32,
48
49 pending_requests: HashMap<RequestId, PendingRequest<M>>,
52
53 pub incoming_responses: VecDeque<Response<T, M>>,
55
56 pub incoming_requests: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
58
59 pub outgoing_messages: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
61}
62
63impl<T, M> Protocol<T, M>
64where
65 T: Debug + NetlinkSerializable + NetlinkDeserializable,
66 M: Debug + Clone,
67{
68 pub fn new() -> Self {
69 Self {
70 sequence_id: 0,
71 pending_requests: HashMap::new(),
72 incoming_responses: VecDeque::new(),
73 incoming_requests: VecDeque::new(),
74 outgoing_messages: VecDeque::new(),
75 }
76 }
77
78 pub fn handle_message(
79 &mut self,
80 message: NetlinkMessage<T>,
81 source: SocketAddr,
82 ) {
83 let request_id = RequestId::new(
84 message.header.sequence_number,
85 source.port_number(),
86 );
87 trace!("handling messages (request id = {:?})", request_id);
88 if let hash_map::Entry::Occupied(entry) =
89 self.pending_requests.entry(request_id)
90 {
91 Self::handle_response(&mut self.incoming_responses, entry, message);
92 } else {
93 self.incoming_requests.push_back((message, source));
94 }
95 }
96
97 fn handle_response(
98 incoming_responses: &mut VecDeque<Response<T, M>>,
99 entry: hash_map::OccupiedEntry<RequestId, PendingRequest<M>>,
100 message: NetlinkMessage<T>,
101 ) {
102 let entry_key;
103 let mut request_id = entry.key();
104 trace!("handling response to request {:?}", request_id);
105
106 let done = match message.payload {
110 NetlinkPayload::InnerMessage(_)
111 if message.header.flags & NLM_F_MULTIPART
112 == NLM_F_MULTIPART =>
113 {
114 false
115 }
116 NetlinkPayload::InnerMessage(_) => !entry.get().expecting_ack,
117 _ => true,
118 };
119
120 let metadata = if done {
121 trace!("request {:?} fully processed", request_id);
122 let (k, v) = entry.remove_entry();
123 entry_key = k;
124 request_id = &entry_key;
125 v.metadata
126 } else {
127 trace!("more responses to request {:?} may come", request_id);
128 entry.get().metadata.clone()
129 };
130
131 let response = Response::<T, M> {
132 done,
133 message,
134 metadata,
135 };
136 incoming_responses.push_back(response);
137 trace!("done handling response to request {:?}", request_id);
138 }
139
140 pub fn request(&mut self, request: Request<T, M>) {
141 let Request {
142 mut message,
143 metadata,
144 destination,
145 } = request;
146
147 self.set_sequence_id(&mut message);
148 let request_id =
149 RequestId::new(self.sequence_id, destination.port_number());
150 let flags = message.header.flags;
151 self.outgoing_messages.push_back((message, destination));
152
153 let expecting_ack = flags & NLM_F_ACK == NLM_F_ACK;
161 if flags & NLM_F_REQUEST == NLM_F_REQUEST
162 || flags & NLM_F_ECHO == NLM_F_ECHO
163 || expecting_ack
164 {
165 self.pending_requests.insert(
166 request_id,
167 PendingRequest {
168 expecting_ack,
169 metadata,
170 },
171 );
172 }
173 }
174
175 fn set_sequence_id(&mut self, message: &mut NetlinkMessage<T>) {
176 self.sequence_id += 1;
177 message.header.sequence_number = self.sequence_id;
178 }
179}