1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26pub use error::ConnectionError;
27pub(crate) use error::{
28 PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
29};
30pub use supported_protocols::SupportedProtocols;
31
32use crate::handler::{
33 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound,
34 FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange,
35 UpgradeInfoSend,
36};
37use crate::stream::ActiveStreamCounter;
38use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
39use crate::{
40 ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
41};
42use futures::future::BoxFuture;
43use futures::stream::FuturesUnordered;
44use futures::StreamExt;
45use futures::{stream, FutureExt};
46use futures_timer::Delay;
47use instant::Instant;
48use libp2p_core::connection::ConnectedPoint;
49use libp2p_core::multiaddr::Multiaddr;
50use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox};
51use libp2p_core::upgrade;
52use libp2p_core::upgrade::{NegotiationError, ProtocolError};
53use libp2p_core::Endpoint;
54use libp2p_identity::PeerId;
55use std::collections::HashSet;
56use std::fmt::{Display, Formatter};
57use std::future::Future;
58use std::sync::atomic::{AtomicUsize, Ordering};
59use std::task::Waker;
60use std::time::Duration;
61use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
62
63static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
64
65#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
67pub struct ConnectionId(usize);
68
69impl ConnectionId {
70 pub fn new_unchecked(id: usize) -> Self {
77 Self(id)
78 }
79
80 pub(crate) fn next() -> Self {
82 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
83 }
84}
85
86impl Display for ConnectionId {
87 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub(crate) struct Connected {
95 pub(crate) endpoint: ConnectedPoint,
97 pub(crate) peer_id: PeerId,
99}
100
101#[derive(Debug, Clone)]
103pub(crate) enum Event<T> {
104 Handler(T),
106 AddressChange(Multiaddr),
108}
109
110pub(crate) struct Connection<THandler>
112where
113 THandler: ConnectionHandler,
114{
115 muxing: StreamMuxerBox,
117 handler: THandler,
119 negotiating_in: FuturesUnordered<
121 StreamUpgrade<
122 THandler::InboundOpenInfo,
123 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
124 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
125 >,
126 >,
127 negotiating_out: FuturesUnordered<
129 StreamUpgrade<
130 THandler::OutboundOpenInfo,
131 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
132 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
133 >,
134 >,
135 shutdown: Shutdown,
137 substream_upgrade_protocol_override: Option<upgrade::Version>,
139 max_negotiating_inbound_streams: usize,
148 requested_substreams: FuturesUnordered<
153 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
154 >,
155
156 local_supported_protocols: HashSet<StreamProtocol>,
157 remote_supported_protocols: HashSet<StreamProtocol>,
158 idle_timeout: Duration,
159 stream_counter: ActiveStreamCounter,
160}
161
162impl<THandler> fmt::Debug for Connection<THandler>
163where
164 THandler: ConnectionHandler + fmt::Debug,
165 THandler::OutboundOpenInfo: fmt::Debug,
166{
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 f.debug_struct("Connection")
169 .field("handler", &self.handler)
170 .finish()
171 }
172}
173
174impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
175
176impl<THandler> Connection<THandler>
177where
178 THandler: ConnectionHandler,
179{
180 pub(crate) fn new(
183 muxer: StreamMuxerBox,
184 mut handler: THandler,
185 substream_upgrade_protocol_override: Option<upgrade::Version>,
186 max_negotiating_inbound_streams: usize,
187 idle_timeout: Duration,
188 ) -> Self {
189 let initial_protocols = gather_supported_protocols(&handler);
190 if !initial_protocols.is_empty() {
191 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
192 ProtocolsChange::Added(ProtocolsAdded::from_set(&initial_protocols)),
193 ));
194 }
195 Connection {
196 muxing: muxer,
197 handler,
198 negotiating_in: Default::default(),
199 negotiating_out: Default::default(),
200 shutdown: Shutdown::None,
201 substream_upgrade_protocol_override,
202 max_negotiating_inbound_streams,
203 requested_substreams: Default::default(),
204 local_supported_protocols: initial_protocols,
205 remote_supported_protocols: Default::default(),
206 idle_timeout,
207 stream_counter: ActiveStreamCounter::default(),
208 }
209 }
210
211 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
213 self.handler.on_behaviour_event(event);
214 }
215
216 pub(crate) fn close(
218 self,
219 ) -> (
220 impl futures::Stream<Item = THandler::ToBehaviour>,
221 impl Future<Output = io::Result<()>>,
222 ) {
223 let Connection {
224 mut handler,
225 muxing,
226 ..
227 } = self;
228
229 (
230 stream::poll_fn(move |cx| handler.poll_close(cx)),
231 muxing.close(),
232 )
233 }
234
235 #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
238 pub(crate) fn poll(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
242 let Self {
243 requested_substreams,
244 muxing,
245 handler,
246 negotiating_out,
247 negotiating_in,
248 shutdown,
249 max_negotiating_inbound_streams,
250 substream_upgrade_protocol_override,
251 local_supported_protocols: supported_protocols,
252 remote_supported_protocols,
253 idle_timeout,
254 stream_counter,
255 ..
256 } = self.get_mut();
257
258 loop {
259 match requested_substreams.poll_next_unpin(cx) {
260 Poll::Ready(Some(Ok(()))) => continue,
261 Poll::Ready(Some(Err(info))) => {
262 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
263 DialUpgradeError {
264 info,
265 error: StreamUpgradeError::Timeout,
266 },
267 ));
268 continue;
269 }
270 Poll::Ready(None) | Poll::Pending => {}
271 }
272
273 match handler.poll(cx) {
275 Poll::Pending => {}
276 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
277 let timeout = *protocol.timeout();
278 let (upgrade, user_data) = protocol.into_upgrade();
279
280 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
281 continue; }
283 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
284 return Poll::Ready(Ok(Event::Handler(event)));
285 }
286 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
287 ProtocolSupport::Added(protocols),
288 )) => {
289 if let Some(added) =
290 ProtocolsChange::add(remote_supported_protocols, &protocols)
291 {
292 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
293 remote_supported_protocols.extend(protocols);
294 }
295
296 continue;
297 }
298 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
299 ProtocolSupport::Removed(protocols),
300 )) => {
301 if let Some(removed) =
302 ProtocolsChange::remove(remote_supported_protocols, &protocols)
303 {
304 handler
305 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
306 remote_supported_protocols.retain(|p| !protocols.contains(p));
307 }
308
309 continue;
310 }
311 }
312
313 match negotiating_out.poll_next_unpin(cx) {
315 Poll::Pending | Poll::Ready(None) => {}
316 Poll::Ready(Some((info, Ok(protocol)))) => {
317 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
318 FullyNegotiatedOutbound { protocol, info },
319 ));
320 continue;
321 }
322 Poll::Ready(Some((info, Err(error)))) => {
323 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
324 DialUpgradeError { info, error },
325 ));
326 continue;
327 }
328 }
329
330 match negotiating_in.poll_next_unpin(cx) {
333 Poll::Pending | Poll::Ready(None) => {}
334 Poll::Ready(Some((info, Ok(protocol)))) => {
335 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
336 FullyNegotiatedInbound { protocol, info },
337 ));
338 continue;
339 }
340 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
341 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
342 ListenUpgradeError { info, error },
343 ));
344 continue;
345 }
346 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
347 tracing::debug!("failed to upgrade inbound stream: {e}");
348 continue;
349 }
350 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
351 tracing::debug!("no protocol could be agreed upon for inbound stream");
352 continue;
353 }
354 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
355 tracing::debug!("inbound stream upgrade timed out");
356 continue;
357 }
358 }
359
360 if negotiating_in.is_empty()
363 && negotiating_out.is_empty()
364 && requested_substreams.is_empty()
365 && stream_counter.has_no_active_streams()
366 {
367 if let Some(new_timeout) =
368 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
369 {
370 *shutdown = new_timeout;
371 }
372
373 match shutdown {
374 Shutdown::None => {}
375 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
376 Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
377 Poll::Ready(_) => {
378 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
379 }
380 Poll::Pending => {}
381 },
382 }
383 } else {
384 *shutdown = Shutdown::None;
385 }
386
387 match muxing.poll_unpin(cx)? {
388 Poll::Pending => {}
389 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
390 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
391 new_address: &address,
392 }));
393 return Poll::Ready(Ok(Event::AddressChange(address)));
394 }
395 }
396
397 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
398 match muxing.poll_outbound_unpin(cx)? {
399 Poll::Pending => {}
400 Poll::Ready(substream) => {
401 let (user_data, timeout, upgrade) = requested_substream.extract();
402
403 negotiating_out.push(StreamUpgrade::new_outbound(
404 substream,
405 user_data,
406 timeout,
407 upgrade,
408 *substream_upgrade_protocol_override,
409 stream_counter.clone(),
410 ));
411
412 continue; }
414 }
415 }
416
417 if negotiating_in.len() < *max_negotiating_inbound_streams {
418 match muxing.poll_inbound_unpin(cx)? {
419 Poll::Pending => {}
420 Poll::Ready(substream) => {
421 let protocol = handler.listen_protocol();
422
423 negotiating_in.push(StreamUpgrade::new_inbound(
424 substream,
425 protocol,
426 stream_counter.clone(),
427 ));
428
429 continue; }
431 }
432 }
433
434 let new_protocols = gather_supported_protocols(handler);
435 let changes = ProtocolsChange::from_full_sets(supported_protocols, &new_protocols);
436
437 if !changes.is_empty() {
438 for change in changes {
439 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
440 }
441
442 *supported_protocols = new_protocols;
443
444 continue; }
446
447 return Poll::Pending; }
449 }
450
451 #[cfg(test)]
452 fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
453 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
454 }
455}
456
457fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet<StreamProtocol> {
458 handler
459 .listen_protocol()
460 .upgrade()
461 .protocol_info()
462 .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok())
463 .collect()
464}
465
466fn compute_new_shutdown(
467 handler_keep_alive: bool,
468 current_shutdown: &Shutdown,
469 idle_timeout: Duration,
470) -> Option<Shutdown> {
471 match (current_shutdown, handler_keep_alive) {
472 (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
473 (Shutdown::Later(_), false) => None, (_, false) => {
475 let now = Instant::now();
476 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
477
478 Some(Shutdown::Later(Delay::new(safe_keep_alive)))
479 }
480 (_, true) => Some(Shutdown::None),
481 }
482}
483
484fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
489 while start.checked_add(duration).is_none() {
490 tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
491
492 duration /= 2;
493 }
494
495 duration
496}
497
498#[derive(Debug, Copy, Clone)]
500pub(crate) struct IncomingInfo<'a> {
501 pub(crate) local_addr: &'a Multiaddr,
503 pub(crate) send_back_addr: &'a Multiaddr,
505}
506
507impl<'a> IncomingInfo<'a> {
508 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
510 ConnectedPoint::Listener {
511 local_addr: self.local_addr.clone(),
512 send_back_addr: self.send_back_addr.clone(),
513 }
514 }
515}
516
517struct StreamUpgrade<UserData, TOk, TErr> {
518 user_data: Option<UserData>,
519 timeout: Delay,
520 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
521}
522
523impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
524 fn new_outbound<Upgrade>(
525 substream: SubstreamBox,
526 user_data: UserData,
527 timeout: Delay,
528 upgrade: Upgrade,
529 version_override: Option<upgrade::Version>,
530 counter: ActiveStreamCounter,
531 ) -> Self
532 where
533 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
534 {
535 let effective_version = match version_override {
536 Some(version_override) if version_override != upgrade::Version::default() => {
537 tracing::debug!(
538 "Substream upgrade protocol override: {:?} -> {:?}",
539 upgrade::Version::default(),
540 version_override
541 );
542
543 version_override
544 }
545 _ => upgrade::Version::default(),
546 };
547 let protocols = upgrade.protocol_info();
548
549 Self {
550 user_data: Some(user_data),
551 timeout,
552 upgrade: Box::pin(async move {
553 let (info, stream) = multistream_select::dialer_select_proto(
554 substream,
555 protocols,
556 effective_version,
557 )
558 .await
559 .map_err(to_stream_upgrade_error)?;
560
561 let output = upgrade
562 .upgrade_outbound(Stream::new(stream, counter), info)
563 .await
564 .map_err(StreamUpgradeError::Apply)?;
565
566 Ok(output)
567 }),
568 }
569 }
570}
571
572impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
573 fn new_inbound<Upgrade>(
574 substream: SubstreamBox,
575 protocol: SubstreamProtocol<Upgrade, UserData>,
576 counter: ActiveStreamCounter,
577 ) -> Self
578 where
579 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
580 {
581 let timeout = *protocol.timeout();
582 let (upgrade, open_info) = protocol.into_upgrade();
583 let protocols = upgrade.protocol_info();
584
585 Self {
586 user_data: Some(open_info),
587 timeout: Delay::new(timeout),
588 upgrade: Box::pin(async move {
589 let (info, stream) =
590 multistream_select::listener_select_proto(substream, protocols)
591 .await
592 .map_err(to_stream_upgrade_error)?;
593
594 let output = upgrade
595 .upgrade_inbound(Stream::new(stream, counter), info)
596 .await
597 .map_err(StreamUpgradeError::Apply)?;
598
599 Ok(output)
600 }),
601 }
602 }
603}
604
605fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
606 match e {
607 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
608 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
609 NegotiationError::ProtocolError(other) => {
610 StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
611 }
612 }
613}
614
615impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
616
617impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
618 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
619
620 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
621 match self.timeout.poll_unpin(cx) {
622 Poll::Ready(()) => {
623 return Poll::Ready((
624 self.user_data
625 .take()
626 .expect("Future not to be polled again once ready."),
627 Err(StreamUpgradeError::Timeout),
628 ))
629 }
630
631 Poll::Pending => {}
632 }
633
634 let result = futures::ready!(self.upgrade.poll_unpin(cx));
635 let user_data = self
636 .user_data
637 .take()
638 .expect("Future not to be polled again once ready.");
639
640 Poll::Ready((user_data, result))
641 }
642}
643
644enum SubstreamRequested<UserData, Upgrade> {
645 Waiting {
646 user_data: UserData,
647 timeout: Delay,
648 upgrade: Upgrade,
649 extracted_waker: Option<Waker>,
654 },
655 Done,
656}
657
658impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
659 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
660 Self::Waiting {
661 user_data,
662 timeout: Delay::new(timeout),
663 upgrade,
664 extracted_waker: None,
665 }
666 }
667
668 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
669 match mem::replace(self, Self::Done) {
670 SubstreamRequested::Waiting {
671 user_data,
672 timeout,
673 upgrade,
674 extracted_waker: waker,
675 } => {
676 if let Some(waker) = waker {
677 waker.wake();
678 }
679
680 (user_data, timeout, upgrade)
681 }
682 SubstreamRequested::Done => panic!("cannot extract twice"),
683 }
684 }
685}
686
687impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
688
689impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
690 type Output = Result<(), UserData>;
691
692 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
693 let this = self.get_mut();
694
695 match mem::replace(this, Self::Done) {
696 SubstreamRequested::Waiting {
697 user_data,
698 upgrade,
699 mut timeout,
700 ..
701 } => match timeout.poll_unpin(cx) {
702 Poll::Ready(()) => Poll::Ready(Err(user_data)),
703 Poll::Pending => {
704 *this = Self::Waiting {
705 user_data,
706 upgrade,
707 timeout,
708 extracted_waker: Some(cx.waker().clone()),
709 };
710 Poll::Pending
711 }
712 },
713 SubstreamRequested::Done => Poll::Ready(Ok(())),
714 }
715 }
716}
717
718#[derive(Debug)]
728enum Shutdown {
729 None,
731 Asap,
733 Later(Delay),
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740 use crate::dummy;
741 use futures::future;
742 use futures::AsyncRead;
743 use futures::AsyncWrite;
744 use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
745 use libp2p_core::StreamMuxer;
746 use quickcheck::*;
747 use std::sync::{Arc, Weak};
748 use std::time::Instant;
749 use tracing_subscriber::EnvFilter;
750 use void::Void;
751
752 #[test]
753 fn max_negotiating_inbound_streams() {
754 let _ = tracing_subscriber::fmt()
755 .with_env_filter(EnvFilter::from_default_env())
756 .try_init();
757
758 fn prop(max_negotiating_inbound_streams: u8) {
759 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
760
761 let alive_substream_counter = Arc::new(());
762 let mut connection = Connection::new(
763 StreamMuxerBox::new(DummyStreamMuxer {
764 counter: alive_substream_counter.clone(),
765 }),
766 MockConnectionHandler::new(Duration::from_secs(10)),
767 None,
768 max_negotiating_inbound_streams,
769 Duration::ZERO,
770 );
771
772 let result = connection.poll_noop_waker();
773
774 assert!(result.is_pending());
775 assert_eq!(
776 Arc::weak_count(&alive_substream_counter),
777 max_negotiating_inbound_streams,
778 "Expect no more than the maximum number of allowed streams"
779 );
780 }
781
782 QuickCheck::new().quickcheck(prop as fn(_));
783 }
784
785 #[test]
786 fn outbound_stream_timeout_starts_on_request() {
787 let upgrade_timeout = Duration::from_secs(1);
788 let mut connection = Connection::new(
789 StreamMuxerBox::new(PendingStreamMuxer),
790 MockConnectionHandler::new(upgrade_timeout),
791 None,
792 2,
793 Duration::ZERO,
794 );
795
796 connection.handler.open_new_outbound();
797 let _ = connection.poll_noop_waker();
798
799 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
800
801 let _ = connection.poll_noop_waker();
802
803 assert!(matches!(
804 connection.handler.error.unwrap(),
805 StreamUpgradeError::Timeout
806 ))
807 }
808
809 #[test]
810 fn propagates_changes_to_supported_inbound_protocols() {
811 let mut connection = Connection::new(
812 StreamMuxerBox::new(PendingStreamMuxer),
813 ConfigurableProtocolConnectionHandler::default(),
814 None,
815 0,
816 Duration::ZERO,
817 );
818
819 connection.handler.listen_on(&["/foo"]);
821 let _ = connection.poll_noop_waker();
822
823 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
824 assert!(connection.handler.local_removed.is_empty());
825
826 connection.handler.listen_on(&["/foo", "/bar"]);
828 let _ = connection.poll_noop_waker();
829
830 assert_eq!(
831 connection.handler.local_added,
832 vec![vec!["/foo"], vec!["/bar"]],
833 "expect to only receive an event for the newly added protocols"
834 );
835 assert!(connection.handler.local_removed.is_empty());
836
837 connection.handler.listen_on(&["/bar"]);
839 let _ = connection.poll_noop_waker();
840
841 assert_eq!(
842 connection.handler.local_added,
843 vec![vec!["/foo"], vec!["/bar"]]
844 );
845 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
846 }
847
848 #[test]
849 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
850 let mut connection = Connection::new(
851 StreamMuxerBox::new(PendingStreamMuxer),
852 ConfigurableProtocolConnectionHandler::default(),
853 None,
854 0,
855 Duration::ZERO,
856 );
857
858 connection.handler.remote_adds_support_for(&["/foo"]);
860 let _ = connection.poll_noop_waker();
861
862 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
863 assert!(connection.handler.remote_removed.is_empty());
864
865 connection
867 .handler
868 .remote_adds_support_for(&["/foo", "/bar"]);
869 let _ = connection.poll_noop_waker();
870
871 assert_eq!(
872 connection.handler.remote_added,
873 vec![vec!["/foo"], vec!["/bar"]],
874 "expect to only receive an event for the newly added protocol"
875 );
876 assert!(connection.handler.remote_removed.is_empty());
877
878 connection.handler.remote_removes_support_for(&["/baz"]);
880 let _ = connection.poll_noop_waker();
881
882 assert_eq!(
883 connection.handler.remote_added,
884 vec![vec!["/foo"], vec!["/bar"]]
885 );
886 assert!(&connection.handler.remote_removed.is_empty());
887
888 connection.handler.remote_removes_support_for(&["/bar"]);
890 let _ = connection.poll_noop_waker();
891
892 assert_eq!(
893 connection.handler.remote_added,
894 vec![vec!["/foo"], vec!["/bar"]]
895 );
896 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
897 }
898
899 #[tokio::test]
900 async fn idle_timeout_with_keep_alive_no() {
901 let idle_timeout = Duration::from_millis(100);
902
903 let mut connection = Connection::new(
904 StreamMuxerBox::new(PendingStreamMuxer),
905 dummy::ConnectionHandler,
906 None,
907 0,
908 idle_timeout,
909 );
910
911 assert!(connection.poll_noop_waker().is_pending());
912
913 tokio::time::sleep(idle_timeout).await;
914
915 assert!(matches!(
916 connection.poll_noop_waker(),
917 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
918 ));
919 }
920
921 #[test]
922 fn checked_add_fraction_can_add_u64_max() {
923 let _ = tracing_subscriber::fmt()
924 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
925 .try_init();
926 let start = Instant::now();
927
928 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
929
930 assert!(start.checked_add(duration).is_some())
931 }
932
933 #[test]
934 fn compute_new_shutdown_does_not_panic() {
935 let _ = tracing_subscriber::fmt()
936 .with_env_filter(EnvFilter::from_default_env())
937 .try_init();
938
939 #[derive(Debug)]
940 struct ArbitraryShutdown(Shutdown);
941
942 impl Clone for ArbitraryShutdown {
943 fn clone(&self) -> Self {
944 let shutdown = match self.0 {
945 Shutdown::None => Shutdown::None,
946 Shutdown::Asap => Shutdown::Asap,
947 Shutdown::Later(_) => Shutdown::Later(
948 Delay::new(Duration::from_secs(1)),
951 ),
952 };
953
954 ArbitraryShutdown(shutdown)
955 }
956 }
957
958 impl Arbitrary for ArbitraryShutdown {
959 fn arbitrary(g: &mut Gen) -> Self {
960 let shutdown = match g.gen_range(1u8..4) {
961 1 => Shutdown::None,
962 2 => Shutdown::Asap,
963 3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
964 _ => unreachable!(),
965 };
966
967 Self(shutdown)
968 }
969 }
970
971 fn prop(
972 handler_keep_alive: bool,
973 current_shutdown: ArbitraryShutdown,
974 idle_timeout: Duration,
975 ) {
976 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
977 }
978
979 QuickCheck::new().quickcheck(prop as fn(_, _, _));
980 }
981
982 struct DummyStreamMuxer {
983 counter: Arc<()>,
984 }
985
986 impl StreamMuxer for DummyStreamMuxer {
987 type Substream = PendingSubstream;
988 type Error = Void;
989
990 fn poll_inbound(
991 self: Pin<&mut Self>,
992 _: &mut Context<'_>,
993 ) -> Poll<Result<Self::Substream, Self::Error>> {
994 Poll::Ready(Ok(PendingSubstream {
995 _weak: Arc::downgrade(&self.counter),
996 }))
997 }
998
999 fn poll_outbound(
1000 self: Pin<&mut Self>,
1001 _: &mut Context<'_>,
1002 ) -> Poll<Result<Self::Substream, Self::Error>> {
1003 Poll::Pending
1004 }
1005
1006 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1007 Poll::Ready(Ok(()))
1008 }
1009
1010 fn poll(
1011 self: Pin<&mut Self>,
1012 _: &mut Context<'_>,
1013 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1014 Poll::Pending
1015 }
1016 }
1017
1018 struct PendingStreamMuxer;
1020
1021 impl StreamMuxer for PendingStreamMuxer {
1022 type Substream = PendingSubstream;
1023 type Error = Void;
1024
1025 fn poll_inbound(
1026 self: Pin<&mut Self>,
1027 _: &mut Context<'_>,
1028 ) -> Poll<Result<Self::Substream, Self::Error>> {
1029 Poll::Pending
1030 }
1031
1032 fn poll_outbound(
1033 self: Pin<&mut Self>,
1034 _: &mut Context<'_>,
1035 ) -> Poll<Result<Self::Substream, Self::Error>> {
1036 Poll::Pending
1037 }
1038
1039 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1040 Poll::Pending
1041 }
1042
1043 fn poll(
1044 self: Pin<&mut Self>,
1045 _: &mut Context<'_>,
1046 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1047 Poll::Pending
1048 }
1049 }
1050
1051 struct PendingSubstream {
1052 _weak: Weak<()>,
1053 }
1054
1055 impl AsyncRead for PendingSubstream {
1056 fn poll_read(
1057 self: Pin<&mut Self>,
1058 _cx: &mut Context<'_>,
1059 _buf: &mut [u8],
1060 ) -> Poll<std::io::Result<usize>> {
1061 Poll::Pending
1062 }
1063 }
1064
1065 impl AsyncWrite for PendingSubstream {
1066 fn poll_write(
1067 self: Pin<&mut Self>,
1068 _cx: &mut Context<'_>,
1069 _buf: &[u8],
1070 ) -> Poll<std::io::Result<usize>> {
1071 Poll::Pending
1072 }
1073
1074 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1075 Poll::Pending
1076 }
1077
1078 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1079 Poll::Pending
1080 }
1081 }
1082
1083 struct MockConnectionHandler {
1084 outbound_requested: bool,
1085 error: Option<StreamUpgradeError<Void>>,
1086 upgrade_timeout: Duration,
1087 }
1088
1089 impl MockConnectionHandler {
1090 fn new(upgrade_timeout: Duration) -> Self {
1091 Self {
1092 outbound_requested: false,
1093 error: None,
1094 upgrade_timeout,
1095 }
1096 }
1097
1098 fn open_new_outbound(&mut self) {
1099 self.outbound_requested = true;
1100 }
1101 }
1102
1103 #[derive(Default)]
1104 struct ConfigurableProtocolConnectionHandler {
1105 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Void>>,
1106 active_protocols: HashSet<StreamProtocol>,
1107 local_added: Vec<Vec<StreamProtocol>>,
1108 local_removed: Vec<Vec<StreamProtocol>>,
1109 remote_added: Vec<Vec<StreamProtocol>>,
1110 remote_removed: Vec<Vec<StreamProtocol>>,
1111 }
1112
1113 impl ConfigurableProtocolConnectionHandler {
1114 fn listen_on(&mut self, protocols: &[&'static str]) {
1115 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1116 }
1117
1118 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1119 self.events
1120 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1121 ProtocolSupport::Added(
1122 protocols.iter().copied().map(StreamProtocol::new).collect(),
1123 ),
1124 ));
1125 }
1126
1127 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1128 self.events
1129 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1130 ProtocolSupport::Removed(
1131 protocols.iter().copied().map(StreamProtocol::new).collect(),
1132 ),
1133 ));
1134 }
1135 }
1136
1137 impl ConnectionHandler for MockConnectionHandler {
1138 type FromBehaviour = Void;
1139 type ToBehaviour = Void;
1140 type InboundProtocol = DeniedUpgrade;
1141 type OutboundProtocol = DeniedUpgrade;
1142 type InboundOpenInfo = ();
1143 type OutboundOpenInfo = ();
1144
1145 fn listen_protocol(
1146 &self,
1147 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1148 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1149 }
1150
1151 fn on_connection_event(
1152 &mut self,
1153 event: ConnectionEvent<
1154 Self::InboundProtocol,
1155 Self::OutboundProtocol,
1156 Self::InboundOpenInfo,
1157 Self::OutboundOpenInfo,
1158 >,
1159 ) {
1160 match event {
1161 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1162 protocol,
1163 ..
1164 }) => void::unreachable(protocol),
1165 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1166 protocol,
1167 ..
1168 }) => void::unreachable(protocol),
1169 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1170 self.error = Some(error)
1171 }
1172 ConnectionEvent::AddressChange(_)
1173 | ConnectionEvent::ListenUpgradeError(_)
1174 | ConnectionEvent::LocalProtocolsChange(_)
1175 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1176 }
1177 }
1178
1179 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1180 void::unreachable(event)
1181 }
1182
1183 fn connection_keep_alive(&self) -> bool {
1184 true
1185 }
1186
1187 fn poll(
1188 &mut self,
1189 _: &mut Context<'_>,
1190 ) -> Poll<
1191 ConnectionHandlerEvent<
1192 Self::OutboundProtocol,
1193 Self::OutboundOpenInfo,
1194 Self::ToBehaviour,
1195 >,
1196 > {
1197 if self.outbound_requested {
1198 self.outbound_requested = false;
1199 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1200 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1201 .with_timeout(self.upgrade_timeout),
1202 });
1203 }
1204
1205 Poll::Pending
1206 }
1207 }
1208
1209 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1210 type FromBehaviour = Void;
1211 type ToBehaviour = Void;
1212 type InboundProtocol = ManyProtocolsUpgrade;
1213 type OutboundProtocol = DeniedUpgrade;
1214 type InboundOpenInfo = ();
1215 type OutboundOpenInfo = ();
1216
1217 fn listen_protocol(
1218 &self,
1219 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1220 SubstreamProtocol::new(
1221 ManyProtocolsUpgrade {
1222 protocols: Vec::from_iter(self.active_protocols.clone()),
1223 },
1224 (),
1225 )
1226 }
1227
1228 fn on_connection_event(
1229 &mut self,
1230 event: ConnectionEvent<
1231 Self::InboundProtocol,
1232 Self::OutboundProtocol,
1233 Self::InboundOpenInfo,
1234 Self::OutboundOpenInfo,
1235 >,
1236 ) {
1237 match event {
1238 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1239 self.local_added.push(added.cloned().collect())
1240 }
1241 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1242 self.local_removed.push(removed.cloned().collect())
1243 }
1244 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1245 self.remote_added.push(added.cloned().collect())
1246 }
1247 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1248 self.remote_removed.push(removed.cloned().collect())
1249 }
1250 _ => {}
1251 }
1252 }
1253
1254 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1255 void::unreachable(event)
1256 }
1257
1258 fn connection_keep_alive(&self) -> bool {
1259 true
1260 }
1261
1262 fn poll(
1263 &mut self,
1264 _: &mut Context<'_>,
1265 ) -> Poll<
1266 ConnectionHandlerEvent<
1267 Self::OutboundProtocol,
1268 Self::OutboundOpenInfo,
1269 Self::ToBehaviour,
1270 >,
1271 > {
1272 if let Some(event) = self.events.pop() {
1273 return Poll::Ready(event);
1274 }
1275
1276 Poll::Pending
1277 }
1278 }
1279
1280 struct ManyProtocolsUpgrade {
1281 protocols: Vec<StreamProtocol>,
1282 }
1283
1284 impl UpgradeInfo for ManyProtocolsUpgrade {
1285 type Info = StreamProtocol;
1286 type InfoIter = std::vec::IntoIter<Self::Info>;
1287
1288 fn protocol_info(&self) -> Self::InfoIter {
1289 self.protocols.clone().into_iter()
1290 }
1291 }
1292
1293 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1294 type Output = C;
1295 type Error = Void;
1296 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1297
1298 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1299 future::ready(Ok(stream))
1300 }
1301 }
1302
1303 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1304 type Output = C;
1305 type Error = Void;
1306 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1307
1308 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1309 future::ready(Ok(stream))
1310 }
1311 }
1312}
1313
1314#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1316enum PendingPoint {
1317 Dialer {
1323 role_override: Endpoint,
1325 },
1326 Listener {
1328 local_addr: Multiaddr,
1330 send_back_addr: Multiaddr,
1332 },
1333}
1334
1335impl From<ConnectedPoint> for PendingPoint {
1336 fn from(endpoint: ConnectedPoint) -> Self {
1337 match endpoint {
1338 ConnectedPoint::Dialer { role_override, .. } => PendingPoint::Dialer { role_override },
1339 ConnectedPoint::Listener {
1340 local_addr,
1341 send_back_addr,
1342 } => PendingPoint::Listener {
1343 local_addr,
1344 send_back_addr,
1345 },
1346 }
1347 }
1348}