yamux/frame/
io.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use super::{
12    header::{self, HeaderDecodeError},
13    Frame,
14};
15use crate::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18    fmt, io,
19    pin::Pin,
20    task::{Context, Poll},
21};
22
23/// Maximum Yamux frame body length
24///
25/// Limits the amount of bytes a remote can cause the local node to allocate at once when reading.
26///
27/// Chosen based on intuition in past iterations.
28const MAX_FRAME_BODY_LEN: usize = crate::MIB;
29
30/// A [`Stream`] and writer of [`Frame`] values.
31#[derive(Debug)]
32pub(crate) struct Io<T> {
33    id: Id,
34    io: T,
35    read_state: ReadState,
36    write_state: WriteState,
37}
38
39impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
40    pub(crate) fn new(id: Id, io: T) -> Self {
41        Io {
42            id,
43            io,
44            read_state: ReadState::Init,
45            write_state: WriteState::Init,
46        }
47    }
48}
49
50/// The stages of writing a new `Frame`.
51enum WriteState {
52    Init,
53    Header {
54        header: [u8; header::HEADER_SIZE],
55        buffer: Vec<u8>,
56        offset: usize,
57    },
58    Body {
59        buffer: Vec<u8>,
60        offset: usize,
61    },
62}
63
64impl fmt::Debug for WriteState {
65    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
66        match self {
67            WriteState::Init => f.write_str("(WriteState::Init)"),
68            WriteState::Header { offset, .. } => {
69                write!(f, "(WriteState::Header (offset {offset}))")
70            }
71            WriteState::Body { offset, buffer } => {
72                write!(
73                    f,
74                    "(WriteState::Body (offset {}) (buffer-len {}))",
75                    offset,
76                    buffer.len()
77                )
78            }
79        }
80    }
81}
82
83impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
84    type Error = io::Error;
85
86    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        let this = Pin::into_inner(self);
88        loop {
89            log::trace!("{}: write: {:?}", this.id, this.write_state);
90            match &mut this.write_state {
91                WriteState::Init => return Poll::Ready(Ok(())),
92                WriteState::Header {
93                    header,
94                    buffer,
95                    ref mut offset,
96                } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
97                    Poll::Pending => return Poll::Pending,
98                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
99                    Poll::Ready(Ok(n)) => {
100                        if n == 0 {
101                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
102                        }
103                        *offset += n;
104
105                        if *offset > header.len() {
106                            return Poll::Ready(Err(io::Error::other(format!(
107                                "Writer header returned invalid write count n={n}: {offset} > {} ",
108                                header.len(),
109                            ))));
110                        }
111
112                        if *offset == header.len() {
113                            if !buffer.is_empty() {
114                                let buffer = std::mem::take(buffer);
115                                this.write_state = WriteState::Body { buffer, offset: 0 };
116                            } else {
117                                this.write_state = WriteState::Init;
118                            }
119                        }
120                    }
121                },
122                WriteState::Body {
123                    buffer,
124                    ref mut offset,
125                } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
126                    Poll::Pending => return Poll::Pending,
127                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
128                    Poll::Ready(Ok(n)) => {
129                        if n == 0 {
130                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
131                        }
132                        *offset += n;
133
134                        if *offset > buffer.len() {
135                            return Poll::Ready(Err(io::Error::other(format!(
136                                "Writer body returned invalid write count n={n}: {offset} > {} ",
137                                buffer.len(),
138                            ))));
139                        }
140
141                        if *offset == buffer.len() {
142                            this.write_state = WriteState::Init;
143                        }
144                    }
145                },
146            }
147        }
148    }
149
150    fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
151        let header = header::encode(&f.header);
152        let buffer = f.body;
153        self.get_mut().write_state = WriteState::Header {
154            header,
155            buffer,
156            offset: 0,
157        };
158        Ok(())
159    }
160
161    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
162        let this = Pin::into_inner(self);
163        ready!(this.poll_ready_unpin(cx))?;
164        Pin::new(&mut this.io).poll_flush(cx)
165    }
166
167    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        let this = Pin::into_inner(self);
169        ready!(this.poll_ready_unpin(cx))?;
170        Pin::new(&mut this.io).poll_close(cx)
171    }
172}
173
174/// The stages of reading a new `Frame`.
175enum ReadState {
176    /// Initial reading state.
177    Init,
178    /// Reading the frame header.
179    Header {
180        offset: usize,
181        buffer: [u8; header::HEADER_SIZE],
182    },
183    /// Reading the frame body.
184    Body {
185        header: header::Header<()>,
186        offset: usize,
187        buffer: Vec<u8>,
188    },
189}
190
191impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
192    type Item = Result<Frame<()>, FrameDecodeError>;
193
194    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
195        let this = &mut *self;
196        loop {
197            log::trace!("{}: read: {:?}", this.id, this.read_state);
198            match this.read_state {
199                ReadState::Init => {
200                    this.read_state = ReadState::Header {
201                        offset: 0,
202                        buffer: [0; header::HEADER_SIZE],
203                    };
204                }
205                ReadState::Header {
206                    ref mut offset,
207                    ref mut buffer,
208                } => {
209                    if *offset == header::HEADER_SIZE {
210                        let header = match header::decode(buffer) {
211                            Ok(hd) => hd,
212                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
213                        };
214
215                        log::trace!("{}: read: {}", this.id, header);
216
217                        if header.tag() != header::Tag::Data {
218                            this.read_state = ReadState::Init;
219                            return Poll::Ready(Some(Ok(Frame::new(header))));
220                        }
221
222                        let body_len = header.len().val() as usize;
223
224                        if body_len > MAX_FRAME_BODY_LEN {
225                            return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
226                                body_len,
227                            ))));
228                        }
229
230                        this.read_state = ReadState::Body {
231                            header,
232                            offset: 0,
233                            buffer: vec![0; body_len],
234                        };
235
236                        continue;
237                    }
238
239                    let buf = &mut buffer[*offset..header::HEADER_SIZE];
240                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
241                        0 => {
242                            if *offset == 0 {
243                                return Poll::Ready(None);
244                            }
245                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
246                            return Poll::Ready(Some(Err(e)));
247                        }
248                        n => *offset += n,
249                    }
250                }
251                ReadState::Body {
252                    ref header,
253                    ref mut offset,
254                    ref mut buffer,
255                } => {
256                    let body_len = header.len().val() as usize;
257
258                    if *offset == body_len {
259                        let h = header.clone();
260                        let v = std::mem::take(buffer);
261                        this.read_state = ReadState::Init;
262                        return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
263                    }
264
265                    let buf = &mut buffer[*offset..body_len];
266                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
267                        0 => {
268                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
269                            return Poll::Ready(Some(Err(e)));
270                        }
271                        n => *offset += n,
272                    }
273                }
274            }
275        }
276    }
277}
278
279impl fmt::Debug for ReadState {
280    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
281        match self {
282            ReadState::Init => f.write_str("(ReadState::Init)"),
283            ReadState::Header { offset, .. } => {
284                write!(f, "(ReadState::Header (offset {offset}))")
285            }
286            ReadState::Body {
287                header,
288                offset,
289                buffer,
290            } => {
291                write!(
292                    f,
293                    "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
294                    header,
295                    offset,
296                    buffer.len()
297                )
298            }
299        }
300    }
301}
302
303/// Possible errors while decoding a message frame.
304#[non_exhaustive]
305#[derive(Debug)]
306pub enum FrameDecodeError {
307    /// An I/O error.
308    Io(io::Error),
309    /// Decoding the frame header failed.
310    Header(HeaderDecodeError),
311    /// A data frame body length is larger than the configured maximum.
312    FrameTooLarge(usize),
313}
314
315impl std::fmt::Display for FrameDecodeError {
316    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
317        match self {
318            FrameDecodeError::Io(e) => write!(f, "i/o error: {e}"),
319            FrameDecodeError::Header(e) => write!(f, "decode error: {e}"),
320            FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({n})"),
321        }
322    }
323}
324
325impl std::error::Error for FrameDecodeError {
326    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
327        match self {
328            FrameDecodeError::Io(e) => Some(e),
329            FrameDecodeError::Header(e) => Some(e),
330            FrameDecodeError::FrameTooLarge(_) => None,
331        }
332    }
333}
334
335impl From<std::io::Error> for FrameDecodeError {
336    fn from(e: std::io::Error) -> Self {
337        FrameDecodeError::Io(e)
338    }
339}
340
341impl From<HeaderDecodeError> for FrameDecodeError {
342    fn from(e: HeaderDecodeError) -> Self {
343        FrameDecodeError::Header(e)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use quickcheck::{Arbitrary, Gen, QuickCheck};
351    use rand::RngCore;
352
353    impl Arbitrary for Frame<()> {
354        fn arbitrary(g: &mut Gen) -> Self {
355            let mut header: header::Header<()> = Arbitrary::arbitrary(g);
356            let body = if header.tag() == header::Tag::Data {
357                header.set_len(header.len().val() % 4096);
358                let mut b = vec![0; header.len().val() as usize];
359                rand::rng().fill_bytes(&mut b);
360                b
361            } else {
362                Vec::new()
363            };
364            Frame { header, body }
365        }
366    }
367
368    #[test]
369    fn encode_decode_identity() {
370        fn property(f: Frame<()>) -> bool {
371            futures::executor::block_on(async move {
372                let id = crate::connection::Id::random();
373                let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()));
374                if io.send(f.clone()).await.is_err() {
375                    return false;
376                }
377                if io.flush().await.is_err() {
378                    return false;
379                }
380                io.io.set_position(0);
381                if let Ok(Some(x)) = io.try_next().await {
382                    x == f
383                } else {
384                    false
385                }
386            })
387        }
388
389        QuickCheck::new()
390            .tests(10_000)
391            .quickcheck(property as fn(Frame<()>) -> bool)
392    }
393}