1use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net};
2use fnv::FnvHashSet;
3use futures::ready;
4use futures::stream::{FusedStream, Stream, TryStreamExt};
5use futures::StreamExt;
6use netlink_packet_core::NetlinkPayload;
7use netlink_packet_route::rtnl::address::nlas::Nla;
8use netlink_packet_route::rtnl::{AddressMessage, RtnlMessage};
9use netlink_proto::Connection;
10use netlink_sys::{AsyncSocket, SocketAddr};
11use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
12use std::collections::VecDeque;
13use std::future::Future;
14use std::io::{Error, ErrorKind, Result};
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19#[cfg(feature = "tokio")]
20pub mod tokio {
21 use netlink_sys::TokioSocket;
23
24 pub type IfWatcher = super::IfWatcher<TokioSocket>;
26}
27
28#[cfg(feature = "smol")]
29pub mod smol {
30 use netlink_sys::SmolSocket;
32
33 pub type IfWatcher = super::IfWatcher<SmolSocket>;
35}
36
37pub struct IfWatcher<T> {
38 conn: Connection<RtnlMessage, T>,
39 messages: Pin<Box<dyn Stream<Item = Result<RtnlMessage>> + Send>>,
40 addrs: FnvHashSet<IpNet>,
41 queue: VecDeque<IfEvent>,
42}
43
44impl<T> std::fmt::Debug for IfWatcher<T> {
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 f.debug_struct("IfWatcher")
47 .field("addrs", &self.addrs)
48 .finish_non_exhaustive()
49 }
50}
51
52impl<T> IfWatcher<T>
53where
54 T: AsyncSocket + Unpin,
55{
56 pub fn new() -> Result<Self> {
58 let (mut conn, handle, messages) = rtnetlink::new_connection_with_socket::<T>()?;
59 let groups = RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR;
60 let addr = SocketAddr::new(0, groups);
61 conn.socket_mut().socket_mut().bind(&addr)?;
62 let get_addrs_stream = handle
63 .address()
64 .get()
65 .execute()
66 .map_ok(RtnlMessage::NewAddress)
67 .map_err(|err| Error::new(ErrorKind::Other, err));
68 let msg_stream = messages.filter_map(|(msg, _)| async {
69 match msg.payload {
70 NetlinkPayload::Error(err) => Some(Err(err.to_io())),
71 NetlinkPayload::InnerMessage(msg) => Some(Ok(msg)),
72 _ => None,
73 }
74 });
75 let messages = get_addrs_stream.chain(msg_stream).boxed();
76 let addrs = FnvHashSet::default();
77 let queue = VecDeque::default();
78 Ok(Self {
79 conn,
80 messages,
81 addrs,
82 queue,
83 })
84 }
85
86 pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
88 self.addrs.iter()
89 }
90
91 fn add_address(&mut self, msg: AddressMessage) {
92 for net in iter_nets(msg) {
93 if self.addrs.insert(net) {
94 self.queue.push_back(IfEvent::Up(net));
95 }
96 }
97 }
98
99 fn rem_address(&mut self, msg: AddressMessage) {
100 for net in iter_nets(msg) {
101 if self.addrs.remove(&net) {
102 self.queue.push_back(IfEvent::Down(net));
103 }
104 }
105 }
106
107 pub fn poll_if_event(&mut self, cx: &mut Context) -> Poll<Result<IfEvent>> {
109 loop {
110 if let Some(event) = self.queue.pop_front() {
111 return Poll::Ready(Ok(event));
112 }
113 if Pin::new(&mut self.conn).poll(cx).is_ready() {
114 return Poll::Ready(Err(socket_err(
115 "rtnetlink socket closed. Connection has been terminated.",
116 )));
117 }
118 let message = match ready!(self.messages.poll_next_unpin(cx)) {
119 Some(Ok(message)) => message,
120 Some(Err(error)) => {
121 return Poll::Ready(Err(socket_err(&format!(
122 "rtnetlink socket closed. {error}"
123 ))));
124 }
125 None => {
126 return Poll::Ready(Err(socket_err(
127 "rtnetlink socket closed. Empty message has been returned.",
128 )));
129 }
130 };
131 match message {
132 RtnlMessage::NewAddress(msg) => self.add_address(msg),
133 RtnlMessage::DelAddress(msg) => self.rem_address(msg),
134 _ => {}
135 }
136 }
137 }
138}
139
140fn socket_err(error: &str) -> std::io::Error {
141 std::io::Error::new(ErrorKind::BrokenPipe, error)
142}
143
144fn iter_nets(msg: AddressMessage) -> impl Iterator<Item = IpNet> {
145 let prefix = msg.header.prefix_len;
146 let family = msg.header.family;
147 msg.nlas.into_iter().filter_map(move |nla| {
148 if let Nla::Address(octets) = nla {
149 match family {
150 2 => {
151 let mut addr = [0; 4];
152 addr.copy_from_slice(&octets);
153 Some(IpNet::V4(
154 Ipv4Net::new(Ipv4Addr::from(addr), prefix).unwrap(),
155 ))
156 }
157 10 => {
158 let mut addr = [0; 16];
159 addr.copy_from_slice(&octets);
160 Some(IpNet::V6(
161 Ipv6Net::new(Ipv6Addr::from(addr), prefix).unwrap(),
162 ))
163 }
164 _ => None,
165 }
166 } else {
167 None
168 }
169 })
170}
171
172impl<T> Stream for IfWatcher<T>
173where
174 T: AsyncSocket + Unpin,
175{
176 type Item = Result<IfEvent>;
177 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 Pin::into_inner(self).poll_if_event(cx).map(Some)
179 }
180}
181
182impl<T> FusedStream for IfWatcher<T>
183where
184 T: AsyncSocket + AsyncSocket + Unpin,
185{
186 fn is_terminated(&self) -> bool {
187 false
188 }
189}