libp2p_swarm/
connection.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26pub use error::ConnectionError;
27pub(crate) use error::{
28    PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
29};
30pub use supported_protocols::SupportedProtocols;
31
32use crate::handler::{
33    AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound,
34    FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange,
35    UpgradeInfoSend,
36};
37use crate::stream::ActiveStreamCounter;
38use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
39use crate::{
40    ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
41};
42use futures::future::BoxFuture;
43use futures::stream::FuturesUnordered;
44use futures::StreamExt;
45use futures::{stream, FutureExt};
46use futures_timer::Delay;
47use instant::Instant;
48use libp2p_core::connection::ConnectedPoint;
49use libp2p_core::multiaddr::Multiaddr;
50use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox};
51use libp2p_core::upgrade;
52use libp2p_core::upgrade::{NegotiationError, ProtocolError};
53use libp2p_core::Endpoint;
54use libp2p_identity::PeerId;
55use std::collections::HashSet;
56use std::fmt::{Display, Formatter};
57use std::future::Future;
58use std::sync::atomic::{AtomicUsize, Ordering};
59use std::task::Waker;
60use std::time::Duration;
61use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
62
63static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
64
65/// Connection identifier.
66#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
67pub struct ConnectionId(usize);
68
69impl ConnectionId {
70    /// Creates an _unchecked_ [`ConnectionId`].
71    ///
72    /// [`Swarm`](crate::Swarm) enforces that [`ConnectionId`]s are unique and not reused.
73    /// This constructor does not, hence the _unchecked_.
74    ///
75    /// It is primarily meant for allowing manual tests of [`NetworkBehaviour`](crate::NetworkBehaviour)s.
76    pub fn new_unchecked(id: usize) -> Self {
77        Self(id)
78    }
79
80    /// Returns the next available [`ConnectionId`].
81    pub(crate) fn next() -> Self {
82        Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
83    }
84}
85
86impl Display for ConnectionId {
87    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
88        write!(f, "{}", self.0)
89    }
90}
91
92/// Information about a successfully established connection.
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub(crate) struct Connected {
95    /// The connected endpoint, including network address information.
96    pub(crate) endpoint: ConnectedPoint,
97    /// Information obtained from the transport.
98    pub(crate) peer_id: PeerId,
99}
100
101/// Event generated by a [`Connection`].
102#[derive(Debug, Clone)]
103pub(crate) enum Event<T> {
104    /// Event generated by the [`ConnectionHandler`].
105    Handler(T),
106    /// Address of the remote has changed.
107    AddressChange(Multiaddr),
108}
109
110/// A multiplexed connection to a peer with an associated [`ConnectionHandler`].
111pub(crate) struct Connection<THandler>
112where
113    THandler: ConnectionHandler,
114{
115    /// Node that handles the muxing.
116    muxing: StreamMuxerBox,
117    /// The underlying handler.
118    handler: THandler,
119    /// Futures that upgrade incoming substreams.
120    negotiating_in: FuturesUnordered<
121        StreamUpgrade<
122            THandler::InboundOpenInfo,
123            <THandler::InboundProtocol as InboundUpgradeSend>::Output,
124            <THandler::InboundProtocol as InboundUpgradeSend>::Error,
125        >,
126    >,
127    /// Futures that upgrade outgoing substreams.
128    negotiating_out: FuturesUnordered<
129        StreamUpgrade<
130            THandler::OutboundOpenInfo,
131            <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
132            <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
133        >,
134    >,
135    /// The currently planned connection & handler shutdown.
136    shutdown: Shutdown,
137    /// The substream upgrade protocol override, if any.
138    substream_upgrade_protocol_override: Option<upgrade::Version>,
139    /// The maximum number of inbound streams concurrently negotiating on a
140    /// connection. New inbound streams exceeding the limit are dropped and thus
141    /// reset.
142    ///
143    /// Note: This only enforces a limit on the number of concurrently
144    /// negotiating inbound streams. The total number of inbound streams on a
145    /// connection is the sum of negotiating and negotiated streams. A limit on
146    /// the total number of streams can be enforced at the [`StreamMuxerBox`] level.
147    max_negotiating_inbound_streams: usize,
148    /// Contains all upgrades that are waiting for a new outbound substream.
149    ///
150    /// The upgrade timeout is already ticking here so this may fail in case the remote is not quick
151    /// enough in providing us with a new stream.
152    requested_substreams: FuturesUnordered<
153        SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
154    >,
155
156    local_supported_protocols: HashSet<StreamProtocol>,
157    remote_supported_protocols: HashSet<StreamProtocol>,
158    idle_timeout: Duration,
159    stream_counter: ActiveStreamCounter,
160}
161
162impl<THandler> fmt::Debug for Connection<THandler>
163where
164    THandler: ConnectionHandler + fmt::Debug,
165    THandler::OutboundOpenInfo: fmt::Debug,
166{
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        f.debug_struct("Connection")
169            .field("handler", &self.handler)
170            .finish()
171    }
172}
173
174impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
175
176impl<THandler> Connection<THandler>
177where
178    THandler: ConnectionHandler,
179{
180    /// Builds a new `Connection` from the given substream multiplexer
181    /// and connection handler.
182    pub(crate) fn new(
183        muxer: StreamMuxerBox,
184        mut handler: THandler,
185        substream_upgrade_protocol_override: Option<upgrade::Version>,
186        max_negotiating_inbound_streams: usize,
187        idle_timeout: Duration,
188    ) -> Self {
189        let initial_protocols = gather_supported_protocols(&handler);
190        if !initial_protocols.is_empty() {
191            handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
192                ProtocolsChange::Added(ProtocolsAdded::from_set(&initial_protocols)),
193            ));
194        }
195        Connection {
196            muxing: muxer,
197            handler,
198            negotiating_in: Default::default(),
199            negotiating_out: Default::default(),
200            shutdown: Shutdown::None,
201            substream_upgrade_protocol_override,
202            max_negotiating_inbound_streams,
203            requested_substreams: Default::default(),
204            local_supported_protocols: initial_protocols,
205            remote_supported_protocols: Default::default(),
206            idle_timeout,
207            stream_counter: ActiveStreamCounter::default(),
208        }
209    }
210
211    /// Notifies the connection handler of an event.
212    pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
213        self.handler.on_behaviour_event(event);
214    }
215
216    /// Begins an orderly shutdown of the connection, returning a stream of final events and a `Future` that resolves when connection shutdown is complete.
217    pub(crate) fn close(
218        self,
219    ) -> (
220        impl futures::Stream<Item = THandler::ToBehaviour>,
221        impl Future<Output = io::Result<()>>,
222    ) {
223        let Connection {
224            mut handler,
225            muxing,
226            ..
227        } = self;
228
229        (
230            stream::poll_fn(move |cx| handler.poll_close(cx)),
231            muxing.close(),
232        )
233    }
234
235    /// Polls the handler and the substream, forwarding events from the former to the latter and
236    /// vice versa.
237    #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
238    pub(crate) fn poll(
239        self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241    ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
242        let Self {
243            requested_substreams,
244            muxing,
245            handler,
246            negotiating_out,
247            negotiating_in,
248            shutdown,
249            max_negotiating_inbound_streams,
250            substream_upgrade_protocol_override,
251            local_supported_protocols: supported_protocols,
252            remote_supported_protocols,
253            idle_timeout,
254            stream_counter,
255            ..
256        } = self.get_mut();
257
258        loop {
259            match requested_substreams.poll_next_unpin(cx) {
260                Poll::Ready(Some(Ok(()))) => continue,
261                Poll::Ready(Some(Err(info))) => {
262                    handler.on_connection_event(ConnectionEvent::DialUpgradeError(
263                        DialUpgradeError {
264                            info,
265                            error: StreamUpgradeError::Timeout,
266                        },
267                    ));
268                    continue;
269                }
270                Poll::Ready(None) | Poll::Pending => {}
271            }
272
273            // Poll the [`ConnectionHandler`].
274            match handler.poll(cx) {
275                Poll::Pending => {}
276                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
277                    let timeout = *protocol.timeout();
278                    let (upgrade, user_data) = protocol.into_upgrade();
279
280                    requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
281                    continue; // Poll handler until exhausted.
282                }
283                Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
284                    return Poll::Ready(Ok(Event::Handler(event)));
285                }
286                Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
287                    ProtocolSupport::Added(protocols),
288                )) => {
289                    if let Some(added) =
290                        ProtocolsChange::add(remote_supported_protocols, &protocols)
291                    {
292                        handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
293                        remote_supported_protocols.extend(protocols);
294                    }
295
296                    continue;
297                }
298                Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
299                    ProtocolSupport::Removed(protocols),
300                )) => {
301                    if let Some(removed) =
302                        ProtocolsChange::remove(remote_supported_protocols, &protocols)
303                    {
304                        handler
305                            .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
306                        remote_supported_protocols.retain(|p| !protocols.contains(p));
307                    }
308
309                    continue;
310                }
311            }
312
313            // In case the [`ConnectionHandler`] can not make any more progress, poll the negotiating outbound streams.
314            match negotiating_out.poll_next_unpin(cx) {
315                Poll::Pending | Poll::Ready(None) => {}
316                Poll::Ready(Some((info, Ok(protocol)))) => {
317                    handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
318                        FullyNegotiatedOutbound { protocol, info },
319                    ));
320                    continue;
321                }
322                Poll::Ready(Some((info, Err(error)))) => {
323                    handler.on_connection_event(ConnectionEvent::DialUpgradeError(
324                        DialUpgradeError { info, error },
325                    ));
326                    continue;
327                }
328            }
329
330            // In case both the [`ConnectionHandler`] and the negotiating outbound streams can not
331            // make any more progress, poll the negotiating inbound streams.
332            match negotiating_in.poll_next_unpin(cx) {
333                Poll::Pending | Poll::Ready(None) => {}
334                Poll::Ready(Some((info, Ok(protocol)))) => {
335                    handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
336                        FullyNegotiatedInbound { protocol, info },
337                    ));
338                    continue;
339                }
340                Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
341                    handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
342                        ListenUpgradeError { info, error },
343                    ));
344                    continue;
345                }
346                Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
347                    tracing::debug!("failed to upgrade inbound stream: {e}");
348                    continue;
349                }
350                Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
351                    tracing::debug!("no protocol could be agreed upon for inbound stream");
352                    continue;
353                }
354                Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
355                    tracing::debug!("inbound stream upgrade timed out");
356                    continue;
357                }
358            }
359
360            // Check if the connection (and handler) should be shut down.
361            // As long as we're still negotiating substreams or have any active streams shutdown is always postponed.
362            if negotiating_in.is_empty()
363                && negotiating_out.is_empty()
364                && requested_substreams.is_empty()
365                && stream_counter.has_no_active_streams()
366            {
367                if let Some(new_timeout) =
368                    compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
369                {
370                    *shutdown = new_timeout;
371                }
372
373                match shutdown {
374                    Shutdown::None => {}
375                    Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
376                    Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
377                        Poll::Ready(_) => {
378                            return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
379                        }
380                        Poll::Pending => {}
381                    },
382                }
383            } else {
384                *shutdown = Shutdown::None;
385            }
386
387            match muxing.poll_unpin(cx)? {
388                Poll::Pending => {}
389                Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
390                    handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
391                        new_address: &address,
392                    }));
393                    return Poll::Ready(Ok(Event::AddressChange(address)));
394                }
395            }
396
397            if let Some(requested_substream) = requested_substreams.iter_mut().next() {
398                match muxing.poll_outbound_unpin(cx)? {
399                    Poll::Pending => {}
400                    Poll::Ready(substream) => {
401                        let (user_data, timeout, upgrade) = requested_substream.extract();
402
403                        negotiating_out.push(StreamUpgrade::new_outbound(
404                            substream,
405                            user_data,
406                            timeout,
407                            upgrade,
408                            *substream_upgrade_protocol_override,
409                            stream_counter.clone(),
410                        ));
411
412                        continue; // Go back to the top, handler can potentially make progress again.
413                    }
414                }
415            }
416
417            if negotiating_in.len() < *max_negotiating_inbound_streams {
418                match muxing.poll_inbound_unpin(cx)? {
419                    Poll::Pending => {}
420                    Poll::Ready(substream) => {
421                        let protocol = handler.listen_protocol();
422
423                        negotiating_in.push(StreamUpgrade::new_inbound(
424                            substream,
425                            protocol,
426                            stream_counter.clone(),
427                        ));
428
429                        continue; // Go back to the top, handler can potentially make progress again.
430                    }
431                }
432            }
433
434            let new_protocols = gather_supported_protocols(handler);
435            let changes = ProtocolsChange::from_full_sets(supported_protocols, &new_protocols);
436
437            if !changes.is_empty() {
438                for change in changes {
439                    handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
440                }
441
442                *supported_protocols = new_protocols;
443
444                continue; // Go back to the top, handler can potentially make progress again.
445            }
446
447            return Poll::Pending; // Nothing can make progress, return `Pending`.
448        }
449    }
450
451    #[cfg(test)]
452    fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
453        Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
454    }
455}
456
457fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet<StreamProtocol> {
458    handler
459        .listen_protocol()
460        .upgrade()
461        .protocol_info()
462        .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok())
463        .collect()
464}
465
466fn compute_new_shutdown(
467    handler_keep_alive: bool,
468    current_shutdown: &Shutdown,
469    idle_timeout: Duration,
470) -> Option<Shutdown> {
471    match (current_shutdown, handler_keep_alive) {
472        (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
473        (Shutdown::Later(_), false) => None, // Do nothing, i.e. let the shutdown timer continue to tick.
474        (_, false) => {
475            let now = Instant::now();
476            let safe_keep_alive = checked_add_fraction(now, idle_timeout);
477
478            Some(Shutdown::Later(Delay::new(safe_keep_alive)))
479        }
480        (_, true) => Some(Shutdown::None),
481    }
482}
483
484/// Repeatedly halves and adds the [`Duration`] to the [`Instant`] until [`Instant::checked_add`] succeeds.
485///
486/// [`Instant`] depends on the underlying platform and has a limit of which points in time it can represent.
487/// The [`Duration`] computed by the this function may not be the longest possible that we can add to `now` but it will work.
488fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
489    while start.checked_add(duration).is_none() {
490        tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
491
492        duration /= 2;
493    }
494
495    duration
496}
497
498/// Borrowed information about an incoming connection currently being negotiated.
499#[derive(Debug, Copy, Clone)]
500pub(crate) struct IncomingInfo<'a> {
501    /// Local connection address.
502    pub(crate) local_addr: &'a Multiaddr,
503    /// Address used to send back data to the remote.
504    pub(crate) send_back_addr: &'a Multiaddr,
505}
506
507impl<'a> IncomingInfo<'a> {
508    /// Builds the [`ConnectedPoint`] corresponding to the incoming connection.
509    pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
510        ConnectedPoint::Listener {
511            local_addr: self.local_addr.clone(),
512            send_back_addr: self.send_back_addr.clone(),
513        }
514    }
515}
516
517struct StreamUpgrade<UserData, TOk, TErr> {
518    user_data: Option<UserData>,
519    timeout: Delay,
520    upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
521}
522
523impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
524    fn new_outbound<Upgrade>(
525        substream: SubstreamBox,
526        user_data: UserData,
527        timeout: Delay,
528        upgrade: Upgrade,
529        version_override: Option<upgrade::Version>,
530        counter: ActiveStreamCounter,
531    ) -> Self
532    where
533        Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
534    {
535        let effective_version = match version_override {
536            Some(version_override) if version_override != upgrade::Version::default() => {
537                tracing::debug!(
538                    "Substream upgrade protocol override: {:?} -> {:?}",
539                    upgrade::Version::default(),
540                    version_override
541                );
542
543                version_override
544            }
545            _ => upgrade::Version::default(),
546        };
547        let protocols = upgrade.protocol_info();
548
549        Self {
550            user_data: Some(user_data),
551            timeout,
552            upgrade: Box::pin(async move {
553                let (info, stream) = multistream_select::dialer_select_proto(
554                    substream,
555                    protocols,
556                    effective_version,
557                )
558                .await
559                .map_err(to_stream_upgrade_error)?;
560
561                let output = upgrade
562                    .upgrade_outbound(Stream::new(stream, counter), info)
563                    .await
564                    .map_err(StreamUpgradeError::Apply)?;
565
566                Ok(output)
567            }),
568        }
569    }
570}
571
572impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
573    fn new_inbound<Upgrade>(
574        substream: SubstreamBox,
575        protocol: SubstreamProtocol<Upgrade, UserData>,
576        counter: ActiveStreamCounter,
577    ) -> Self
578    where
579        Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
580    {
581        let timeout = *protocol.timeout();
582        let (upgrade, open_info) = protocol.into_upgrade();
583        let protocols = upgrade.protocol_info();
584
585        Self {
586            user_data: Some(open_info),
587            timeout: Delay::new(timeout),
588            upgrade: Box::pin(async move {
589                let (info, stream) =
590                    multistream_select::listener_select_proto(substream, protocols)
591                        .await
592                        .map_err(to_stream_upgrade_error)?;
593
594                let output = upgrade
595                    .upgrade_inbound(Stream::new(stream, counter), info)
596                    .await
597                    .map_err(StreamUpgradeError::Apply)?;
598
599                Ok(output)
600            }),
601        }
602    }
603}
604
605fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
606    match e {
607        NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
608        NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
609        NegotiationError::ProtocolError(other) => {
610            StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
611        }
612    }
613}
614
615impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
616
617impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
618    type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
619
620    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
621        match self.timeout.poll_unpin(cx) {
622            Poll::Ready(()) => {
623                return Poll::Ready((
624                    self.user_data
625                        .take()
626                        .expect("Future not to be polled again once ready."),
627                    Err(StreamUpgradeError::Timeout),
628                ))
629            }
630
631            Poll::Pending => {}
632        }
633
634        let result = futures::ready!(self.upgrade.poll_unpin(cx));
635        let user_data = self
636            .user_data
637            .take()
638            .expect("Future not to be polled again once ready.");
639
640        Poll::Ready((user_data, result))
641    }
642}
643
644enum SubstreamRequested<UserData, Upgrade> {
645    Waiting {
646        user_data: UserData,
647        timeout: Delay,
648        upgrade: Upgrade,
649        /// A waker to notify our [`FuturesUnordered`] that we have extracted the data.
650        ///
651        /// This will ensure that we will get polled again in the next iteration which allows us to
652        /// resolve with `Ok(())` and be removed from the [`FuturesUnordered`].
653        extracted_waker: Option<Waker>,
654    },
655    Done,
656}
657
658impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
659    fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
660        Self::Waiting {
661            user_data,
662            timeout: Delay::new(timeout),
663            upgrade,
664            extracted_waker: None,
665        }
666    }
667
668    fn extract(&mut self) -> (UserData, Delay, Upgrade) {
669        match mem::replace(self, Self::Done) {
670            SubstreamRequested::Waiting {
671                user_data,
672                timeout,
673                upgrade,
674                extracted_waker: waker,
675            } => {
676                if let Some(waker) = waker {
677                    waker.wake();
678                }
679
680                (user_data, timeout, upgrade)
681            }
682            SubstreamRequested::Done => panic!("cannot extract twice"),
683        }
684    }
685}
686
687impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
688
689impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
690    type Output = Result<(), UserData>;
691
692    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
693        let this = self.get_mut();
694
695        match mem::replace(this, Self::Done) {
696            SubstreamRequested::Waiting {
697                user_data,
698                upgrade,
699                mut timeout,
700                ..
701            } => match timeout.poll_unpin(cx) {
702                Poll::Ready(()) => Poll::Ready(Err(user_data)),
703                Poll::Pending => {
704                    *this = Self::Waiting {
705                        user_data,
706                        upgrade,
707                        timeout,
708                        extracted_waker: Some(cx.waker().clone()),
709                    };
710                    Poll::Pending
711                }
712            },
713            SubstreamRequested::Done => Poll::Ready(Ok(())),
714        }
715    }
716}
717
718/// The options for a planned connection & handler shutdown.
719///
720/// A shutdown is planned anew based on the return value of
721/// [`ConnectionHandler::connection_keep_alive`] of the underlying handler
722/// after every invocation of [`ConnectionHandler::poll`].
723///
724/// A planned shutdown is always postponed for as long as there are ingoing
725/// or outgoing substreams being negotiated, i.e. it is a graceful, "idle"
726/// shutdown.
727#[derive(Debug)]
728enum Shutdown {
729    /// No shutdown is planned.
730    None,
731    /// A shut down is planned as soon as possible.
732    Asap,
733    /// A shut down is planned for when a `Delay` has elapsed.
734    Later(Delay),
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740    use crate::dummy;
741    use futures::future;
742    use futures::AsyncRead;
743    use futures::AsyncWrite;
744    use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
745    use libp2p_core::StreamMuxer;
746    use quickcheck::*;
747    use std::sync::{Arc, Weak};
748    use std::time::Instant;
749    use tracing_subscriber::EnvFilter;
750    use void::Void;
751
752    #[test]
753    fn max_negotiating_inbound_streams() {
754        let _ = tracing_subscriber::fmt()
755            .with_env_filter(EnvFilter::from_default_env())
756            .try_init();
757
758        fn prop(max_negotiating_inbound_streams: u8) {
759            let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
760
761            let alive_substream_counter = Arc::new(());
762            let mut connection = Connection::new(
763                StreamMuxerBox::new(DummyStreamMuxer {
764                    counter: alive_substream_counter.clone(),
765                }),
766                MockConnectionHandler::new(Duration::from_secs(10)),
767                None,
768                max_negotiating_inbound_streams,
769                Duration::ZERO,
770            );
771
772            let result = connection.poll_noop_waker();
773
774            assert!(result.is_pending());
775            assert_eq!(
776                Arc::weak_count(&alive_substream_counter),
777                max_negotiating_inbound_streams,
778                "Expect no more than the maximum number of allowed streams"
779            );
780        }
781
782        QuickCheck::new().quickcheck(prop as fn(_));
783    }
784
785    #[test]
786    fn outbound_stream_timeout_starts_on_request() {
787        let upgrade_timeout = Duration::from_secs(1);
788        let mut connection = Connection::new(
789            StreamMuxerBox::new(PendingStreamMuxer),
790            MockConnectionHandler::new(upgrade_timeout),
791            None,
792            2,
793            Duration::ZERO,
794        );
795
796        connection.handler.open_new_outbound();
797        let _ = connection.poll_noop_waker();
798
799        std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
800
801        let _ = connection.poll_noop_waker();
802
803        assert!(matches!(
804            connection.handler.error.unwrap(),
805            StreamUpgradeError::Timeout
806        ))
807    }
808
809    #[test]
810    fn propagates_changes_to_supported_inbound_protocols() {
811        let mut connection = Connection::new(
812            StreamMuxerBox::new(PendingStreamMuxer),
813            ConfigurableProtocolConnectionHandler::default(),
814            None,
815            0,
816            Duration::ZERO,
817        );
818
819        // First, start listening on a single protocol.
820        connection.handler.listen_on(&["/foo"]);
821        let _ = connection.poll_noop_waker();
822
823        assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
824        assert!(connection.handler.local_removed.is_empty());
825
826        // Second, listen on two protocols.
827        connection.handler.listen_on(&["/foo", "/bar"]);
828        let _ = connection.poll_noop_waker();
829
830        assert_eq!(
831            connection.handler.local_added,
832            vec![vec!["/foo"], vec!["/bar"]],
833            "expect to only receive an event for the newly added protocols"
834        );
835        assert!(connection.handler.local_removed.is_empty());
836
837        // Third, stop listening on the first protocol.
838        connection.handler.listen_on(&["/bar"]);
839        let _ = connection.poll_noop_waker();
840
841        assert_eq!(
842            connection.handler.local_added,
843            vec![vec!["/foo"], vec!["/bar"]]
844        );
845        assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
846    }
847
848    #[test]
849    fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
850        let mut connection = Connection::new(
851            StreamMuxerBox::new(PendingStreamMuxer),
852            ConfigurableProtocolConnectionHandler::default(),
853            None,
854            0,
855            Duration::ZERO,
856        );
857
858        // First, remote supports a single protocol.
859        connection.handler.remote_adds_support_for(&["/foo"]);
860        let _ = connection.poll_noop_waker();
861
862        assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
863        assert!(connection.handler.remote_removed.is_empty());
864
865        // Second, it adds a protocol but also still includes the first one.
866        connection
867            .handler
868            .remote_adds_support_for(&["/foo", "/bar"]);
869        let _ = connection.poll_noop_waker();
870
871        assert_eq!(
872            connection.handler.remote_added,
873            vec![vec!["/foo"], vec!["/bar"]],
874            "expect to only receive an event for the newly added protocol"
875        );
876        assert!(connection.handler.remote_removed.is_empty());
877
878        // Third, stop listening on a protocol it never advertised (we can't control what handlers do so this needs to be handled gracefully).
879        connection.handler.remote_removes_support_for(&["/baz"]);
880        let _ = connection.poll_noop_waker();
881
882        assert_eq!(
883            connection.handler.remote_added,
884            vec![vec!["/foo"], vec!["/bar"]]
885        );
886        assert!(&connection.handler.remote_removed.is_empty());
887
888        // Fourth, stop listening on a protocol that was previously supported
889        connection.handler.remote_removes_support_for(&["/bar"]);
890        let _ = connection.poll_noop_waker();
891
892        assert_eq!(
893            connection.handler.remote_added,
894            vec![vec!["/foo"], vec!["/bar"]]
895        );
896        assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
897    }
898
899    #[tokio::test]
900    async fn idle_timeout_with_keep_alive_no() {
901        let idle_timeout = Duration::from_millis(100);
902
903        let mut connection = Connection::new(
904            StreamMuxerBox::new(PendingStreamMuxer),
905            dummy::ConnectionHandler,
906            None,
907            0,
908            idle_timeout,
909        );
910
911        assert!(connection.poll_noop_waker().is_pending());
912
913        tokio::time::sleep(idle_timeout).await;
914
915        assert!(matches!(
916            connection.poll_noop_waker(),
917            Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
918        ));
919    }
920
921    #[test]
922    fn checked_add_fraction_can_add_u64_max() {
923        let _ = tracing_subscriber::fmt()
924            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
925            .try_init();
926        let start = Instant::now();
927
928        let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
929
930        assert!(start.checked_add(duration).is_some())
931    }
932
933    #[test]
934    fn compute_new_shutdown_does_not_panic() {
935        let _ = tracing_subscriber::fmt()
936            .with_env_filter(EnvFilter::from_default_env())
937            .try_init();
938
939        #[derive(Debug)]
940        struct ArbitraryShutdown(Shutdown);
941
942        impl Clone for ArbitraryShutdown {
943            fn clone(&self) -> Self {
944                let shutdown = match self.0 {
945                    Shutdown::None => Shutdown::None,
946                    Shutdown::Asap => Shutdown::Asap,
947                    Shutdown::Later(_) => Shutdown::Later(
948                        // compute_new_shutdown does not touch the delay. Delay does not
949                        // implement Clone. Thus use a placeholder delay.
950                        Delay::new(Duration::from_secs(1)),
951                    ),
952                };
953
954                ArbitraryShutdown(shutdown)
955            }
956        }
957
958        impl Arbitrary for ArbitraryShutdown {
959            fn arbitrary(g: &mut Gen) -> Self {
960                let shutdown = match g.gen_range(1u8..4) {
961                    1 => Shutdown::None,
962                    2 => Shutdown::Asap,
963                    3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
964                    _ => unreachable!(),
965                };
966
967                Self(shutdown)
968            }
969        }
970
971        fn prop(
972            handler_keep_alive: bool,
973            current_shutdown: ArbitraryShutdown,
974            idle_timeout: Duration,
975        ) {
976            compute_new_shutdown(handler_keep_alive, &current_shutdown.0, idle_timeout);
977        }
978
979        QuickCheck::new().quickcheck(prop as fn(_, _, _));
980    }
981
982    struct DummyStreamMuxer {
983        counter: Arc<()>,
984    }
985
986    impl StreamMuxer for DummyStreamMuxer {
987        type Substream = PendingSubstream;
988        type Error = Void;
989
990        fn poll_inbound(
991            self: Pin<&mut Self>,
992            _: &mut Context<'_>,
993        ) -> Poll<Result<Self::Substream, Self::Error>> {
994            Poll::Ready(Ok(PendingSubstream {
995                _weak: Arc::downgrade(&self.counter),
996            }))
997        }
998
999        fn poll_outbound(
1000            self: Pin<&mut Self>,
1001            _: &mut Context<'_>,
1002        ) -> Poll<Result<Self::Substream, Self::Error>> {
1003            Poll::Pending
1004        }
1005
1006        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1007            Poll::Ready(Ok(()))
1008        }
1009
1010        fn poll(
1011            self: Pin<&mut Self>,
1012            _: &mut Context<'_>,
1013        ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1014            Poll::Pending
1015        }
1016    }
1017
1018    /// A [`StreamMuxer`] which never returns a stream.
1019    struct PendingStreamMuxer;
1020
1021    impl StreamMuxer for PendingStreamMuxer {
1022        type Substream = PendingSubstream;
1023        type Error = Void;
1024
1025        fn poll_inbound(
1026            self: Pin<&mut Self>,
1027            _: &mut Context<'_>,
1028        ) -> Poll<Result<Self::Substream, Self::Error>> {
1029            Poll::Pending
1030        }
1031
1032        fn poll_outbound(
1033            self: Pin<&mut Self>,
1034            _: &mut Context<'_>,
1035        ) -> Poll<Result<Self::Substream, Self::Error>> {
1036            Poll::Pending
1037        }
1038
1039        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1040            Poll::Pending
1041        }
1042
1043        fn poll(
1044            self: Pin<&mut Self>,
1045            _: &mut Context<'_>,
1046        ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1047            Poll::Pending
1048        }
1049    }
1050
1051    struct PendingSubstream {
1052        _weak: Weak<()>,
1053    }
1054
1055    impl AsyncRead for PendingSubstream {
1056        fn poll_read(
1057            self: Pin<&mut Self>,
1058            _cx: &mut Context<'_>,
1059            _buf: &mut [u8],
1060        ) -> Poll<std::io::Result<usize>> {
1061            Poll::Pending
1062        }
1063    }
1064
1065    impl AsyncWrite for PendingSubstream {
1066        fn poll_write(
1067            self: Pin<&mut Self>,
1068            _cx: &mut Context<'_>,
1069            _buf: &[u8],
1070        ) -> Poll<std::io::Result<usize>> {
1071            Poll::Pending
1072        }
1073
1074        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1075            Poll::Pending
1076        }
1077
1078        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1079            Poll::Pending
1080        }
1081    }
1082
1083    struct MockConnectionHandler {
1084        outbound_requested: bool,
1085        error: Option<StreamUpgradeError<Void>>,
1086        upgrade_timeout: Duration,
1087    }
1088
1089    impl MockConnectionHandler {
1090        fn new(upgrade_timeout: Duration) -> Self {
1091            Self {
1092                outbound_requested: false,
1093                error: None,
1094                upgrade_timeout,
1095            }
1096        }
1097
1098        fn open_new_outbound(&mut self) {
1099            self.outbound_requested = true;
1100        }
1101    }
1102
1103    #[derive(Default)]
1104    struct ConfigurableProtocolConnectionHandler {
1105        events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Void>>,
1106        active_protocols: HashSet<StreamProtocol>,
1107        local_added: Vec<Vec<StreamProtocol>>,
1108        local_removed: Vec<Vec<StreamProtocol>>,
1109        remote_added: Vec<Vec<StreamProtocol>>,
1110        remote_removed: Vec<Vec<StreamProtocol>>,
1111    }
1112
1113    impl ConfigurableProtocolConnectionHandler {
1114        fn listen_on(&mut self, protocols: &[&'static str]) {
1115            self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1116        }
1117
1118        fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1119            self.events
1120                .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1121                    ProtocolSupport::Added(
1122                        protocols.iter().copied().map(StreamProtocol::new).collect(),
1123                    ),
1124                ));
1125        }
1126
1127        fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1128            self.events
1129                .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1130                    ProtocolSupport::Removed(
1131                        protocols.iter().copied().map(StreamProtocol::new).collect(),
1132                    ),
1133                ));
1134        }
1135    }
1136
1137    impl ConnectionHandler for MockConnectionHandler {
1138        type FromBehaviour = Void;
1139        type ToBehaviour = Void;
1140        type InboundProtocol = DeniedUpgrade;
1141        type OutboundProtocol = DeniedUpgrade;
1142        type InboundOpenInfo = ();
1143        type OutboundOpenInfo = ();
1144
1145        fn listen_protocol(
1146            &self,
1147        ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1148            SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1149        }
1150
1151        fn on_connection_event(
1152            &mut self,
1153            event: ConnectionEvent<
1154                Self::InboundProtocol,
1155                Self::OutboundProtocol,
1156                Self::InboundOpenInfo,
1157                Self::OutboundOpenInfo,
1158            >,
1159        ) {
1160            match event {
1161                ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1162                    protocol,
1163                    ..
1164                }) => void::unreachable(protocol),
1165                ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1166                    protocol,
1167                    ..
1168                }) => void::unreachable(protocol),
1169                ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1170                    self.error = Some(error)
1171                }
1172                ConnectionEvent::AddressChange(_)
1173                | ConnectionEvent::ListenUpgradeError(_)
1174                | ConnectionEvent::LocalProtocolsChange(_)
1175                | ConnectionEvent::RemoteProtocolsChange(_) => {}
1176            }
1177        }
1178
1179        fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1180            void::unreachable(event)
1181        }
1182
1183        fn connection_keep_alive(&self) -> bool {
1184            true
1185        }
1186
1187        fn poll(
1188            &mut self,
1189            _: &mut Context<'_>,
1190        ) -> Poll<
1191            ConnectionHandlerEvent<
1192                Self::OutboundProtocol,
1193                Self::OutboundOpenInfo,
1194                Self::ToBehaviour,
1195            >,
1196        > {
1197            if self.outbound_requested {
1198                self.outbound_requested = false;
1199                return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1200                    protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1201                        .with_timeout(self.upgrade_timeout),
1202                });
1203            }
1204
1205            Poll::Pending
1206        }
1207    }
1208
1209    impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1210        type FromBehaviour = Void;
1211        type ToBehaviour = Void;
1212        type InboundProtocol = ManyProtocolsUpgrade;
1213        type OutboundProtocol = DeniedUpgrade;
1214        type InboundOpenInfo = ();
1215        type OutboundOpenInfo = ();
1216
1217        fn listen_protocol(
1218            &self,
1219        ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1220            SubstreamProtocol::new(
1221                ManyProtocolsUpgrade {
1222                    protocols: Vec::from_iter(self.active_protocols.clone()),
1223                },
1224                (),
1225            )
1226        }
1227
1228        fn on_connection_event(
1229            &mut self,
1230            event: ConnectionEvent<
1231                Self::InboundProtocol,
1232                Self::OutboundProtocol,
1233                Self::InboundOpenInfo,
1234                Self::OutboundOpenInfo,
1235            >,
1236        ) {
1237            match event {
1238                ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1239                    self.local_added.push(added.cloned().collect())
1240                }
1241                ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1242                    self.local_removed.push(removed.cloned().collect())
1243                }
1244                ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1245                    self.remote_added.push(added.cloned().collect())
1246                }
1247                ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1248                    self.remote_removed.push(removed.cloned().collect())
1249                }
1250                _ => {}
1251            }
1252        }
1253
1254        fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1255            void::unreachable(event)
1256        }
1257
1258        fn connection_keep_alive(&self) -> bool {
1259            true
1260        }
1261
1262        fn poll(
1263            &mut self,
1264            _: &mut Context<'_>,
1265        ) -> Poll<
1266            ConnectionHandlerEvent<
1267                Self::OutboundProtocol,
1268                Self::OutboundOpenInfo,
1269                Self::ToBehaviour,
1270            >,
1271        > {
1272            if let Some(event) = self.events.pop() {
1273                return Poll::Ready(event);
1274            }
1275
1276            Poll::Pending
1277        }
1278    }
1279
1280    struct ManyProtocolsUpgrade {
1281        protocols: Vec<StreamProtocol>,
1282    }
1283
1284    impl UpgradeInfo for ManyProtocolsUpgrade {
1285        type Info = StreamProtocol;
1286        type InfoIter = std::vec::IntoIter<Self::Info>;
1287
1288        fn protocol_info(&self) -> Self::InfoIter {
1289            self.protocols.clone().into_iter()
1290        }
1291    }
1292
1293    impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1294        type Output = C;
1295        type Error = Void;
1296        type Future = future::Ready<Result<Self::Output, Self::Error>>;
1297
1298        fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1299            future::ready(Ok(stream))
1300        }
1301    }
1302
1303    impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1304        type Output = C;
1305        type Error = Void;
1306        type Future = future::Ready<Result<Self::Output, Self::Error>>;
1307
1308        fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1309            future::ready(Ok(stream))
1310        }
1311    }
1312}
1313
1314/// The endpoint roles associated with a pending peer-to-peer connection.
1315#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1316enum PendingPoint {
1317    /// The socket comes from a dialer.
1318    ///
1319    /// There is no single address associated with the Dialer of a pending
1320    /// connection. Addresses are dialed in parallel. Only once the first dial
1321    /// is successful is the address of the connection known.
1322    Dialer {
1323        /// Same as [`ConnectedPoint::Dialer`] `role_override`.
1324        role_override: Endpoint,
1325    },
1326    /// The socket comes from a listener.
1327    Listener {
1328        /// Local connection address.
1329        local_addr: Multiaddr,
1330        /// Address used to send back data to the remote.
1331        send_back_addr: Multiaddr,
1332    },
1333}
1334
1335impl From<ConnectedPoint> for PendingPoint {
1336    fn from(endpoint: ConnectedPoint) -> Self {
1337        match endpoint {
1338            ConnectedPoint::Dialer { role_override, .. } => PendingPoint::Dialer { role_override },
1339            ConnectedPoint::Listener {
1340                local_addr,
1341                send_back_addr,
1342            } => PendingPoint::Listener {
1343                local_addr,
1344                send_back_addr,
1345            },
1346        }
1347    }
1348}