if_watch/
linux.rs

1use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net};
2use fnv::FnvHashSet;
3use futures::ready;
4use futures::stream::{FusedStream, Stream, TryStreamExt};
5use futures::StreamExt;
6use netlink_packet_core::NetlinkPayload;
7use netlink_packet_route::rtnl::address::nlas::Nla;
8use netlink_packet_route::rtnl::{AddressMessage, RtnlMessage};
9use netlink_proto::Connection;
10use netlink_sys::{AsyncSocket, SocketAddr};
11use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
12use std::collections::VecDeque;
13use std::future::Future;
14use std::io::{Error, ErrorKind, Result};
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19#[cfg(feature = "tokio")]
20pub mod tokio {
21    //! An interface watcher that uses `netlink`'s [`TokioSocket`](netlink_sys::TokioSocket)
22    use netlink_sys::TokioSocket;
23
24    /// Watches for interface changes.
25    pub type IfWatcher = super::IfWatcher<TokioSocket>;
26}
27
28#[cfg(feature = "smol")]
29pub mod smol {
30    //! An interface watcher that uses `netlink`'s [`SmolSocket`](netlink_sys::SmolSocket)
31    use netlink_sys::SmolSocket;
32
33    /// Watches for interface changes.
34    pub type IfWatcher = super::IfWatcher<SmolSocket>;
35}
36
37pub struct IfWatcher<T> {
38    conn: Connection<RtnlMessage, T>,
39    messages: Pin<Box<dyn Stream<Item = Result<RtnlMessage>> + Send>>,
40    addrs: FnvHashSet<IpNet>,
41    queue: VecDeque<IfEvent>,
42}
43
44impl<T> std::fmt::Debug for IfWatcher<T> {
45    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46        f.debug_struct("IfWatcher")
47            .field("addrs", &self.addrs)
48            .finish_non_exhaustive()
49    }
50}
51
52impl<T> IfWatcher<T>
53where
54    T: AsyncSocket + Unpin,
55{
56    /// Create a watcher.
57    pub fn new() -> Result<Self> {
58        let (mut conn, handle, messages) = rtnetlink::new_connection_with_socket::<T>()?;
59        let groups = RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR;
60        let addr = SocketAddr::new(0, groups);
61        conn.socket_mut().socket_mut().bind(&addr)?;
62        let get_addrs_stream = handle
63            .address()
64            .get()
65            .execute()
66            .map_ok(RtnlMessage::NewAddress)
67            .map_err(|err| Error::new(ErrorKind::Other, err));
68        let msg_stream = messages.filter_map(|(msg, _)| async {
69            match msg.payload {
70                NetlinkPayload::Error(err) => Some(Err(err.to_io())),
71                NetlinkPayload::InnerMessage(msg) => Some(Ok(msg)),
72                _ => None,
73            }
74        });
75        let messages = get_addrs_stream.chain(msg_stream).boxed();
76        let addrs = FnvHashSet::default();
77        let queue = VecDeque::default();
78        Ok(Self {
79            conn,
80            messages,
81            addrs,
82            queue,
83        })
84    }
85
86    /// Iterate over current networks.
87    pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
88        self.addrs.iter()
89    }
90
91    fn add_address(&mut self, msg: AddressMessage) {
92        for net in iter_nets(msg) {
93            if self.addrs.insert(net) {
94                self.queue.push_back(IfEvent::Up(net));
95            }
96        }
97    }
98
99    fn rem_address(&mut self, msg: AddressMessage) {
100        for net in iter_nets(msg) {
101            if self.addrs.remove(&net) {
102                self.queue.push_back(IfEvent::Down(net));
103            }
104        }
105    }
106
107    /// Poll for an address change event.
108    pub fn poll_if_event(&mut self, cx: &mut Context) -> Poll<Result<IfEvent>> {
109        loop {
110            if let Some(event) = self.queue.pop_front() {
111                return Poll::Ready(Ok(event));
112            }
113            if Pin::new(&mut self.conn).poll(cx).is_ready() {
114                return Poll::Ready(Err(socket_err(
115                    "rtnetlink socket closed. Connection has been terminated.",
116                )));
117            }
118            let message = match ready!(self.messages.poll_next_unpin(cx)) {
119                Some(Ok(message)) => message,
120                Some(Err(error)) => {
121                    return Poll::Ready(Err(socket_err(&format!(
122                        "rtnetlink socket closed. {error}"
123                    ))));
124                }
125                None => {
126                    return Poll::Ready(Err(socket_err(
127                        "rtnetlink socket closed. Empty message has been returned.",
128                    )));
129                }
130            };
131            match message {
132                RtnlMessage::NewAddress(msg) => self.add_address(msg),
133                RtnlMessage::DelAddress(msg) => self.rem_address(msg),
134                _ => {}
135            }
136        }
137    }
138}
139
140fn socket_err(error: &str) -> std::io::Error {
141    std::io::Error::new(ErrorKind::BrokenPipe, error)
142}
143
144fn iter_nets(msg: AddressMessage) -> impl Iterator<Item = IpNet> {
145    let prefix = msg.header.prefix_len;
146    let family = msg.header.family;
147    msg.nlas.into_iter().filter_map(move |nla| {
148        if let Nla::Address(octets) = nla {
149            match family {
150                2 => {
151                    let mut addr = [0; 4];
152                    addr.copy_from_slice(&octets);
153                    Some(IpNet::V4(
154                        Ipv4Net::new(Ipv4Addr::from(addr), prefix).unwrap(),
155                    ))
156                }
157                10 => {
158                    let mut addr = [0; 16];
159                    addr.copy_from_slice(&octets);
160                    Some(IpNet::V6(
161                        Ipv6Net::new(Ipv6Addr::from(addr), prefix).unwrap(),
162                    ))
163                }
164                _ => None,
165            }
166        } else {
167            None
168        }
169    })
170}
171
172impl<T> Stream for IfWatcher<T>
173where
174    T: AsyncSocket + Unpin,
175{
176    type Item = Result<IfEvent>;
177    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        Pin::into_inner(self).poll_if_event(cx).map(Some)
179    }
180}
181
182impl<T> FusedStream for IfWatcher<T>
183where
184    T: AsyncSocket + AsyncSocket + Unpin,
185{
186    fn is_terminated(&self) -> bool {
187        false
188    }
189}