1use super::{
12 header::{self, HeaderDecodeError},
13 Frame,
14};
15use crate::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18 fmt, io,
19 pin::Pin,
20 task::{Context, Poll},
21};
22
23const MAX_FRAME_BODY_LEN: usize = crate::MIB;
29
30#[derive(Debug)]
32pub(crate) struct Io<T> {
33 id: Id,
34 io: T,
35 read_state: ReadState,
36 write_state: WriteState,
37}
38
39impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
40 pub(crate) fn new(id: Id, io: T) -> Self {
41 Io {
42 id,
43 io,
44 read_state: ReadState::Init,
45 write_state: WriteState::Init,
46 }
47 }
48}
49
50enum WriteState {
52 Init,
53 Header {
54 header: [u8; header::HEADER_SIZE],
55 buffer: Vec<u8>,
56 offset: usize,
57 },
58 Body {
59 buffer: Vec<u8>,
60 offset: usize,
61 },
62}
63
64impl fmt::Debug for WriteState {
65 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
66 match self {
67 WriteState::Init => f.write_str("(WriteState::Init)"),
68 WriteState::Header { offset, .. } => {
69 write!(f, "(WriteState::Header (offset {offset}))")
70 }
71 WriteState::Body { offset, buffer } => {
72 write!(
73 f,
74 "(WriteState::Body (offset {}) (buffer-len {}))",
75 offset,
76 buffer.len()
77 )
78 }
79 }
80 }
81}
82
83impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
84 type Error = io::Error;
85
86 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 let this = Pin::into_inner(self);
88 loop {
89 log::trace!("{}: write: {:?}", this.id, this.write_state);
90 match &mut this.write_state {
91 WriteState::Init => return Poll::Ready(Ok(())),
92 WriteState::Header {
93 header,
94 buffer,
95 ref mut offset,
96 } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
97 Poll::Pending => return Poll::Pending,
98 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
99 Poll::Ready(Ok(n)) => {
100 if n == 0 {
101 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
102 }
103 *offset += n;
104
105 if *offset > header.len() {
106 return Poll::Ready(Err(io::Error::other(format!(
107 "Writer header returned invalid write count n={n}: {offset} > {} ",
108 header.len(),
109 ))));
110 }
111
112 if *offset == header.len() {
113 if !buffer.is_empty() {
114 let buffer = std::mem::take(buffer);
115 this.write_state = WriteState::Body { buffer, offset: 0 };
116 } else {
117 this.write_state = WriteState::Init;
118 }
119 }
120 }
121 },
122 WriteState::Body {
123 buffer,
124 ref mut offset,
125 } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
126 Poll::Pending => return Poll::Pending,
127 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
128 Poll::Ready(Ok(n)) => {
129 if n == 0 {
130 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
131 }
132 *offset += n;
133
134 if *offset > buffer.len() {
135 return Poll::Ready(Err(io::Error::other(format!(
136 "Writer body returned invalid write count n={n}: {offset} > {} ",
137 buffer.len(),
138 ))));
139 }
140
141 if *offset == buffer.len() {
142 this.write_state = WriteState::Init;
143 }
144 }
145 },
146 }
147 }
148 }
149
150 fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
151 let header = header::encode(&f.header);
152 let buffer = f.body;
153 self.get_mut().write_state = WriteState::Header {
154 header,
155 buffer,
156 offset: 0,
157 };
158 Ok(())
159 }
160
161 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
162 let this = Pin::into_inner(self);
163 ready!(this.poll_ready_unpin(cx))?;
164 Pin::new(&mut this.io).poll_flush(cx)
165 }
166
167 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 let this = Pin::into_inner(self);
169 ready!(this.poll_ready_unpin(cx))?;
170 Pin::new(&mut this.io).poll_close(cx)
171 }
172}
173
174enum ReadState {
176 Init,
178 Header {
180 offset: usize,
181 buffer: [u8; header::HEADER_SIZE],
182 },
183 Body {
185 header: header::Header<()>,
186 offset: usize,
187 buffer: Vec<u8>,
188 },
189}
190
191impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
192 type Item = Result<Frame<()>, FrameDecodeError>;
193
194 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
195 let this = &mut *self;
196 loop {
197 log::trace!("{}: read: {:?}", this.id, this.read_state);
198 match this.read_state {
199 ReadState::Init => {
200 this.read_state = ReadState::Header {
201 offset: 0,
202 buffer: [0; header::HEADER_SIZE],
203 };
204 }
205 ReadState::Header {
206 ref mut offset,
207 ref mut buffer,
208 } => {
209 if *offset == header::HEADER_SIZE {
210 let header = match header::decode(buffer) {
211 Ok(hd) => hd,
212 Err(e) => return Poll::Ready(Some(Err(e.into()))),
213 };
214
215 log::trace!("{}: read: {}", this.id, header);
216
217 if header.tag() != header::Tag::Data {
218 this.read_state = ReadState::Init;
219 return Poll::Ready(Some(Ok(Frame::new(header))));
220 }
221
222 let body_len = header.len().val() as usize;
223
224 if body_len > MAX_FRAME_BODY_LEN {
225 return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
226 body_len,
227 ))));
228 }
229
230 this.read_state = ReadState::Body {
231 header,
232 offset: 0,
233 buffer: vec![0; body_len],
234 };
235
236 continue;
237 }
238
239 let buf = &mut buffer[*offset..header::HEADER_SIZE];
240 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
241 0 => {
242 if *offset == 0 {
243 return Poll::Ready(None);
244 }
245 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
246 return Poll::Ready(Some(Err(e)));
247 }
248 n => *offset += n,
249 }
250 }
251 ReadState::Body {
252 ref header,
253 ref mut offset,
254 ref mut buffer,
255 } => {
256 let body_len = header.len().val() as usize;
257
258 if *offset == body_len {
259 let h = header.clone();
260 let v = std::mem::take(buffer);
261 this.read_state = ReadState::Init;
262 return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
263 }
264
265 let buf = &mut buffer[*offset..body_len];
266 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
267 0 => {
268 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
269 return Poll::Ready(Some(Err(e)));
270 }
271 n => *offset += n,
272 }
273 }
274 }
275 }
276 }
277}
278
279impl fmt::Debug for ReadState {
280 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
281 match self {
282 ReadState::Init => f.write_str("(ReadState::Init)"),
283 ReadState::Header { offset, .. } => {
284 write!(f, "(ReadState::Header (offset {offset}))")
285 }
286 ReadState::Body {
287 header,
288 offset,
289 buffer,
290 } => {
291 write!(
292 f,
293 "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
294 header,
295 offset,
296 buffer.len()
297 )
298 }
299 }
300 }
301}
302
303#[non_exhaustive]
305#[derive(Debug)]
306pub enum FrameDecodeError {
307 Io(io::Error),
309 Header(HeaderDecodeError),
311 FrameTooLarge(usize),
313}
314
315impl std::fmt::Display for FrameDecodeError {
316 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
317 match self {
318 FrameDecodeError::Io(e) => write!(f, "i/o error: {e}"),
319 FrameDecodeError::Header(e) => write!(f, "decode error: {e}"),
320 FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({n})"),
321 }
322 }
323}
324
325impl std::error::Error for FrameDecodeError {
326 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
327 match self {
328 FrameDecodeError::Io(e) => Some(e),
329 FrameDecodeError::Header(e) => Some(e),
330 FrameDecodeError::FrameTooLarge(_) => None,
331 }
332 }
333}
334
335impl From<std::io::Error> for FrameDecodeError {
336 fn from(e: std::io::Error) -> Self {
337 FrameDecodeError::Io(e)
338 }
339}
340
341impl From<HeaderDecodeError> for FrameDecodeError {
342 fn from(e: HeaderDecodeError) -> Self {
343 FrameDecodeError::Header(e)
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use quickcheck::{Arbitrary, Gen, QuickCheck};
351 use rand::RngCore;
352
353 impl Arbitrary for Frame<()> {
354 fn arbitrary(g: &mut Gen) -> Self {
355 let mut header: header::Header<()> = Arbitrary::arbitrary(g);
356 let body = if header.tag() == header::Tag::Data {
357 header.set_len(header.len().val() % 4096);
358 let mut b = vec![0; header.len().val() as usize];
359 rand::rng().fill_bytes(&mut b);
360 b
361 } else {
362 Vec::new()
363 };
364 Frame { header, body }
365 }
366 }
367
368 #[test]
369 fn encode_decode_identity() {
370 fn property(f: Frame<()>) -> bool {
371 futures::executor::block_on(async move {
372 let id = crate::connection::Id::random();
373 let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()));
374 if io.send(f.clone()).await.is_err() {
375 return false;
376 }
377 if io.flush().await.is_err() {
378 return false;
379 }
380 io.io.set_position(0);
381 if let Ok(Some(x)) = io.try_next().await {
382 x == f
383 } else {
384 false
385 }
386 })
387 }
388
389 QuickCheck::new()
390 .tests(10_000)
391 .quickcheck(property as fn(Frame<()>) -> bool)
392 }
393}