1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25use either::Either;
26use futures::{prelude::*, ready};
27use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
28use libp2p_core::upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeInfo};
29use std::collections::VecDeque;
30use std::io::{IoSlice, IoSliceMut};
31use std::task::Waker;
32use std::{
33 io, iter,
34 pin::Pin,
35 task::{Context, Poll},
36};
37use thiserror::Error;
38
39#[derive(Debug)]
41pub struct Muxer<C> {
42 connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>,
43 inbound_stream_buffer: VecDeque<Stream>,
53 inbound_stream_waker: Option<Waker>,
55}
56
57const MAX_BUFFERED_INBOUND_STREAMS: usize = 256;
62
63impl<C> Muxer<C>
64where
65 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
66{
67 fn new(connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>) -> Self {
69 Muxer {
70 connection,
71 inbound_stream_buffer: VecDeque::default(),
72 inbound_stream_waker: None,
73 }
74 }
75}
76
77impl<C> StreamMuxer for Muxer<C>
78where
79 C: AsyncRead + AsyncWrite + Unpin + 'static,
80{
81 type Substream = Stream;
82 type Error = Error;
83
84 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_inbound", skip(self, cx))]
85 fn poll_inbound(
86 mut self: Pin<&mut Self>,
87 cx: &mut Context<'_>,
88 ) -> Poll<Result<Self::Substream, Self::Error>> {
89 if let Some(stream) = self.inbound_stream_buffer.pop_front() {
90 return Poll::Ready(Ok(stream));
91 }
92
93 if let Poll::Ready(res) = self.poll_inner(cx) {
94 return Poll::Ready(res);
95 }
96
97 self.inbound_stream_waker = Some(cx.waker().clone());
98 Poll::Pending
99 }
100
101 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_outbound", skip(self, cx))]
102 fn poll_outbound(
103 mut self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 ) -> Poll<Result<Self::Substream, Self::Error>> {
106 let stream = match self.connection.as_mut() {
107 Either::Left(c) => ready!(c.poll_new_outbound(cx))
108 .map_err(|e| Error(Either::Left(e)))
109 .map(|s| Stream(Either::Left(s))),
110 Either::Right(c) => ready!(c.poll_new_outbound(cx))
111 .map_err(|e| Error(Either::Right(e)))
112 .map(|s| Stream(Either::Right(s))),
113 }?;
114 Poll::Ready(Ok(stream))
115 }
116
117 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_close", skip(self, cx))]
118 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119 match self.connection.as_mut() {
120 Either::Left(c) => c.poll_close(cx).map_err(|e| Error(Either::Left(e))),
121 Either::Right(c) => c.poll_close(cx).map_err(|e| Error(Either::Right(e))),
122 }
123 }
124
125 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll", skip(self, cx))]
126 fn poll(
127 self: Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
130 let this = self.get_mut();
131
132 let inbound_stream = ready!(this.poll_inner(cx))?;
133
134 if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS {
135 tracing::warn!(
136 stream=%inbound_stream.0,
137 "dropping stream because buffer is full"
138 );
139 drop(inbound_stream);
140 } else {
141 this.inbound_stream_buffer.push_back(inbound_stream);
142
143 if let Some(waker) = this.inbound_stream_waker.take() {
144 waker.wake()
145 }
146 }
147
148 cx.waker().wake_by_ref();
150 Poll::Pending
151 }
152}
153
154#[derive(Debug)]
156pub struct Stream(Either<yamux012::Stream, yamux013::Stream>);
157
158impl AsyncRead for Stream {
159 fn poll_read(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 buf: &mut [u8],
163 ) -> Poll<io::Result<usize>> {
164 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read(cx, buf))
165 }
166
167 fn poll_read_vectored(
168 mut self: Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 bufs: &mut [IoSliceMut<'_>],
171 ) -> Poll<io::Result<usize>> {
172 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read_vectored(cx, bufs))
173 }
174}
175
176impl AsyncWrite for Stream {
177 fn poll_write(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<io::Result<usize>> {
182 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write(cx, buf))
183 }
184
185 fn poll_write_vectored(
186 mut self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 bufs: &[IoSlice<'_>],
189 ) -> Poll<io::Result<usize>> {
190 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write_vectored(cx, bufs))
191 }
192
193 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_flush(cx))
195 }
196
197 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_close(cx))
199 }
200}
201
202impl<C> Muxer<C>
203where
204 C: AsyncRead + AsyncWrite + Unpin + 'static,
205{
206 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream, Error>> {
207 let stream = match self.connection.as_mut() {
208 Either::Left(c) => ready!(c.poll_next_inbound(cx))
209 .ok_or(Error(Either::Left(yamux012::ConnectionError::Closed)))?
210 .map_err(|e| Error(Either::Left(e)))
211 .map(|s| Stream(Either::Left(s)))?,
212 Either::Right(c) => ready!(c.poll_next_inbound(cx))
213 .ok_or(Error(Either::Right(yamux013::ConnectionError::Closed)))?
214 .map_err(|e| Error(Either::Right(e)))
215 .map(|s| Stream(Either::Right(s)))?,
216 };
217
218 Poll::Ready(Ok(stream))
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct Config(Either<Config012, Config013>);
225
226impl Default for Config {
227 fn default() -> Self {
228 Self(Either::Right(Config013::default()))
229 }
230}
231
232#[derive(Debug, Clone)]
233struct Config012 {
234 inner: yamux012::Config,
235 mode: Option<yamux012::Mode>,
236}
237
238impl Default for Config012 {
239 fn default() -> Self {
240 let mut inner = yamux012::Config::default();
241 inner.set_read_after_close(false);
244 Self { inner, mode: None }
245 }
246}
247
248pub struct WindowUpdateMode(yamux012::WindowUpdateMode);
251
252impl WindowUpdateMode {
253 #[deprecated(note = "Use `WindowUpdateMode::on_read` instead.")]
266 pub fn on_receive() -> Self {
267 #[allow(deprecated)]
268 WindowUpdateMode(yamux012::WindowUpdateMode::OnReceive)
269 }
270
271 pub fn on_read() -> Self {
286 WindowUpdateMode(yamux012::WindowUpdateMode::OnRead)
287 }
288}
289
290impl Config {
291 #[deprecated(note = "Will be removed with the next breaking release.")]
294 pub fn client() -> Self {
295 Self(Either::Left(Config012 {
296 mode: Some(yamux012::Mode::Client),
297 ..Default::default()
298 }))
299 }
300
301 #[deprecated(note = "Will be removed with the next breaking release.")]
304 pub fn server() -> Self {
305 Self(Either::Left(Config012 {
306 mode: Some(yamux012::Mode::Server),
307 ..Default::default()
308 }))
309 }
310
311 #[deprecated(
313 note = "Will be replaced in the next breaking release with a connection receive window size limit."
314 )]
315 pub fn set_receive_window_size(&mut self, num_bytes: u32) -> &mut Self {
316 self.set(|cfg| cfg.set_receive_window(num_bytes))
317 }
318
319 #[deprecated(note = "Will be removed with the next breaking release.")]
321 pub fn set_max_buffer_size(&mut self, num_bytes: usize) -> &mut Self {
322 self.set(|cfg| cfg.set_max_buffer_size(num_bytes))
323 }
324
325 pub fn set_max_num_streams(&mut self, num_streams: usize) -> &mut Self {
327 self.set(|cfg| cfg.set_max_num_streams(num_streams))
328 }
329
330 #[deprecated(
333 note = "`WindowUpdate::OnRead` is the default. `WindowUpdate::OnReceive` breaks backpressure, is thus not recommended, and will be removed in the next breaking release. Thus this method becomes obsolete and will be removed with the next breaking release."
334 )]
335 pub fn set_window_update_mode(&mut self, mode: WindowUpdateMode) -> &mut Self {
336 self.set(|cfg| cfg.set_window_update_mode(mode.0))
337 }
338
339 fn set(&mut self, f: impl FnOnce(&mut yamux012::Config) -> &mut yamux012::Config) -> &mut Self {
340 let cfg012 = match self.0.as_mut() {
341 Either::Left(c) => &mut c.inner,
342 Either::Right(_) => {
343 self.0 = Either::Left(Config012::default());
344 &mut self.0.as_mut().unwrap_left().inner
345 }
346 };
347
348 f(cfg012);
349
350 self
351 }
352}
353
354impl UpgradeInfo for Config {
355 type Info = &'static str;
356 type InfoIter = iter::Once<Self::Info>;
357
358 fn protocol_info(&self) -> Self::InfoIter {
359 iter::once("/yamux/1.0.0")
360 }
361}
362
363impl<C> InboundConnectionUpgrade<C> for Config
364where
365 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
366{
367 type Output = Muxer<C>;
368 type Error = io::Error;
369 type Future = future::Ready<Result<Self::Output, Self::Error>>;
370
371 fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
372 let connection = match self.0 {
373 Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
374 io,
375 inner,
376 mode.unwrap_or(yamux012::Mode::Server),
377 )),
378 Either::Right(Config013(cfg)) => {
379 Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Server))
380 }
381 };
382
383 future::ready(Ok(Muxer::new(connection)))
384 }
385}
386
387impl<C> OutboundConnectionUpgrade<C> for Config
388where
389 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
390{
391 type Output = Muxer<C>;
392 type Error = io::Error;
393 type Future = future::Ready<Result<Self::Output, Self::Error>>;
394
395 fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
396 let connection = match self.0 {
397 Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
398 io,
399 inner,
400 mode.unwrap_or(yamux012::Mode::Client),
401 )),
402 Either::Right(Config013(cfg)) => {
403 Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Client))
404 }
405 };
406
407 future::ready(Ok(Muxer::new(connection)))
408 }
409}
410
411#[derive(Debug, Clone)]
412struct Config013(yamux013::Config);
413
414impl Default for Config013 {
415 fn default() -> Self {
416 let mut cfg = yamux013::Config::default();
417 cfg.set_read_after_close(false);
420 Self(cfg)
421 }
422}
423
424#[derive(Debug, Error)]
426#[error(transparent)]
427pub struct Error(Either<yamux012::ConnectionError, yamux013::ConnectionError>);
428
429impl From<Error> for io::Error {
430 fn from(err: Error) -> Self {
431 match err.0 {
432 Either::Left(err) => match err {
433 yamux012::ConnectionError::Io(e) => e,
434 e => io::Error::new(io::ErrorKind::Other, e),
435 },
436 Either::Right(err) => match err {
437 yamux013::ConnectionError::Io(e) => e,
438 e => io::Error::new(io::ErrorKind::Other, e),
439 },
440 }
441 }
442}
443
444#[cfg(test)]
445mod test {
446 use super::*;
447 #[test]
448 fn config_set_switches_to_v012() {
449 let mut cfg = Config::default();
452 assert!(matches!(
453 cfg,
454 Config(Either::Right(Config013(yamux013::Config { .. })))
455 ));
456
457 cfg.set_max_num_streams(42);
459 assert!(matches!(cfg, Config(Either::Left(Config012 { .. }))));
460 }
461}