1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
30
31mod provider;
32
33#[cfg(feature = "async-io")]
34pub use provider::async_io;
35
36#[cfg(feature = "tokio")]
37pub use provider::tokio;
38
39use futures::{
40 future::{self, Ready},
41 prelude::*,
42 stream::SelectAll,
43};
44use futures_timer::Delay;
45use if_watch::IfEvent;
46use libp2p_core::{
47 address_translation,
48 multiaddr::{Multiaddr, Protocol},
49 transport::{ListenerId, TransportError, TransportEvent},
50};
51use provider::{Incoming, Provider};
52use socket2::{Domain, Socket, Type};
53use std::{
54 collections::{HashSet, VecDeque},
55 io,
56 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener},
57 pin::Pin,
58 sync::{Arc, RwLock},
59 task::{Context, Poll, Waker},
60 time::Duration,
61};
62
63#[derive(Clone, Debug)]
65pub struct Config {
66 ttl: Option<u32>,
68 nodelay: Option<bool>,
70 backlog: u32,
72 enable_port_reuse: bool,
74}
75
76type Port = u16;
77
78#[derive(Debug, Clone)]
80enum PortReuse {
81 Disabled,
84 Enabled {
89 listen_addrs: Arc<RwLock<HashSet<(IpAddr, Port)>>>,
92 },
93}
94
95impl PortReuse {
96 fn register(&mut self, ip: IpAddr, port: Port) {
100 if let PortReuse::Enabled { listen_addrs } = self {
101 tracing::trace!(%ip, %port, "Registering for port reuse");
102 listen_addrs
103 .write()
104 .expect("`register()` and `unregister()` never panic while holding the lock")
105 .insert((ip, port));
106 }
107 }
108
109 fn unregister(&mut self, ip: IpAddr, port: Port) {
113 if let PortReuse::Enabled { listen_addrs } = self {
114 tracing::trace!(%ip, %port, "Unregistering for port reuse");
115 listen_addrs
116 .write()
117 .expect("`register()` and `unregister()` never panic while holding the lock")
118 .remove(&(ip, port));
119 }
120 }
121
122 fn local_dial_addr(&self, remote_ip: &IpAddr) -> Option<SocketAddr> {
132 if let PortReuse::Enabled { listen_addrs } = self {
133 for (ip, port) in listen_addrs
134 .read()
135 .expect("`local_dial_addr` never panic while holding the lock")
136 .iter()
137 {
138 if ip.is_ipv4() == remote_ip.is_ipv4()
139 && ip.is_loopback() == remote_ip.is_loopback()
140 {
141 if remote_ip.is_ipv4() {
142 return Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port));
143 } else {
144 return Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port));
145 }
146 }
147 }
148 }
149
150 None
151 }
152}
153
154impl Config {
155 pub fn new() -> Self {
166 Self {
167 ttl: None,
168 nodelay: None,
169 backlog: 1024,
170 enable_port_reuse: false,
171 }
172 }
173
174 pub fn ttl(mut self, value: u32) -> Self {
176 self.ttl = Some(value);
177 self
178 }
179
180 pub fn nodelay(mut self, value: bool) -> Self {
182 self.nodelay = Some(value);
183 self
184 }
185
186 pub fn listen_backlog(mut self, backlog: u32) -> Self {
188 self.backlog = backlog;
189 self
190 }
191
192 pub fn port_reuse(mut self, port_reuse: bool) -> Self {
288 self.enable_port_reuse = port_reuse;
289 self
290 }
291}
292
293impl Default for Config {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299pub struct Transport<T>
306where
307 T: Provider + Send,
308{
309 config: Config,
310
311 port_reuse: PortReuse,
313 listeners: SelectAll<ListenStream<T>>,
317 pending_events:
319 VecDeque<TransportEvent<<Self as libp2p_core::Transport>::ListenerUpgrade, io::Error>>,
320}
321
322impl<T> Transport<T>
323where
324 T: Provider + Send,
325{
326 pub fn new(config: Config) -> Self {
335 let port_reuse = if config.enable_port_reuse {
336 PortReuse::Enabled {
337 listen_addrs: Arc::new(RwLock::new(HashSet::new())),
338 }
339 } else {
340 PortReuse::Disabled
341 };
342 Transport {
343 config,
344 port_reuse,
345 ..Default::default()
346 }
347 }
348
349 fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
350 let socket = Socket::new(
351 Domain::for_address(socket_addr),
352 Type::STREAM,
353 Some(socket2::Protocol::TCP),
354 )?;
355 if socket_addr.is_ipv6() {
356 socket.set_only_v6(true)?;
357 }
358 if let Some(ttl) = self.config.ttl {
359 socket.set_ttl(ttl)?;
360 }
361 if let Some(nodelay) = self.config.nodelay {
362 socket.set_nodelay(nodelay)?;
363 }
364 socket.set_reuse_address(true)?;
365 #[cfg(unix)]
366 if let PortReuse::Enabled { .. } = &self.port_reuse {
367 socket.set_reuse_port(true)?;
368 }
369 Ok(socket)
370 }
371
372 fn do_listen(
373 &mut self,
374 id: ListenerId,
375 socket_addr: SocketAddr,
376 ) -> io::Result<ListenStream<T>> {
377 let socket = self.create_socket(socket_addr)?;
378 socket.bind(&socket_addr.into())?;
379 socket.listen(self.config.backlog as _)?;
380 socket.set_nonblocking(true)?;
381 let listener: TcpListener = socket.into();
382 let local_addr = listener.local_addr()?;
383
384 if local_addr.ip().is_unspecified() {
385 return ListenStream::<T>::new(
386 id,
387 listener,
388 Some(T::new_if_watcher()?),
389 self.port_reuse.clone(),
390 );
391 }
392
393 self.port_reuse.register(local_addr.ip(), local_addr.port());
394 let listen_addr = ip_to_multiaddr(local_addr.ip(), local_addr.port());
395 self.pending_events.push_back(TransportEvent::NewAddress {
396 listener_id: id,
397 listen_addr,
398 });
399 ListenStream::<T>::new(id, listener, None, self.port_reuse.clone())
400 }
401}
402
403impl<T> Default for Transport<T>
404where
405 T: Provider + Send,
406{
407 fn default() -> Self {
411 let config = Config::default();
412 let port_reuse = if config.enable_port_reuse {
413 PortReuse::Enabled {
414 listen_addrs: Arc::new(RwLock::new(HashSet::new())),
415 }
416 } else {
417 PortReuse::Disabled
418 };
419 Transport {
420 port_reuse,
421 config,
422 listeners: SelectAll::new(),
423 pending_events: VecDeque::new(),
424 }
425 }
426}
427
428impl<T> libp2p_core::Transport for Transport<T>
429where
430 T: Provider + Send + 'static,
431 T::Listener: Unpin,
432 T::Stream: Unpin,
433{
434 type Output = T::Stream;
435 type Error = io::Error;
436 type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
437 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
438
439 fn listen_on(
440 &mut self,
441 id: ListenerId,
442 addr: Multiaddr,
443 ) -> Result<(), TransportError<Self::Error>> {
444 let socket_addr = if let Ok(sa) = multiaddr_to_socketaddr(addr.clone()) {
445 sa
446 } else {
447 return Err(TransportError::MultiaddrNotSupported(addr));
448 };
449 tracing::debug!(address=%socket_addr, "listening on address");
450 let listener = self
451 .do_listen(id, socket_addr)
452 .map_err(TransportError::Other)?;
453 self.listeners.push(listener);
454 Ok(())
455 }
456
457 fn remove_listener(&mut self, id: ListenerId) -> bool {
458 if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) {
459 listener.close(Ok(()));
460 true
461 } else {
462 false
463 }
464 }
465
466 fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
467 let socket_addr = if let Ok(socket_addr) = multiaddr_to_socketaddr(addr.clone()) {
468 if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() {
469 return Err(TransportError::MultiaddrNotSupported(addr));
470 }
471 socket_addr
472 } else {
473 return Err(TransportError::MultiaddrNotSupported(addr));
474 };
475 tracing::debug!(address=%socket_addr, "dialing address");
476
477 let socket = self
478 .create_socket(socket_addr)
479 .map_err(TransportError::Other)?;
480
481 if let Some(addr) = self.port_reuse.local_dial_addr(&socket_addr.ip()) {
482 tracing::trace!(address=%addr, "Binding dial socket to listen socket address");
483 socket.bind(&addr.into()).map_err(TransportError::Other)?;
484 }
485
486 socket
487 .set_nonblocking(true)
488 .map_err(TransportError::Other)?;
489
490 Ok(async move {
491 match socket.connect(&socket_addr.into()) {
494 Ok(()) => {}
495 Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
496 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
497 Err(err) => return Err(err),
498 };
499
500 let stream = T::new_stream(socket.into()).await?;
501 Ok(stream)
502 }
503 .boxed())
504 }
505
506 fn dial_as_listener(
507 &mut self,
508 addr: Multiaddr,
509 ) -> Result<Self::Dial, TransportError<Self::Error>> {
510 self.dial(addr)
511 }
512
513 fn address_translation(&self, listen: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
531 if !is_tcp_addr(listen) || !is_tcp_addr(observed) {
532 return None;
533 }
534 match &self.port_reuse {
535 PortReuse::Disabled => address_translation(listen, observed),
536 PortReuse::Enabled { .. } => Some(observed.clone()),
537 }
538 }
539
540 #[tracing::instrument(level = "trace", name = "Transport::poll", skip(self, cx))]
542 fn poll(
543 mut self: Pin<&mut Self>,
544 cx: &mut Context<'_>,
545 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
546 if let Some(event) = self.pending_events.pop_front() {
548 return Poll::Ready(event);
549 }
550
551 match self.listeners.poll_next_unpin(cx) {
552 Poll::Ready(Some(transport_event)) => Poll::Ready(transport_event),
553 _ => Poll::Pending,
554 }
555 }
556}
557
558struct ListenStream<T>
560where
561 T: Provider,
562{
563 listener_id: ListenerId,
565 listen_addr: SocketAddr,
569 listener: T::Listener,
571 if_watcher: Option<T::IfWatcher>,
577 port_reuse: PortReuse,
584 sleep_on_error: Duration,
587 pause: Option<Delay>,
589 pending_event: Option<<Self as Stream>::Item>,
591 is_closed: bool,
593 close_listener_waker: Option<Waker>,
595}
596
597impl<T> ListenStream<T>
598where
599 T: Provider,
600{
601 fn new(
604 listener_id: ListenerId,
605 listener: TcpListener,
606 if_watcher: Option<T::IfWatcher>,
607 port_reuse: PortReuse,
608 ) -> io::Result<Self> {
609 let listen_addr = listener.local_addr()?;
610 let listener = T::new_listener(listener)?;
611
612 Ok(ListenStream {
613 port_reuse,
614 listener,
615 listener_id,
616 listen_addr,
617 if_watcher,
618 pause: None,
619 sleep_on_error: Duration::from_millis(100),
620 pending_event: None,
621 is_closed: false,
622 close_listener_waker: None,
623 })
624 }
625
626 fn disable_port_reuse(&mut self) {
633 match &self.if_watcher {
634 Some(if_watcher) => {
635 for ip_net in T::addrs(if_watcher) {
636 self.port_reuse
637 .unregister(ip_net.addr(), self.listen_addr.port());
638 }
639 }
640 None => self
641 .port_reuse
642 .unregister(self.listen_addr.ip(), self.listen_addr.port()),
643 }
644 }
645
646 fn close(&mut self, reason: Result<(), io::Error>) {
651 if self.is_closed {
652 return;
653 }
654 self.pending_event = Some(TransportEvent::ListenerClosed {
655 listener_id: self.listener_id,
656 reason,
657 });
658 self.is_closed = true;
659
660 if let Some(waker) = self.close_listener_waker.take() {
662 waker.wake();
663 }
664 }
665
666 fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
668 let if_watcher = match self.if_watcher.as_mut() {
669 Some(if_watcher) => if_watcher,
670 None => return Poll::Pending,
671 };
672
673 let my_listen_addr_port = self.listen_addr.port();
674
675 while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) {
676 match event {
677 Ok(IfEvent::Up(inet)) => {
678 let ip = inet.addr();
679 if self.listen_addr.is_ipv4() == ip.is_ipv4() {
680 let ma = ip_to_multiaddr(ip, my_listen_addr_port);
681 tracing::debug!(address=%ma, "New listen address");
682 self.port_reuse.register(ip, my_listen_addr_port);
683 return Poll::Ready(TransportEvent::NewAddress {
684 listener_id: self.listener_id,
685 listen_addr: ma,
686 });
687 }
688 }
689 Ok(IfEvent::Down(inet)) => {
690 let ip = inet.addr();
691 if self.listen_addr.is_ipv4() == ip.is_ipv4() {
692 let ma = ip_to_multiaddr(ip, my_listen_addr_port);
693 tracing::debug!(address=%ma, "Expired listen address");
694 self.port_reuse.unregister(ip, my_listen_addr_port);
695 return Poll::Ready(TransportEvent::AddressExpired {
696 listener_id: self.listener_id,
697 listen_addr: ma,
698 });
699 }
700 }
701 Err(error) => {
702 self.pause = Some(Delay::new(self.sleep_on_error));
703 return Poll::Ready(TransportEvent::ListenerError {
704 listener_id: self.listener_id,
705 error,
706 });
707 }
708 }
709 }
710
711 Poll::Pending
712 }
713}
714
715impl<T> Drop for ListenStream<T>
716where
717 T: Provider,
718{
719 fn drop(&mut self) {
720 self.disable_port_reuse();
721 }
722}
723
724impl<T> Stream for ListenStream<T>
725where
726 T: Provider,
727 T::Listener: Unpin,
728 T::Stream: Unpin,
729{
730 type Item = TransportEvent<Ready<Result<T::Stream, io::Error>>, io::Error>;
731
732 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
733 if let Some(mut pause) = self.pause.take() {
734 match pause.poll_unpin(cx) {
735 Poll::Ready(_) => {}
736 Poll::Pending => {
737 self.pause = Some(pause);
738 return Poll::Pending;
739 }
740 }
741 }
742
743 if let Some(event) = self.pending_event.take() {
744 return Poll::Ready(Some(event));
745 }
746
747 if self.is_closed {
748 return Poll::Ready(None);
750 }
751
752 if let Poll::Ready(event) = self.poll_if_addr(cx) {
753 return Poll::Ready(Some(event));
754 }
755
756 match T::poll_accept(&mut self.listener, cx) {
758 Poll::Ready(Ok(Incoming {
759 local_addr,
760 remote_addr,
761 stream,
762 })) => {
763 let local_addr = ip_to_multiaddr(local_addr.ip(), local_addr.port());
764 let remote_addr = ip_to_multiaddr(remote_addr.ip(), remote_addr.port());
765
766 tracing::debug!(
767 remote_address=%remote_addr,
768 local_address=%local_addr,
769 "Incoming connection from remote at local"
770 );
771
772 return Poll::Ready(Some(TransportEvent::Incoming {
773 listener_id: self.listener_id,
774 upgrade: future::ok(stream),
775 local_addr,
776 send_back_addr: remote_addr,
777 }));
778 }
779 Poll::Ready(Err(error)) => {
780 self.pause = Some(Delay::new(self.sleep_on_error));
782 return Poll::Ready(Some(TransportEvent::ListenerError {
783 listener_id: self.listener_id,
784 error,
785 }));
786 }
787 Poll::Pending => {}
788 }
789
790 self.close_listener_waker = Some(cx.waker().clone());
791 Poll::Pending
792 }
793}
794
795fn multiaddr_to_socketaddr(mut addr: Multiaddr) -> Result<SocketAddr, ()> {
800 let mut port = None;
804 while let Some(proto) = addr.pop() {
805 match proto {
806 Protocol::Ip4(ipv4) => match port {
807 Some(port) => return Ok(SocketAddr::new(ipv4.into(), port)),
808 None => return Err(()),
809 },
810 Protocol::Ip6(ipv6) => match port {
811 Some(port) => return Ok(SocketAddr::new(ipv6.into(), port)),
812 None => return Err(()),
813 },
814 Protocol::Tcp(portnum) => match port {
815 Some(_) => return Err(()),
816 None => port = Some(portnum),
817 },
818 Protocol::P2p(_) => {}
819 _ => return Err(()),
820 }
821 }
822 Err(())
823}
824
825fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr {
827 Multiaddr::empty().with(ip.into()).with(Protocol::Tcp(port))
828}
829
830fn is_tcp_addr(addr: &Multiaddr) -> bool {
831 use Protocol::*;
832
833 let mut iter = addr.iter();
834
835 let first = match iter.next() {
836 None => return false,
837 Some(p) => p,
838 };
839 let second = match iter.next() {
840 None => return false,
841 Some(p) => p,
842 };
843
844 matches!(first, Ip4(_) | Ip6(_) | Dns(_) | Dns4(_) | Dns6(_)) && matches!(second, Tcp(_))
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use futures::{
851 channel::{mpsc, oneshot},
852 future::poll_fn,
853 };
854 use libp2p_core::Transport as _;
855 use libp2p_identity::PeerId;
856
857 #[test]
858 fn multiaddr_to_tcp_conversion() {
859 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
860
861 assert!(
862 multiaddr_to_socketaddr("/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap())
863 .is_err()
864 );
865
866 assert_eq!(
867 multiaddr_to_socketaddr("/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap()),
868 Ok(SocketAddr::new(
869 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
870 12345,
871 ))
872 );
873 assert_eq!(
874 multiaddr_to_socketaddr(
875 "/ip4/255.255.255.255/tcp/8080"
876 .parse::<Multiaddr>()
877 .unwrap()
878 ),
879 Ok(SocketAddr::new(
880 IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
881 8080,
882 ))
883 );
884 assert_eq!(
885 multiaddr_to_socketaddr("/ip6/::1/tcp/12345".parse::<Multiaddr>().unwrap()),
886 Ok(SocketAddr::new(
887 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
888 12345,
889 ))
890 );
891 assert_eq!(
892 multiaddr_to_socketaddr(
893 "/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/tcp/8080"
894 .parse::<Multiaddr>()
895 .unwrap()
896 ),
897 Ok(SocketAddr::new(
898 IpAddr::V6(Ipv6Addr::new(
899 65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535,
900 )),
901 8080,
902 ))
903 );
904 }
905
906 #[test]
907 fn communicating_between_dialer_and_listener() {
908 let _ = tracing_subscriber::fmt()
909 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
910 .try_init();
911
912 async fn listener<T: Provider>(addr: Multiaddr, mut ready_tx: mpsc::Sender<Multiaddr>) {
913 let mut tcp = Transport::<T>::default().boxed();
914 tcp.listen_on(ListenerId::next(), addr).unwrap();
915 loop {
916 match tcp.select_next_some().await {
917 TransportEvent::NewAddress { listen_addr, .. } => {
918 ready_tx.send(listen_addr).await.unwrap();
919 }
920 TransportEvent::Incoming { upgrade, .. } => {
921 let mut upgrade = upgrade.await.unwrap();
922 let mut buf = [0u8; 3];
923 upgrade.read_exact(&mut buf).await.unwrap();
924 assert_eq!(buf, [1, 2, 3]);
925 upgrade.write_all(&[4, 5, 6]).await.unwrap();
926 return;
927 }
928 e => panic!("Unexpected transport event: {e:?}"),
929 }
930 }
931 }
932
933 async fn dialer<T: Provider>(mut ready_rx: mpsc::Receiver<Multiaddr>) {
934 let addr = ready_rx.next().await.unwrap();
935 let mut tcp = Transport::<T>::default();
936
937 let mut socket = tcp.dial(addr.clone()).unwrap().await.unwrap();
939 socket.write_all(&[0x1, 0x2, 0x3]).await.unwrap();
940
941 let mut buf = [0u8; 3];
942 socket.read_exact(&mut buf).await.unwrap();
943 assert_eq!(buf, [4, 5, 6]);
944 }
945
946 fn test(addr: Multiaddr) {
947 #[cfg(feature = "async-io")]
948 {
949 let (ready_tx, ready_rx) = mpsc::channel(1);
950 let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx);
951 let dialer = dialer::<async_io::Tcp>(ready_rx);
952 let listener = async_std::task::spawn(listener);
953 async_std::task::block_on(dialer);
954 async_std::task::block_on(listener);
955 }
956
957 #[cfg(feature = "tokio")]
958 {
959 let (ready_tx, ready_rx) = mpsc::channel(1);
960 let listener = listener::<tokio::Tcp>(addr, ready_tx);
961 let dialer = dialer::<tokio::Tcp>(ready_rx);
962 let rt = ::tokio::runtime::Builder::new_current_thread()
963 .enable_io()
964 .build()
965 .unwrap();
966 let tasks = ::tokio::task::LocalSet::new();
967 let listener = tasks.spawn_local(listener);
968 tasks.block_on(&rt, dialer);
969 tasks.block_on(&rt, listener).unwrap();
970 }
971 }
972
973 test("/ip4/127.0.0.1/tcp/0".parse().unwrap());
974 test("/ip6/::1/tcp/0".parse().unwrap());
975 }
976
977 #[test]
978 fn wildcard_expansion() {
979 let _ = tracing_subscriber::fmt()
980 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
981 .try_init();
982
983 async fn listener<T: Provider>(addr: Multiaddr, mut ready_tx: mpsc::Sender<Multiaddr>) {
984 let mut tcp = Transport::<T>::default().boxed();
985 tcp.listen_on(ListenerId::next(), addr).unwrap();
986
987 loop {
988 match tcp.select_next_some().await {
989 TransportEvent::NewAddress { listen_addr, .. } => {
990 let mut iter = listen_addr.iter();
991 match iter.next().expect("ip address") {
992 Protocol::Ip4(ip) => assert!(!ip.is_unspecified()),
993 Protocol::Ip6(ip) => assert!(!ip.is_unspecified()),
994 other => panic!("Unexpected protocol: {other}"),
995 }
996 if let Protocol::Tcp(port) = iter.next().expect("port") {
997 assert_ne!(0, port)
998 } else {
999 panic!("No TCP port in address: {listen_addr}")
1000 }
1001 ready_tx.send(listen_addr).await.ok();
1002 }
1003 TransportEvent::Incoming { .. } => {
1004 return;
1005 }
1006 _ => {}
1007 }
1008 }
1009 }
1010
1011 async fn dialer<T: Provider>(mut ready_rx: mpsc::Receiver<Multiaddr>) {
1012 let dest_addr = ready_rx.next().await.unwrap();
1013 let mut tcp = Transport::<T>::default();
1014 tcp.dial(dest_addr).unwrap().await.unwrap();
1015 }
1016
1017 fn test(addr: Multiaddr) {
1018 #[cfg(feature = "async-io")]
1019 {
1020 let (ready_tx, ready_rx) = mpsc::channel(1);
1021 let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx);
1022 let dialer = dialer::<async_io::Tcp>(ready_rx);
1023 let listener = async_std::task::spawn(listener);
1024 async_std::task::block_on(dialer);
1025 async_std::task::block_on(listener);
1026 }
1027
1028 #[cfg(feature = "tokio")]
1029 {
1030 let (ready_tx, ready_rx) = mpsc::channel(1);
1031 let listener = listener::<tokio::Tcp>(addr, ready_tx);
1032 let dialer = dialer::<tokio::Tcp>(ready_rx);
1033 let rt = ::tokio::runtime::Builder::new_current_thread()
1034 .enable_io()
1035 .build()
1036 .unwrap();
1037 let tasks = ::tokio::task::LocalSet::new();
1038 let listener = tasks.spawn_local(listener);
1039 tasks.block_on(&rt, dialer);
1040 tasks.block_on(&rt, listener).unwrap();
1041 }
1042 }
1043
1044 test("/ip4/0.0.0.0/tcp/0".parse().unwrap());
1045 test("/ip6/::1/tcp/0".parse().unwrap());
1046 }
1047
1048 #[test]
1049 fn port_reuse_dialing() {
1050 let _ = tracing_subscriber::fmt()
1051 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1052 .try_init();
1053
1054 async fn listener<T: Provider>(
1055 addr: Multiaddr,
1056 mut ready_tx: mpsc::Sender<Multiaddr>,
1057 port_reuse_rx: oneshot::Receiver<Protocol<'_>>,
1058 ) {
1059 let mut tcp = Transport::<T>::new(Config::new()).boxed();
1060 tcp.listen_on(ListenerId::next(), addr).unwrap();
1061 loop {
1062 match tcp.select_next_some().await {
1063 TransportEvent::NewAddress { listen_addr, .. } => {
1064 ready_tx.send(listen_addr).await.ok();
1065 }
1066 TransportEvent::Incoming {
1067 upgrade,
1068 mut send_back_addr,
1069 ..
1070 } => {
1071 let remote_port_reuse = port_reuse_rx.await.unwrap();
1073 assert_eq!(send_back_addr.pop().unwrap(), remote_port_reuse);
1075
1076 let mut upgrade = upgrade.await.unwrap();
1077 let mut buf = [0u8; 3];
1078 upgrade.read_exact(&mut buf).await.unwrap();
1079 assert_eq!(buf, [1, 2, 3]);
1080 upgrade.write_all(&[4, 5, 6]).await.unwrap();
1081 return;
1082 }
1083 e => panic!("Unexpected event: {e:?}"),
1084 }
1085 }
1086 }
1087
1088 async fn dialer<T: Provider>(
1089 addr: Multiaddr,
1090 mut ready_rx: mpsc::Receiver<Multiaddr>,
1091 port_reuse_tx: oneshot::Sender<Protocol<'_>>,
1092 ) {
1093 let dest_addr = ready_rx.next().await.unwrap();
1094 let mut tcp = Transport::<T>::new(Config::new().port_reuse(true));
1095 tcp.listen_on(ListenerId::next(), addr).unwrap();
1096 match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await {
1097 TransportEvent::NewAddress { .. } => {
1098 let listener = tcp.listeners.iter().next().unwrap();
1100 let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip());
1101 let port_reuse_listener = listener
1102 .port_reuse
1103 .local_dial_addr(&listener.listen_addr.ip());
1104 assert!(port_reuse_tcp.is_some());
1105 assert_eq!(port_reuse_tcp, port_reuse_listener);
1106
1107 port_reuse_tx
1109 .send(Protocol::Tcp(port_reuse_tcp.unwrap().port()))
1110 .ok();
1111
1112 let mut socket = tcp.dial(dest_addr).unwrap().await.unwrap();
1114 socket.write_all(&[0x1, 0x2, 0x3]).await.unwrap();
1115 let mut buf = [0u8; 3];
1117 socket.read_exact(&mut buf).await.unwrap();
1118 assert_eq!(buf, [4, 5, 6]);
1119 }
1120 e => panic!("Unexpected transport event: {e:?}"),
1121 }
1122 }
1123
1124 fn test(addr: Multiaddr) {
1125 #[cfg(feature = "async-io")]
1126 {
1127 let (ready_tx, ready_rx) = mpsc::channel(1);
1128 let (port_reuse_tx, port_reuse_rx) = oneshot::channel();
1129 let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx, port_reuse_rx);
1130 let dialer = dialer::<async_io::Tcp>(addr.clone(), ready_rx, port_reuse_tx);
1131 let listener = async_std::task::spawn(listener);
1132 async_std::task::block_on(dialer);
1133 async_std::task::block_on(listener);
1134 }
1135
1136 #[cfg(feature = "tokio")]
1137 {
1138 let (ready_tx, ready_rx) = mpsc::channel(1);
1139 let (port_reuse_tx, port_reuse_rx) = oneshot::channel();
1140 let listener = listener::<tokio::Tcp>(addr.clone(), ready_tx, port_reuse_rx);
1141 let dialer = dialer::<tokio::Tcp>(addr, ready_rx, port_reuse_tx);
1142 let rt = ::tokio::runtime::Builder::new_current_thread()
1143 .enable_io()
1144 .build()
1145 .unwrap();
1146 let tasks = ::tokio::task::LocalSet::new();
1147 let listener = tasks.spawn_local(listener);
1148 tasks.block_on(&rt, dialer);
1149 tasks.block_on(&rt, listener).unwrap();
1150 }
1151 }
1152
1153 test("/ip4/127.0.0.1/tcp/0".parse().unwrap());
1154 test("/ip6/::1/tcp/0".parse().unwrap());
1155 }
1156
1157 #[test]
1158 fn port_reuse_listening() {
1159 let _ = tracing_subscriber::fmt()
1160 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1161 .try_init();
1162
1163 async fn listen_twice<T: Provider>(addr: Multiaddr) {
1164 let mut tcp = Transport::<T>::new(Config::new().port_reuse(true));
1165 tcp.listen_on(ListenerId::next(), addr).unwrap();
1166 match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await {
1167 TransportEvent::NewAddress {
1168 listen_addr: addr1, ..
1169 } => {
1170 let listener1 = tcp.listeners.iter().next().unwrap();
1171 let port_reuse_tcp =
1172 tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip());
1173 let port_reuse_listener1 = listener1
1174 .port_reuse
1175 .local_dial_addr(&listener1.listen_addr.ip());
1176 assert!(port_reuse_tcp.is_some());
1177 assert_eq!(port_reuse_tcp, port_reuse_listener1);
1178
1179 tcp.listen_on(ListenerId::next(), addr1.clone()).unwrap();
1181 match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await {
1182 TransportEvent::NewAddress {
1183 listen_addr: addr2, ..
1184 } => assert_eq!(addr1, addr2),
1185 e => panic!("Unexpected transport event: {e:?}"),
1186 }
1187 }
1188 e => panic!("Unexpected transport event: {e:?}"),
1189 }
1190 }
1191
1192 fn test(addr: Multiaddr) {
1193 #[cfg(feature = "async-io")]
1194 {
1195 let listener = listen_twice::<async_io::Tcp>(addr.clone());
1196 async_std::task::block_on(listener);
1197 }
1198
1199 #[cfg(feature = "tokio")]
1200 {
1201 let listener = listen_twice::<tokio::Tcp>(addr);
1202 let rt = ::tokio::runtime::Builder::new_current_thread()
1203 .enable_io()
1204 .build()
1205 .unwrap();
1206 rt.block_on(listener);
1207 }
1208 }
1209
1210 test("/ip4/127.0.0.1/tcp/0".parse().unwrap());
1211 }
1212
1213 #[test]
1214 fn listen_port_0() {
1215 let _ = tracing_subscriber::fmt()
1216 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1217 .try_init();
1218
1219 async fn listen<T: Provider>(addr: Multiaddr) -> Multiaddr {
1220 let mut tcp = Transport::<T>::default().boxed();
1221 tcp.listen_on(ListenerId::next(), addr).unwrap();
1222 tcp.select_next_some()
1223 .await
1224 .into_new_address()
1225 .expect("listen address")
1226 }
1227
1228 fn test(addr: Multiaddr) {
1229 #[cfg(feature = "async-io")]
1230 {
1231 let new_addr = async_std::task::block_on(listen::<async_io::Tcp>(addr.clone()));
1232 assert!(!new_addr.to_string().contains("tcp/0"));
1233 }
1234
1235 #[cfg(feature = "tokio")]
1236 {
1237 let rt = ::tokio::runtime::Builder::new_current_thread()
1238 .enable_io()
1239 .build()
1240 .unwrap();
1241 let new_addr = rt.block_on(listen::<tokio::Tcp>(addr));
1242 assert!(!new_addr.to_string().contains("tcp/0"));
1243 }
1244 }
1245
1246 test("/ip6/::1/tcp/0".parse().unwrap());
1247 test("/ip4/127.0.0.1/tcp/0".parse().unwrap());
1248 }
1249
1250 #[test]
1251 fn listen_invalid_addr() {
1252 let _ = tracing_subscriber::fmt()
1253 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1254 .try_init();
1255
1256 fn test(addr: Multiaddr) {
1257 #[cfg(feature = "async-io")]
1258 {
1259 let mut tcp = async_io::Transport::default();
1260 assert!(tcp.listen_on(ListenerId::next(), addr.clone()).is_err());
1261 }
1262
1263 #[cfg(feature = "tokio")]
1264 {
1265 let mut tcp = tokio::Transport::default();
1266 assert!(tcp.listen_on(ListenerId::next(), addr).is_err());
1267 }
1268 }
1269
1270 test("/ip4/127.0.0.1/tcp/12345/tcp/12345".parse().unwrap());
1271 }
1272
1273 #[cfg(feature = "async-io")]
1274 #[test]
1275 fn test_address_translation_async_io() {
1276 test_address_translation::<async_io::Transport>()
1277 }
1278
1279 #[cfg(feature = "tokio")]
1280 #[test]
1281 fn test_address_translation_tokio() {
1282 test_address_translation::<tokio::Transport>()
1283 }
1284
1285 fn test_address_translation<T>()
1286 where
1287 T: Default + libp2p_core::Transport,
1288 {
1289 let transport = T::default();
1290
1291 let port = 42;
1292 let tcp_listen_addr = Multiaddr::empty()
1293 .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1)))
1294 .with(Protocol::Tcp(port));
1295 let observed_ip = Ipv4Addr::new(123, 45, 67, 8);
1296 let tcp_observed_addr = Multiaddr::empty()
1297 .with(Protocol::Ip4(observed_ip))
1298 .with(Protocol::Tcp(1))
1299 .with(Protocol::P2p(PeerId::random()));
1300
1301 let translated = transport
1302 .address_translation(&tcp_listen_addr, &tcp_observed_addr)
1303 .unwrap();
1304 let mut iter = translated.iter();
1305 assert_eq!(iter.next(), Some(Protocol::Ip4(observed_ip)));
1306 assert_eq!(iter.next(), Some(Protocol::Tcp(port)));
1307 assert_eq!(iter.next(), None);
1308
1309 let quic_addr = Multiaddr::empty()
1310 .with(Protocol::Ip4(Ipv4Addr::new(87, 65, 43, 21)))
1311 .with(Protocol::Udp(1))
1312 .with(Protocol::QuicV1);
1313
1314 assert!(transport
1315 .address_translation(&tcp_listen_addr, &quic_addr)
1316 .is_none());
1317 assert!(transport
1318 .address_translation(&quic_addr, &tcp_observed_addr)
1319 .is_none());
1320 }
1321
1322 #[test]
1323 fn test_remove_listener() {
1324 let _ = tracing_subscriber::fmt()
1325 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1326 .try_init();
1327
1328 async fn cycle_listeners<T: Provider>() -> bool {
1329 let mut tcp = Transport::<T>::default().boxed();
1330 let listener_id = ListenerId::next();
1331 tcp.listen_on(listener_id, "/ip4/127.0.0.1/tcp/0".parse().unwrap())
1332 .unwrap();
1333 tcp.remove_listener(listener_id)
1334 }
1335
1336 #[cfg(feature = "async-io")]
1337 {
1338 assert!(async_std::task::block_on(cycle_listeners::<async_io::Tcp>()));
1339 }
1340
1341 #[cfg(feature = "tokio")]
1342 {
1343 let rt = ::tokio::runtime::Builder::new_current_thread()
1344 .enable_io()
1345 .build()
1346 .unwrap();
1347 assert!(rt.block_on(cycle_listeners::<tokio::Tcp>()));
1348 }
1349 }
1350
1351 #[test]
1352 fn test_listens_ipv4_ipv6_separately() {
1353 fn test<T: Provider>() {
1354 let port = {
1355 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1356 listener.local_addr().unwrap().port()
1357 };
1358 let mut tcp = Transport::<T>::default().boxed();
1359 let listener_id = ListenerId::next();
1360 tcp.listen_on(
1361 listener_id,
1362 format!("/ip4/0.0.0.0/tcp/{port}").parse().unwrap(),
1363 )
1364 .unwrap();
1365 tcp.listen_on(
1366 ListenerId::next(),
1367 format!("/ip6/::/tcp/{port}").parse().unwrap(),
1368 )
1369 .unwrap();
1370 }
1371 #[cfg(feature = "async-io")]
1372 {
1373 async_std::task::block_on(async {
1374 test::<async_io::Tcp>();
1375 })
1376 }
1377 #[cfg(feature = "tokio")]
1378 {
1379 let rt = ::tokio::runtime::Builder::new_current_thread()
1380 .enable_io()
1381 .build()
1382 .unwrap();
1383 rt.block_on(async {
1384 test::<async_io::Tcp>();
1385 });
1386 }
1387 }
1388}