libp2p_core/transport/
memory.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::transport::{ListenerId, Transport, TransportError, TransportEvent};
22use fnv::FnvHashMap;
23use futures::{channel::mpsc, future::Ready, prelude::*, task::Context, task::Poll};
24use multiaddr::{Multiaddr, Protocol};
25use once_cell::sync::Lazy;
26use parking_lot::Mutex;
27use rw_stream_sink::RwStreamSink;
28use std::{
29    collections::{hash_map::Entry, VecDeque},
30    error, fmt, io,
31    num::NonZeroU64,
32    pin::Pin,
33};
34
35static HUB: Lazy<Hub> = Lazy::new(|| Hub(Mutex::new(FnvHashMap::default())));
36
37struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
38
39/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
40/// port of the dialer to a [`Listener`].
41type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
42
43/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
44/// the port of the dialer from a [`DialFuture`].
45type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
46
47impl Hub {
48    /// Registers the given port on the hub.
49    ///
50    /// Randomizes port when given port is `0`. Returns [`None`] when given port
51    /// is already occupied.
52    fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
53        let mut hub = self.0.lock();
54
55        let port = if let Some(port) = NonZeroU64::new(port) {
56            port
57        } else {
58            loop {
59                let Some(port) = NonZeroU64::new(rand::random()) else {
60                    continue;
61                };
62                if !hub.contains_key(&port) {
63                    break port;
64                }
65            }
66        };
67
68        let (tx, rx) = mpsc::channel(2);
69        match hub.entry(port) {
70            Entry::Occupied(_) => return None,
71            Entry::Vacant(e) => e.insert(tx),
72        };
73
74        Some((rx, port))
75    }
76
77    fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
78        self.0.lock().remove(port)
79    }
80
81    fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
82        self.0.lock().get(port).cloned()
83    }
84}
85
86/// Transport that supports `/memory/N` multiaddresses.
87#[derive(Default)]
88pub struct MemoryTransport {
89    listeners: VecDeque<Pin<Box<Listener>>>,
90}
91
92impl MemoryTransport {
93    pub fn new() -> Self {
94        Self::default()
95    }
96}
97
98/// Connection to a `MemoryTransport` currently being opened.
99pub struct DialFuture {
100    /// Ephemeral source port.
101    ///
102    /// These ports mimic TCP ephemeral source ports but are not actually used
103    /// by the memory transport due to the direct use of channels. They merely
104    /// ensure that every connection has a unique address for each dialer, which
105    /// is not at the same time a listen address (analogous to TCP).
106    dial_port: NonZeroU64,
107    sender: ChannelSender,
108    channel_to_send: Option<Channel<Vec<u8>>>,
109    channel_to_return: Option<Channel<Vec<u8>>>,
110}
111
112impl DialFuture {
113    fn new(port: NonZeroU64) -> Option<Self> {
114        let sender = HUB.get(&port)?;
115
116        let (_dial_port_channel, dial_port) = HUB
117            .register_port(0)
118            .expect("there to be some random unoccupied port.");
119
120        let (a_tx, a_rx) = mpsc::channel(4096);
121        let (b_tx, b_rx) = mpsc::channel(4096);
122        Some(DialFuture {
123            dial_port,
124            sender,
125            channel_to_send: Some(RwStreamSink::new(Chan {
126                incoming: a_rx,
127                outgoing: b_tx,
128                dial_port: None,
129            })),
130            channel_to_return: Some(RwStreamSink::new(Chan {
131                incoming: b_rx,
132                outgoing: a_tx,
133                dial_port: Some(dial_port),
134            })),
135        })
136    }
137}
138
139impl Future for DialFuture {
140    type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
141
142    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143        match self.sender.poll_ready(cx) {
144            Poll::Pending => return Poll::Pending,
145            Poll::Ready(Ok(())) => {}
146            Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
147        }
148
149        let channel_to_send = self
150            .channel_to_send
151            .take()
152            .expect("Future should not be polled again once complete");
153        let dial_port = self.dial_port;
154        if self
155            .sender
156            .start_send((channel_to_send, dial_port))
157            .is_err()
158        {
159            return Poll::Ready(Err(MemoryTransportError::Unreachable));
160        }
161
162        Poll::Ready(Ok(self
163            .channel_to_return
164            .take()
165            .expect("Future should not be polled again once complete")))
166    }
167}
168
169impl Transport for MemoryTransport {
170    type Output = Channel<Vec<u8>>;
171    type Error = MemoryTransportError;
172    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
173    type Dial = DialFuture;
174
175    fn listen_on(
176        &mut self,
177        id: ListenerId,
178        addr: Multiaddr,
179    ) -> Result<(), TransportError<Self::Error>> {
180        let port =
181            parse_memory_addr(&addr).map_err(|_| TransportError::MultiaddrNotSupported(addr))?;
182
183        let (rx, port) = HUB
184            .register_port(port)
185            .ok_or(TransportError::Other(MemoryTransportError::Unreachable))?;
186
187        let listener = Listener {
188            id,
189            port,
190            addr: Protocol::Memory(port.get()).into(),
191            receiver: rx,
192            tell_listen_addr: true,
193        };
194        self.listeners.push_back(Box::pin(listener));
195
196        Ok(())
197    }
198
199    fn remove_listener(&mut self, id: ListenerId) -> bool {
200        if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
201            let listener = self.listeners.get_mut(index).unwrap();
202            let val_in = HUB.unregister_port(&listener.port);
203            debug_assert!(val_in.is_some());
204            listener.receiver.close();
205            true
206        } else {
207            false
208        }
209    }
210
211    fn dial(&mut self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
212        let port = if let Ok(port) = parse_memory_addr(&addr) {
213            if let Some(port) = NonZeroU64::new(port) {
214                port
215            } else {
216                return Err(TransportError::Other(MemoryTransportError::Unreachable));
217            }
218        } else {
219            return Err(TransportError::MultiaddrNotSupported(addr));
220        };
221
222        DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
223    }
224
225    fn dial_as_listener(
226        &mut self,
227        addr: Multiaddr,
228    ) -> Result<DialFuture, TransportError<Self::Error>> {
229        self.dial(addr)
230    }
231
232    fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option<Multiaddr> {
233        None
234    }
235
236    fn poll(
237        mut self: Pin<&mut Self>,
238        cx: &mut Context<'_>,
239    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
240    where
241        Self: Sized,
242    {
243        let mut remaining = self.listeners.len();
244        while let Some(mut listener) = self.listeners.pop_back() {
245            if listener.tell_listen_addr {
246                listener.tell_listen_addr = false;
247                let listen_addr = listener.addr.clone();
248                let listener_id = listener.id;
249                self.listeners.push_front(listener);
250                return Poll::Ready(TransportEvent::NewAddress {
251                    listen_addr,
252                    listener_id,
253                });
254            }
255
256            let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
257                Poll::Pending => None,
258                Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
259                    listener_id: listener.id,
260                    upgrade: future::ready(Ok(channel)),
261                    local_addr: listener.addr.clone(),
262                    send_back_addr: Protocol::Memory(dial_port.get()).into(),
263                }),
264                Poll::Ready(None) => {
265                    // Listener was closed.
266                    return Poll::Ready(TransportEvent::ListenerClosed {
267                        listener_id: listener.id,
268                        reason: Ok(()),
269                    });
270                }
271            };
272
273            self.listeners.push_front(listener);
274            if let Some(event) = event {
275                return Poll::Ready(event);
276            } else {
277                remaining -= 1;
278                if remaining == 0 {
279                    break;
280                }
281            }
282        }
283        Poll::Pending
284    }
285}
286
287/// Error that can be produced from the `MemoryTransport`.
288#[derive(Debug, Copy, Clone)]
289pub enum MemoryTransportError {
290    /// There's no listener on the given port.
291    Unreachable,
292    /// Tries to listen on a port that is already in use.
293    AlreadyInUse,
294}
295
296impl fmt::Display for MemoryTransportError {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        match *self {
299            MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
300            MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
301        }
302    }
303}
304
305impl error::Error for MemoryTransportError {}
306
307/// Listener for memory connections.
308pub struct Listener {
309    id: ListenerId,
310    /// Port we're listening on.
311    port: NonZeroU64,
312    /// The address we are listening on.
313    addr: Multiaddr,
314    /// Receives incoming connections.
315    receiver: ChannelReceiver,
316    /// Generate [`TransportEvent::NewAddress`] to inform about our listen address.
317    tell_listen_addr: bool,
318}
319
320/// If the address is `/memory/n`, returns the value of `n`.
321fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
322    let mut protocols = a.iter();
323    match protocols.next() {
324        Some(Protocol::Memory(port)) => match protocols.next() {
325            None | Some(Protocol::P2p(_)) => Ok(port),
326            _ => Err(()),
327        },
328        _ => Err(()),
329    }
330}
331
332/// A channel represents an established, in-memory, logical connection between two endpoints.
333///
334/// Implements `AsyncRead` and `AsyncWrite`.
335pub type Channel<T> = RwStreamSink<Chan<T>>;
336
337/// A channel represents an established, in-memory, logical connection between two endpoints.
338///
339/// Implements `Sink` and `Stream`.
340pub struct Chan<T = Vec<u8>> {
341    incoming: mpsc::Receiver<T>,
342    outgoing: mpsc::Sender<T>,
343
344    // Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
345    // port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
346    // [`None`] when [`Chan`] of listener.
347    //
348    // Note: Listening port is unregistered in [`Drop`] implementation of
349    // [`Listener`].
350    dial_port: Option<NonZeroU64>,
351}
352
353impl<T> Unpin for Chan<T> {}
354
355impl<T> Stream for Chan<T> {
356    type Item = Result<T, io::Error>;
357
358    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
359        match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
360            Poll::Pending => Poll::Pending,
361            Poll::Ready(None) => Poll::Ready(None),
362            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
363        }
364    }
365}
366
367impl<T> Sink<T> for Chan<T> {
368    type Error = io::Error;
369
370    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
371        self.outgoing
372            .poll_ready(cx)
373            .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
374    }
375
376    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
377        self.outgoing
378            .start_send(item)
379            .map_err(|_| io::ErrorKind::BrokenPipe.into())
380    }
381
382    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383        Poll::Ready(Ok(()))
384    }
385
386    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387        Poll::Ready(Ok(()))
388    }
389}
390
391impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
392    fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
393        RwStreamSink::new(channel)
394    }
395}
396
397impl<T> Drop for Chan<T> {
398    fn drop(&mut self) {
399        if let Some(port) = self.dial_port {
400            let channel_sender = HUB.unregister_port(&port);
401            debug_assert!(channel_sender.is_some());
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn parse_memory_addr_works() {
412        assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
413        assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
414        assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
415        assert_eq!(
416            parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
417            Err(())
418        );
419        assert_eq!(
420            parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
421            Err(())
422        );
423        assert_eq!(
424            parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
425            Ok(1_234_567_890)
426        );
427        assert_eq!(
428            parse_memory_addr(
429                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
430                    .parse()
431                    .unwrap()
432            ),
433            Ok(5)
434        );
435        assert_eq!(
436            parse_memory_addr(
437                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
438                    .parse()
439                    .unwrap()
440            ),
441            Ok(5)
442        );
443    }
444
445    #[test]
446    fn listening_twice() {
447        let mut transport = MemoryTransport::default();
448
449        let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
450        let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
451
452        let listener_id_1 = ListenerId::next();
453
454        transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
455        assert!(
456            transport.remove_listener(listener_id_1),
457            "Listener doesn't exist."
458        );
459
460        let listener_id_2 = ListenerId::next();
461        transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
462        let listener_id_3 = ListenerId::next();
463        transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
464
465        assert!(transport
466            .listen_on(ListenerId::next(), addr_1.clone())
467            .is_err());
468        assert!(transport
469            .listen_on(ListenerId::next(), addr_2.clone())
470            .is_err());
471
472        assert!(
473            transport.remove_listener(listener_id_2),
474            "Listener doesn't exist."
475        );
476        assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
477        assert!(transport
478            .listen_on(ListenerId::next(), addr_2.clone())
479            .is_err());
480
481        assert!(
482            transport.remove_listener(listener_id_3),
483            "Listener doesn't exist."
484        );
485        assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
486    }
487
488    #[test]
489    fn port_not_in_use() {
490        let mut transport = MemoryTransport::default();
491        assert!(transport
492            .dial("/memory/810172461024613".parse().unwrap())
493            .is_err());
494        transport
495            .listen_on(
496                ListenerId::next(),
497                "/memory/810172461024613".parse().unwrap(),
498            )
499            .unwrap();
500        assert!(transport
501            .dial("/memory/810172461024613".parse().unwrap())
502            .is_ok());
503    }
504
505    #[test]
506    fn stop_listening() {
507        let rand_port = rand::random::<u64>().saturating_add(1);
508        let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
509
510        let mut transport = MemoryTransport::default().boxed();
511        futures::executor::block_on(async {
512            let listener_id = ListenerId::next();
513            transport.listen_on(listener_id, addr.clone()).unwrap();
514            let reported_addr = transport
515                .select_next_some()
516                .await
517                .into_new_address()
518                .expect("new address");
519            assert_eq!(addr, reported_addr);
520            assert!(transport.remove_listener(listener_id));
521            match transport.select_next_some().await {
522                TransportEvent::ListenerClosed {
523                    listener_id: id,
524                    reason,
525                } => {
526                    assert_eq!(id, listener_id);
527                    assert!(reason.is_ok())
528                }
529                other => panic!("Unexpected transport event: {other:?}"),
530            }
531            assert!(!transport.remove_listener(listener_id));
532        })
533    }
534
535    #[test]
536    fn communicating_between_dialer_and_listener() {
537        let msg = [1, 2, 3];
538
539        // Setup listener.
540
541        let rand_port = rand::random::<u64>().saturating_add(1);
542        let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
543        let cloned_t1_addr = t1_addr.clone();
544
545        let mut t1 = MemoryTransport::default().boxed();
546
547        let listener = async move {
548            t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
549            let upgrade = loop {
550                let event = t1.select_next_some().await;
551                if let Some(upgrade) = event.into_incoming() {
552                    break upgrade;
553                }
554            };
555
556            let mut socket = upgrade.0.await.unwrap();
557
558            let mut buf = [0; 3];
559            socket.read_exact(&mut buf).await.unwrap();
560
561            assert_eq!(buf, msg);
562        };
563
564        // Setup dialer.
565
566        let mut t2 = MemoryTransport::default();
567        let dialer = async move {
568            let mut socket = t2.dial(cloned_t1_addr).unwrap().await.unwrap();
569            socket.write_all(&msg).await.unwrap();
570        };
571
572        // Wait for both to finish.
573
574        futures::executor::block_on(futures::future::join(listener, dialer));
575    }
576
577    #[test]
578    fn dialer_address_unequal_to_listener_address() {
579        let listener_addr: Multiaddr =
580            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
581        let listener_addr_cloned = listener_addr.clone();
582
583        let mut listener_transport = MemoryTransport::default().boxed();
584
585        let listener = async move {
586            listener_transport
587                .listen_on(ListenerId::next(), listener_addr.clone())
588                .unwrap();
589            loop {
590                if let TransportEvent::Incoming { send_back_addr, .. } =
591                    listener_transport.select_next_some().await
592                {
593                    assert!(
594                        send_back_addr != listener_addr,
595                        "Expect dialer address not to equal listener address."
596                    );
597                    return;
598                }
599            }
600        };
601
602        let dialer = async move {
603            MemoryTransport::default()
604                .dial(listener_addr_cloned)
605                .unwrap()
606                .await
607                .unwrap();
608        };
609
610        futures::executor::block_on(futures::future::join(listener, dialer));
611    }
612
613    #[test]
614    fn dialer_port_is_deregistered() {
615        let (terminate, should_terminate) = futures::channel::oneshot::channel();
616        let (terminated, is_terminated) = futures::channel::oneshot::channel();
617
618        let listener_addr: Multiaddr =
619            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
620        let listener_addr_cloned = listener_addr.clone();
621
622        let mut listener_transport = MemoryTransport::default().boxed();
623
624        let listener = async move {
625            listener_transport
626                .listen_on(ListenerId::next(), listener_addr.clone())
627                .unwrap();
628            loop {
629                if let TransportEvent::Incoming { send_back_addr, .. } =
630                    listener_transport.select_next_some().await
631                {
632                    let dialer_port =
633                        NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
634
635                    assert!(
636                        HUB.get(&dialer_port).is_some(),
637                        "Expect dialer port to stay registered while connection is in use.",
638                    );
639
640                    terminate.send(()).unwrap();
641                    is_terminated.await.unwrap();
642
643                    assert!(
644                        HUB.get(&dialer_port).is_none(),
645                        "Expect dialer port to be deregistered once connection is dropped.",
646                    );
647
648                    return;
649                }
650            }
651        };
652
653        let dialer = async move {
654            let chan = MemoryTransport::default()
655                .dial(listener_addr_cloned)
656                .unwrap()
657                .await
658                .unwrap();
659
660            should_terminate.await.unwrap();
661            drop(chan);
662            terminated.send(()).unwrap();
663        };
664
665        futures::executor::block_on(futures::future::join(listener, dialer));
666    }
667}