libp2p_mdns/
behaviour.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod iface;
22mod socket;
23mod timer;
24
25use self::iface::InterfaceState;
26use crate::behaviour::{socket::AsyncSocket, timer::Builder};
27use crate::Config;
28use futures::channel::mpsc;
29use futures::{Stream, StreamExt};
30use if_watch::IfEvent;
31use libp2p_core::{Endpoint, Multiaddr};
32use libp2p_identity::PeerId;
33use libp2p_swarm::behaviour::FromSwarm;
34use libp2p_swarm::{
35    dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour, THandler,
36    THandlerInEvent, THandlerOutEvent, ToSwarm,
37};
38use smallvec::SmallVec;
39use std::collections::hash_map::{Entry, HashMap};
40use std::future::Future;
41use std::sync::{Arc, RwLock};
42use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant};
43
44/// An abstraction to allow for compatibility with various async runtimes.
45pub trait Provider: 'static {
46    /// The Async Socket type.
47    type Socket: AsyncSocket;
48    /// The Async Timer type.
49    type Timer: Builder + Stream;
50    /// The IfWatcher type.
51    type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
52
53    type TaskHandle: Abort;
54
55    /// Create a new instance of the `IfWatcher` type.
56    fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
57
58    fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
59}
60
61#[allow(unreachable_pub)] // Not re-exported.
62pub trait Abort {
63    fn abort(self);
64}
65
66/// The type of a [`Behaviour`] using the `async-io` implementation.
67#[cfg(feature = "async-io")]
68pub mod async_io {
69    use super::Provider;
70    use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
71    use async_std::task::JoinHandle;
72    use if_watch::smol::IfWatcher;
73    use std::future::Future;
74
75    #[doc(hidden)]
76    pub enum AsyncIo {}
77
78    impl Provider for AsyncIo {
79        type Socket = AsyncUdpSocket;
80        type Timer = AsyncTimer;
81        type Watcher = IfWatcher;
82        type TaskHandle = JoinHandle<()>;
83
84        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
85            IfWatcher::new()
86        }
87
88        fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
89            async_std::task::spawn(task)
90        }
91    }
92
93    impl Abort for JoinHandle<()> {
94        fn abort(self) {
95            async_std::task::spawn(self.cancel());
96        }
97    }
98
99    pub type Behaviour = super::Behaviour<AsyncIo>;
100}
101
102/// The type of a [`Behaviour`] using the `tokio` implementation.
103#[cfg(feature = "tokio")]
104pub mod tokio {
105    use super::Provider;
106    use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
107    use if_watch::tokio::IfWatcher;
108    use std::future::Future;
109    use tokio::task::JoinHandle;
110
111    #[doc(hidden)]
112    pub enum Tokio {}
113
114    impl Provider for Tokio {
115        type Socket = TokioUdpSocket;
116        type Timer = TokioTimer;
117        type Watcher = IfWatcher;
118        type TaskHandle = JoinHandle<()>;
119
120        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
121            IfWatcher::new()
122        }
123
124        fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
125            tokio::spawn(task)
126        }
127    }
128
129    impl Abort for JoinHandle<()> {
130        fn abort(self) {
131            JoinHandle::abort(&self)
132        }
133    }
134
135    pub type Behaviour = super::Behaviour<Tokio>;
136}
137
138/// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds
139/// them to the topology.
140#[derive(Debug)]
141pub struct Behaviour<P>
142where
143    P: Provider,
144{
145    /// InterfaceState config.
146    config: Config,
147
148    /// Iface watcher.
149    if_watch: P::Watcher,
150
151    /// Handles to tasks running the mDNS queries.
152    if_tasks: HashMap<IpAddr, P::TaskHandle>,
153
154    query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
155    query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
156
157    /// List of nodes that we have discovered, the address, and when their TTL expires.
158    ///
159    /// Each combination of `PeerId` and `Multiaddr` can only appear once, but the same `PeerId`
160    /// can appear multiple times.
161    discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
162
163    /// Future that fires when the TTL of at least one node in `discovered_nodes` expires.
164    ///
165    /// `None` if `discovered_nodes` is empty.
166    closest_expiration: Option<P::Timer>,
167
168    /// The current set of listen addresses.
169    ///
170    /// This is shared across all interface tasks using an [`RwLock`].
171    /// The [`Behaviour`] updates this upon new [`FromSwarm`] events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
172    listen_addresses: Arc<RwLock<ListenAddresses>>,
173
174    local_peer_id: PeerId,
175}
176
177impl<P> Behaviour<P>
178where
179    P: Provider,
180{
181    /// Builds a new `Mdns` behaviour.
182    pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
183        let (tx, rx) = mpsc::channel(10); // Chosen arbitrarily.
184
185        Ok(Self {
186            config,
187            if_watch: P::new_watcher()?,
188            if_tasks: Default::default(),
189            query_response_receiver: rx,
190            query_response_sender: tx,
191            discovered_nodes: Default::default(),
192            closest_expiration: Default::default(),
193            listen_addresses: Default::default(),
194            local_peer_id,
195        })
196    }
197
198    /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
199    #[deprecated(note = "Use `discovered_nodes` iterator instead.")]
200    pub fn has_node(&self, peer_id: &PeerId) -> bool {
201        self.discovered_nodes().any(|p| p == peer_id)
202    }
203
204    /// Returns the list of nodes that we have discovered through mDNS and that are not expired.
205    pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
206        self.discovered_nodes.iter().map(|(p, _, _)| p)
207    }
208
209    /// Expires a node before the ttl.
210    #[deprecated(note = "Unused API. Will be removed in the next release.")]
211    pub fn expire_node(&mut self, peer_id: &PeerId) {
212        let now = Instant::now();
213        for (peer, _addr, expires) in &mut self.discovered_nodes {
214            if peer == peer_id {
215                *expires = now;
216            }
217        }
218        self.closest_expiration = Some(P::Timer::at(now));
219    }
220}
221
222impl<P> NetworkBehaviour for Behaviour<P>
223where
224    P: Provider,
225{
226    type ConnectionHandler = dummy::ConnectionHandler;
227    type ToSwarm = Event;
228
229    fn handle_established_inbound_connection(
230        &mut self,
231        _: ConnectionId,
232        _: PeerId,
233        _: &Multiaddr,
234        _: &Multiaddr,
235    ) -> Result<THandler<Self>, ConnectionDenied> {
236        Ok(dummy::ConnectionHandler)
237    }
238
239    fn handle_pending_outbound_connection(
240        &mut self,
241        _connection_id: ConnectionId,
242        maybe_peer: Option<PeerId>,
243        _addresses: &[Multiaddr],
244        _effective_role: Endpoint,
245    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
246        let peer_id = match maybe_peer {
247            None => return Ok(vec![]),
248            Some(peer) => peer,
249        };
250
251        Ok(self
252            .discovered_nodes
253            .iter()
254            .filter(|(peer, _, _)| peer == &peer_id)
255            .map(|(_, addr, _)| addr.clone())
256            .collect())
257    }
258
259    fn handle_established_outbound_connection(
260        &mut self,
261        _: ConnectionId,
262        _: PeerId,
263        _: &Multiaddr,
264        _: Endpoint,
265    ) -> Result<THandler<Self>, ConnectionDenied> {
266        Ok(dummy::ConnectionHandler)
267    }
268
269    fn on_connection_handler_event(
270        &mut self,
271        _: PeerId,
272        _: ConnectionId,
273        ev: THandlerOutEvent<Self>,
274    ) {
275        void::unreachable(ev)
276    }
277
278    fn on_swarm_event(&mut self, event: FromSwarm) {
279        self.listen_addresses
280            .write()
281            .unwrap_or_else(|e| e.into_inner())
282            .on_swarm_event(&event);
283    }
284
285    #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
286    fn poll(
287        &mut self,
288        cx: &mut Context<'_>,
289    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
290        // Poll ifwatch.
291        while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
292            match event {
293                Ok(IfEvent::Up(inet)) => {
294                    let addr = inet.addr();
295                    if addr.is_loopback() {
296                        continue;
297                    }
298                    if addr.is_ipv4() && self.config.enable_ipv6
299                        || addr.is_ipv6() && !self.config.enable_ipv6
300                    {
301                        continue;
302                    }
303                    if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
304                        match InterfaceState::<P::Socket, P::Timer>::new(
305                            addr,
306                            self.config.clone(),
307                            self.local_peer_id,
308                            self.listen_addresses.clone(),
309                            self.query_response_sender.clone(),
310                        ) {
311                            Ok(iface_state) => {
312                                e.insert(P::spawn(iface_state));
313                            }
314                            Err(err) => {
315                                tracing::error!("failed to create `InterfaceState`: {}", err)
316                            }
317                        }
318                    }
319                }
320                Ok(IfEvent::Down(inet)) => {
321                    if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
322                        tracing::info!(instance=%inet.addr(), "dropping instance");
323
324                        handle.abort();
325                    }
326                }
327                Err(err) => tracing::error!("if watch returned an error: {}", err),
328            }
329        }
330        // Emit discovered event.
331        let mut discovered = Vec::new();
332
333        while let Poll::Ready(Some((peer, addr, expiration))) =
334            self.query_response_receiver.poll_next_unpin(cx)
335        {
336            if let Some((_, _, cur_expires)) = self
337                .discovered_nodes
338                .iter_mut()
339                .find(|(p, a, _)| *p == peer && *a == addr)
340            {
341                *cur_expires = cmp::max(*cur_expires, expiration);
342            } else {
343                tracing::info!(%peer, address=%addr, "discovered peer on address");
344                self.discovered_nodes.push((peer, addr.clone(), expiration));
345                discovered.push((peer, addr));
346            }
347        }
348
349        if !discovered.is_empty() {
350            let event = Event::Discovered(discovered);
351            return Poll::Ready(ToSwarm::GenerateEvent(event));
352        }
353        // Emit expired event.
354        let now = Instant::now();
355        let mut closest_expiration = None;
356        let mut expired = Vec::new();
357        self.discovered_nodes.retain(|(peer, addr, expiration)| {
358            if *expiration <= now {
359                tracing::info!(%peer, address=%addr, "expired peer on address");
360                expired.push((*peer, addr.clone()));
361                return false;
362            }
363            closest_expiration = Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
364            true
365        });
366        if !expired.is_empty() {
367            let event = Event::Expired(expired);
368            return Poll::Ready(ToSwarm::GenerateEvent(event));
369        }
370        if let Some(closest_expiration) = closest_expiration {
371            let mut timer = P::Timer::at(closest_expiration);
372            let _ = Pin::new(&mut timer).poll_next(cx);
373
374            self.closest_expiration = Some(timer);
375        }
376        Poll::Pending
377    }
378}
379
380/// Event that can be produced by the `Mdns` behaviour.
381#[derive(Debug, Clone)]
382pub enum Event {
383    /// Discovered nodes through mDNS.
384    Discovered(Vec<(PeerId, Multiaddr)>),
385
386    /// The given combinations of `PeerId` and `Multiaddr` have expired.
387    ///
388    /// Each discovered record has a time-to-live. When this TTL expires and the address hasn't
389    /// been refreshed, we remove it from the list and emit it as an `Expired` event.
390    Expired(Vec<(PeerId, Multiaddr)>),
391}