yamux/connection/
closing.rs1use crate::connection::StreamCommand;
2use crate::frame::Frame;
3use crate::tagged_stream::TaggedStream;
4use crate::Result;
5use crate::{frame, StreamId};
6use futures::channel::mpsc;
7use futures::stream::{Fuse, SelectAll};
8use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt};
9use std::collections::VecDeque;
10use std::future::Future;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14#[must_use]
16pub struct Closing<T> {
17 state: State,
18 stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
19 pending_frames: VecDeque<Frame<()>>,
20 socket: Fuse<frame::Io<T>>,
21}
22
23impl<T> Closing<T>
24where
25 T: AsyncRead + AsyncWrite + Unpin,
26{
27 pub(crate) fn new(
28 stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
29 pending_frames: VecDeque<Frame<()>>,
30 socket: Fuse<frame::Io<T>>,
31 ) -> Self {
32 Self {
33 state: State::ClosingStreamReceiver,
34 stream_receivers,
35 pending_frames,
36 socket,
37 }
38 }
39}
40
41impl<T> Future for Closing<T>
42where
43 T: AsyncRead + AsyncWrite + Unpin,
44{
45 type Output = Result<()>;
46
47 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
48 let this = self.get_mut();
49
50 loop {
51 match this.state {
52 State::ClosingStreamReceiver => {
53 for stream in this.stream_receivers.iter_mut() {
54 stream.inner_mut().close();
55 }
56 this.state = State::DrainingStreamReceiver;
57 }
58
59 State::DrainingStreamReceiver => {
60 match this.stream_receivers.poll_next_unpin(cx) {
61 Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
62 this.pending_frames.push_back(frame.into());
63 }
64 Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
65 this.pending_frames
66 .push_back(Frame::close_stream(id, ack).into());
67 }
68 Poll::Ready(Some((_, None))) => {}
69 Poll::Pending | Poll::Ready(None) => {
70 this.pending_frames.push_back(Frame::term().into());
72 this.state = State::FlushingPendingFrames;
73 continue;
74 }
75 }
76 }
77 State::FlushingPendingFrames => {
78 ready!(this.socket.poll_ready_unpin(cx))?;
79
80 match this.pending_frames.pop_front() {
81 Some(frame) => this.socket.start_send_unpin(frame)?,
82 None => this.state = State::ClosingSocket,
83 }
84 }
85 State::ClosingSocket => {
86 ready!(this.socket.poll_close_unpin(cx))?;
87
88 return Poll::Ready(Ok(()));
89 }
90 }
91 }
92 }
93}
94
95enum State {
96 ClosingStreamReceiver,
97 DrainingStreamReceiver,
98 FlushingPendingFrames,
99 ClosingSocket,
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use futures::future::poll_fn;
106 use futures::FutureExt;
107
108 struct Socket {
109 written: Vec<u8>,
110 closed: bool,
111 }
112 impl AsyncRead for Socket {
113 fn poll_read(
114 self: Pin<&mut Self>,
115 _: &mut Context<'_>,
116 _: &mut [u8],
117 ) -> Poll<std::io::Result<usize>> {
118 unimplemented!()
119 }
120 }
121 impl AsyncWrite for Socket {
122 fn poll_write(
123 mut self: Pin<&mut Self>,
124 _: &mut Context<'_>,
125 buf: &[u8],
126 ) -> Poll<std::io::Result<usize>> {
127 assert!(!self.closed);
128 self.written.extend_from_slice(buf);
129 Poll::Ready(Ok(buf.len()))
130 }
131
132 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
133 unimplemented!()
134 }
135
136 fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
137 assert!(!self.closed);
138 self.closed = true;
139 Poll::Ready(Ok(()))
140 }
141 }
142
143 #[test]
144 fn pending_frames() {
145 let frame_pending = Frame::data(StreamId::new(1), vec![2]).unwrap().into();
146 let frame_data = Frame::data(StreamId::new(3), vec![4]).unwrap().into();
147 let frame_close = Frame::close_stream(StreamId::new(5), false).into();
148 let frame_close_ack = Frame::close_stream(StreamId::new(6), true).into();
149 let frame_term = Frame::term().into();
150 fn encode(buf: &mut Vec<u8>, frame: &Frame<()>) {
151 buf.extend_from_slice(&frame::header::encode(frame.header()));
152 if frame.header().tag() == frame::header::Tag::Data {
153 buf.extend_from_slice(frame.clone().into_data().body());
154 }
155 }
156 let mut expected_written = vec![];
157 encode(&mut expected_written, &frame_pending);
158 encode(&mut expected_written, &frame_data);
159 encode(&mut expected_written, &frame_close);
160 encode(&mut expected_written, &frame_close_ack);
161 encode(&mut expected_written, &frame_term);
162
163 let receiver = |frame: &Frame<_>, command: StreamCommand| {
164 TaggedStream::new(frame.header().stream_id(), {
165 let (mut tx, rx) = mpsc::channel(1);
166 tx.try_send(command).unwrap();
167 rx
168 })
169 };
170
171 let mut stream_receivers: SelectAll<_> = Default::default();
172 stream_receivers.push(receiver(
173 &frame_data,
174 StreamCommand::SendFrame(frame_data.clone().into_data().left()),
175 ));
176 stream_receivers.push(receiver(
177 &frame_close,
178 StreamCommand::CloseStream { ack: false },
179 ));
180 stream_receivers.push(receiver(
181 &frame_close_ack,
182 StreamCommand::CloseStream { ack: true },
183 ));
184 let pending_frames = vec![frame_pending];
185 let mut socket = Socket {
186 written: vec![],
187 closed: false,
188 };
189 let mut closing = Closing::new(
190 stream_receivers,
191 pending_frames.into(),
192 frame::Io::new(crate::connection::Id(0), &mut socket).fuse(),
193 );
194 futures::executor::block_on(async { poll_fn(|cx| closing.poll_unpin(cx)).await.unwrap() });
195 assert!(closing.pending_frames.is_empty());
196 assert!(socket.closed);
197 assert_eq!(socket.written, expected_written);
198 }
199}