1use crate::transport::{ListenerId, Transport, TransportError, TransportEvent};
22use fnv::FnvHashMap;
23use futures::{channel::mpsc, future::Ready, prelude::*, task::Context, task::Poll};
24use multiaddr::{Multiaddr, Protocol};
25use once_cell::sync::Lazy;
26use parking_lot::Mutex;
27use rw_stream_sink::RwStreamSink;
28use std::{
29 collections::{hash_map::Entry, VecDeque},
30 error, fmt, io,
31 num::NonZeroU64,
32 pin::Pin,
33};
34
35static HUB: Lazy<Hub> = Lazy::new(|| Hub(Mutex::new(FnvHashMap::default())));
36
37struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
38
39type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
42
43type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
46
47impl Hub {
48 fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
53 let mut hub = self.0.lock();
54
55 let port = if let Some(port) = NonZeroU64::new(port) {
56 port
57 } else {
58 loop {
59 let Some(port) = NonZeroU64::new(rand::random()) else {
60 continue;
61 };
62 if !hub.contains_key(&port) {
63 break port;
64 }
65 }
66 };
67
68 let (tx, rx) = mpsc::channel(2);
69 match hub.entry(port) {
70 Entry::Occupied(_) => return None,
71 Entry::Vacant(e) => e.insert(tx),
72 };
73
74 Some((rx, port))
75 }
76
77 fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
78 self.0.lock().remove(port)
79 }
80
81 fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
82 self.0.lock().get(port).cloned()
83 }
84}
85
86#[derive(Default)]
88pub struct MemoryTransport {
89 listeners: VecDeque<Pin<Box<Listener>>>,
90}
91
92impl MemoryTransport {
93 pub fn new() -> Self {
94 Self::default()
95 }
96}
97
98pub struct DialFuture {
100 dial_port: NonZeroU64,
107 sender: ChannelSender,
108 channel_to_send: Option<Channel<Vec<u8>>>,
109 channel_to_return: Option<Channel<Vec<u8>>>,
110}
111
112impl DialFuture {
113 fn new(port: NonZeroU64) -> Option<Self> {
114 let sender = HUB.get(&port)?;
115
116 let (_dial_port_channel, dial_port) = HUB
117 .register_port(0)
118 .expect("there to be some random unoccupied port.");
119
120 let (a_tx, a_rx) = mpsc::channel(4096);
121 let (b_tx, b_rx) = mpsc::channel(4096);
122 Some(DialFuture {
123 dial_port,
124 sender,
125 channel_to_send: Some(RwStreamSink::new(Chan {
126 incoming: a_rx,
127 outgoing: b_tx,
128 dial_port: None,
129 })),
130 channel_to_return: Some(RwStreamSink::new(Chan {
131 incoming: b_rx,
132 outgoing: a_tx,
133 dial_port: Some(dial_port),
134 })),
135 })
136 }
137}
138
139impl Future for DialFuture {
140 type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
141
142 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143 match self.sender.poll_ready(cx) {
144 Poll::Pending => return Poll::Pending,
145 Poll::Ready(Ok(())) => {}
146 Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
147 }
148
149 let channel_to_send = self
150 .channel_to_send
151 .take()
152 .expect("Future should not be polled again once complete");
153 let dial_port = self.dial_port;
154 if self
155 .sender
156 .start_send((channel_to_send, dial_port))
157 .is_err()
158 {
159 return Poll::Ready(Err(MemoryTransportError::Unreachable));
160 }
161
162 Poll::Ready(Ok(self
163 .channel_to_return
164 .take()
165 .expect("Future should not be polled again once complete")))
166 }
167}
168
169impl Transport for MemoryTransport {
170 type Output = Channel<Vec<u8>>;
171 type Error = MemoryTransportError;
172 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
173 type Dial = DialFuture;
174
175 fn listen_on(
176 &mut self,
177 id: ListenerId,
178 addr: Multiaddr,
179 ) -> Result<(), TransportError<Self::Error>> {
180 let port =
181 parse_memory_addr(&addr).map_err(|_| TransportError::MultiaddrNotSupported(addr))?;
182
183 let (rx, port) = HUB
184 .register_port(port)
185 .ok_or(TransportError::Other(MemoryTransportError::Unreachable))?;
186
187 let listener = Listener {
188 id,
189 port,
190 addr: Protocol::Memory(port.get()).into(),
191 receiver: rx,
192 tell_listen_addr: true,
193 };
194 self.listeners.push_back(Box::pin(listener));
195
196 Ok(())
197 }
198
199 fn remove_listener(&mut self, id: ListenerId) -> bool {
200 if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
201 let listener = self.listeners.get_mut(index).unwrap();
202 let val_in = HUB.unregister_port(&listener.port);
203 debug_assert!(val_in.is_some());
204 listener.receiver.close();
205 true
206 } else {
207 false
208 }
209 }
210
211 fn dial(&mut self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
212 let port = if let Ok(port) = parse_memory_addr(&addr) {
213 if let Some(port) = NonZeroU64::new(port) {
214 port
215 } else {
216 return Err(TransportError::Other(MemoryTransportError::Unreachable));
217 }
218 } else {
219 return Err(TransportError::MultiaddrNotSupported(addr));
220 };
221
222 DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
223 }
224
225 fn dial_as_listener(
226 &mut self,
227 addr: Multiaddr,
228 ) -> Result<DialFuture, TransportError<Self::Error>> {
229 self.dial(addr)
230 }
231
232 fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option<Multiaddr> {
233 None
234 }
235
236 fn poll(
237 mut self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
240 where
241 Self: Sized,
242 {
243 let mut remaining = self.listeners.len();
244 while let Some(mut listener) = self.listeners.pop_back() {
245 if listener.tell_listen_addr {
246 listener.tell_listen_addr = false;
247 let listen_addr = listener.addr.clone();
248 let listener_id = listener.id;
249 self.listeners.push_front(listener);
250 return Poll::Ready(TransportEvent::NewAddress {
251 listen_addr,
252 listener_id,
253 });
254 }
255
256 let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
257 Poll::Pending => None,
258 Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
259 listener_id: listener.id,
260 upgrade: future::ready(Ok(channel)),
261 local_addr: listener.addr.clone(),
262 send_back_addr: Protocol::Memory(dial_port.get()).into(),
263 }),
264 Poll::Ready(None) => {
265 return Poll::Ready(TransportEvent::ListenerClosed {
267 listener_id: listener.id,
268 reason: Ok(()),
269 });
270 }
271 };
272
273 self.listeners.push_front(listener);
274 if let Some(event) = event {
275 return Poll::Ready(event);
276 } else {
277 remaining -= 1;
278 if remaining == 0 {
279 break;
280 }
281 }
282 }
283 Poll::Pending
284 }
285}
286
287#[derive(Debug, Copy, Clone)]
289pub enum MemoryTransportError {
290 Unreachable,
292 AlreadyInUse,
294}
295
296impl fmt::Display for MemoryTransportError {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 match *self {
299 MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
300 MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
301 }
302 }
303}
304
305impl error::Error for MemoryTransportError {}
306
307pub struct Listener {
309 id: ListenerId,
310 port: NonZeroU64,
312 addr: Multiaddr,
314 receiver: ChannelReceiver,
316 tell_listen_addr: bool,
318}
319
320fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
322 let mut protocols = a.iter();
323 match protocols.next() {
324 Some(Protocol::Memory(port)) => match protocols.next() {
325 None | Some(Protocol::P2p(_)) => Ok(port),
326 _ => Err(()),
327 },
328 _ => Err(()),
329 }
330}
331
332pub type Channel<T> = RwStreamSink<Chan<T>>;
336
337pub struct Chan<T = Vec<u8>> {
341 incoming: mpsc::Receiver<T>,
342 outgoing: mpsc::Sender<T>,
343
344 dial_port: Option<NonZeroU64>,
351}
352
353impl<T> Unpin for Chan<T> {}
354
355impl<T> Stream for Chan<T> {
356 type Item = Result<T, io::Error>;
357
358 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
359 match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
360 Poll::Pending => Poll::Pending,
361 Poll::Ready(None) => Poll::Ready(None),
362 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
363 }
364 }
365}
366
367impl<T> Sink<T> for Chan<T> {
368 type Error = io::Error;
369
370 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
371 self.outgoing
372 .poll_ready(cx)
373 .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
374 }
375
376 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
377 self.outgoing
378 .start_send(item)
379 .map_err(|_| io::ErrorKind::BrokenPipe.into())
380 }
381
382 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383 Poll::Ready(Ok(()))
384 }
385
386 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387 Poll::Ready(Ok(()))
388 }
389}
390
391impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
392 fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
393 RwStreamSink::new(channel)
394 }
395}
396
397impl<T> Drop for Chan<T> {
398 fn drop(&mut self) {
399 if let Some(port) = self.dial_port {
400 let channel_sender = HUB.unregister_port(&port);
401 debug_assert!(channel_sender.is_some());
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn parse_memory_addr_works() {
412 assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
413 assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
414 assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
415 assert_eq!(
416 parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
417 Err(())
418 );
419 assert_eq!(
420 parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
421 Err(())
422 );
423 assert_eq!(
424 parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
425 Ok(1_234_567_890)
426 );
427 assert_eq!(
428 parse_memory_addr(
429 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
430 .parse()
431 .unwrap()
432 ),
433 Ok(5)
434 );
435 assert_eq!(
436 parse_memory_addr(
437 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
438 .parse()
439 .unwrap()
440 ),
441 Ok(5)
442 );
443 }
444
445 #[test]
446 fn listening_twice() {
447 let mut transport = MemoryTransport::default();
448
449 let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
450 let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
451
452 let listener_id_1 = ListenerId::next();
453
454 transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
455 assert!(
456 transport.remove_listener(listener_id_1),
457 "Listener doesn't exist."
458 );
459
460 let listener_id_2 = ListenerId::next();
461 transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
462 let listener_id_3 = ListenerId::next();
463 transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
464
465 assert!(transport
466 .listen_on(ListenerId::next(), addr_1.clone())
467 .is_err());
468 assert!(transport
469 .listen_on(ListenerId::next(), addr_2.clone())
470 .is_err());
471
472 assert!(
473 transport.remove_listener(listener_id_2),
474 "Listener doesn't exist."
475 );
476 assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
477 assert!(transport
478 .listen_on(ListenerId::next(), addr_2.clone())
479 .is_err());
480
481 assert!(
482 transport.remove_listener(listener_id_3),
483 "Listener doesn't exist."
484 );
485 assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
486 }
487
488 #[test]
489 fn port_not_in_use() {
490 let mut transport = MemoryTransport::default();
491 assert!(transport
492 .dial("/memory/810172461024613".parse().unwrap())
493 .is_err());
494 transport
495 .listen_on(
496 ListenerId::next(),
497 "/memory/810172461024613".parse().unwrap(),
498 )
499 .unwrap();
500 assert!(transport
501 .dial("/memory/810172461024613".parse().unwrap())
502 .is_ok());
503 }
504
505 #[test]
506 fn stop_listening() {
507 let rand_port = rand::random::<u64>().saturating_add(1);
508 let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
509
510 let mut transport = MemoryTransport::default().boxed();
511 futures::executor::block_on(async {
512 let listener_id = ListenerId::next();
513 transport.listen_on(listener_id, addr.clone()).unwrap();
514 let reported_addr = transport
515 .select_next_some()
516 .await
517 .into_new_address()
518 .expect("new address");
519 assert_eq!(addr, reported_addr);
520 assert!(transport.remove_listener(listener_id));
521 match transport.select_next_some().await {
522 TransportEvent::ListenerClosed {
523 listener_id: id,
524 reason,
525 } => {
526 assert_eq!(id, listener_id);
527 assert!(reason.is_ok())
528 }
529 other => panic!("Unexpected transport event: {other:?}"),
530 }
531 assert!(!transport.remove_listener(listener_id));
532 })
533 }
534
535 #[test]
536 fn communicating_between_dialer_and_listener() {
537 let msg = [1, 2, 3];
538
539 let rand_port = rand::random::<u64>().saturating_add(1);
542 let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
543 let cloned_t1_addr = t1_addr.clone();
544
545 let mut t1 = MemoryTransport::default().boxed();
546
547 let listener = async move {
548 t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
549 let upgrade = loop {
550 let event = t1.select_next_some().await;
551 if let Some(upgrade) = event.into_incoming() {
552 break upgrade;
553 }
554 };
555
556 let mut socket = upgrade.0.await.unwrap();
557
558 let mut buf = [0; 3];
559 socket.read_exact(&mut buf).await.unwrap();
560
561 assert_eq!(buf, msg);
562 };
563
564 let mut t2 = MemoryTransport::default();
567 let dialer = async move {
568 let mut socket = t2.dial(cloned_t1_addr).unwrap().await.unwrap();
569 socket.write_all(&msg).await.unwrap();
570 };
571
572 futures::executor::block_on(futures::future::join(listener, dialer));
575 }
576
577 #[test]
578 fn dialer_address_unequal_to_listener_address() {
579 let listener_addr: Multiaddr =
580 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
581 let listener_addr_cloned = listener_addr.clone();
582
583 let mut listener_transport = MemoryTransport::default().boxed();
584
585 let listener = async move {
586 listener_transport
587 .listen_on(ListenerId::next(), listener_addr.clone())
588 .unwrap();
589 loop {
590 if let TransportEvent::Incoming { send_back_addr, .. } =
591 listener_transport.select_next_some().await
592 {
593 assert!(
594 send_back_addr != listener_addr,
595 "Expect dialer address not to equal listener address."
596 );
597 return;
598 }
599 }
600 };
601
602 let dialer = async move {
603 MemoryTransport::default()
604 .dial(listener_addr_cloned)
605 .unwrap()
606 .await
607 .unwrap();
608 };
609
610 futures::executor::block_on(futures::future::join(listener, dialer));
611 }
612
613 #[test]
614 fn dialer_port_is_deregistered() {
615 let (terminate, should_terminate) = futures::channel::oneshot::channel();
616 let (terminated, is_terminated) = futures::channel::oneshot::channel();
617
618 let listener_addr: Multiaddr =
619 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
620 let listener_addr_cloned = listener_addr.clone();
621
622 let mut listener_transport = MemoryTransport::default().boxed();
623
624 let listener = async move {
625 listener_transport
626 .listen_on(ListenerId::next(), listener_addr.clone())
627 .unwrap();
628 loop {
629 if let TransportEvent::Incoming { send_back_addr, .. } =
630 listener_transport.select_next_some().await
631 {
632 let dialer_port =
633 NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
634
635 assert!(
636 HUB.get(&dialer_port).is_some(),
637 "Expect dialer port to stay registered while connection is in use.",
638 );
639
640 terminate.send(()).unwrap();
641 is_terminated.await.unwrap();
642
643 assert!(
644 HUB.get(&dialer_port).is_none(),
645 "Expect dialer port to be deregistered once connection is dropped.",
646 );
647
648 return;
649 }
650 }
651 };
652
653 let dialer = async move {
654 let chan = MemoryTransport::default()
655 .dial(listener_addr_cloned)
656 .unwrap()
657 .await
658 .unwrap();
659
660 should_terminate.await.unwrap();
661 drop(chan);
662 terminated.send(()).unwrap();
663 };
664
665 futures::executor::block_on(futures::future::join(listener, dialer));
666 }
667}