1use crate::upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError};
22use crate::{connection::ConnectedPoint, Negotiated};
23use futures::{future::Either, prelude::*};
24use multistream_select::{DialerSelectFuture, ListenerSelectFuture};
25use std::{mem, pin::Pin, task::Context, task::Poll};
26
27pub(crate) use multistream_select::Version;
28
29pub(crate) fn apply<C, U>(
32 conn: C,
33 up: U,
34 cp: ConnectedPoint,
35 v: Version,
36) -> Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
37where
38 C: AsyncRead + AsyncWrite + Unpin,
39 U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
40{
41 match cp {
42 ConnectedPoint::Dialer { role_override, .. } if role_override.is_dialer() => {
43 Either::Right(apply_outbound(conn, up, v))
44 }
45 _ => Either::Left(apply_inbound(conn, up)),
46 }
47}
48
49pub(crate) fn apply_inbound<C, U>(conn: C, up: U) -> InboundUpgradeApply<C, U>
51where
52 C: AsyncRead + AsyncWrite + Unpin,
53 U: InboundConnectionUpgrade<Negotiated<C>>,
54{
55 InboundUpgradeApply {
56 inner: InboundUpgradeApplyState::Init {
57 future: multistream_select::listener_select_proto(conn, up.protocol_info()),
58 upgrade: up,
59 },
60 }
61}
62
63pub(crate) fn apply_outbound<C, U>(conn: C, up: U, v: Version) -> OutboundUpgradeApply<C, U>
65where
66 C: AsyncRead + AsyncWrite + Unpin,
67 U: OutboundConnectionUpgrade<Negotiated<C>>,
68{
69 OutboundUpgradeApply {
70 inner: OutboundUpgradeApplyState::Init {
71 future: multistream_select::dialer_select_proto(conn, up.protocol_info(), v),
72 upgrade: up,
73 },
74 }
75}
76
77pub struct InboundUpgradeApply<C, U>
79where
80 C: AsyncRead + AsyncWrite + Unpin,
81 U: InboundConnectionUpgrade<Negotiated<C>>,
82{
83 inner: InboundUpgradeApplyState<C, U>,
84}
85
86#[allow(clippy::large_enum_variant)]
87enum InboundUpgradeApplyState<C, U>
88where
89 C: AsyncRead + AsyncWrite + Unpin,
90 U: InboundConnectionUpgrade<Negotiated<C>>,
91{
92 Init {
93 future: ListenerSelectFuture<C, U::Info>,
94 upgrade: U,
95 },
96 Upgrade {
97 future: Pin<Box<U::Future>>,
98 name: String,
99 },
100 Undefined,
101}
102
103impl<C, U> Unpin for InboundUpgradeApply<C, U>
104where
105 C: AsyncRead + AsyncWrite + Unpin,
106 U: InboundConnectionUpgrade<Negotiated<C>>,
107{
108}
109
110impl<C, U> Future for InboundUpgradeApply<C, U>
111where
112 C: AsyncRead + AsyncWrite + Unpin,
113 U: InboundConnectionUpgrade<Negotiated<C>>,
114{
115 type Output = Result<U::Output, UpgradeError<U::Error>>;
116
117 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
118 loop {
119 match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) {
120 InboundUpgradeApplyState::Init {
121 mut future,
122 upgrade,
123 } => {
124 let (info, io) = match Future::poll(Pin::new(&mut future), cx)? {
125 Poll::Ready(x) => x,
126 Poll::Pending => {
127 self.inner = InboundUpgradeApplyState::Init { future, upgrade };
128 return Poll::Pending;
129 }
130 };
131 self.inner = InboundUpgradeApplyState::Upgrade {
132 future: Box::pin(upgrade.upgrade_inbound(io, info.clone())),
133 name: info.as_ref().to_owned(),
134 };
135 }
136 InboundUpgradeApplyState::Upgrade { mut future, name } => {
137 match Future::poll(Pin::new(&mut future), cx) {
138 Poll::Pending => {
139 self.inner = InboundUpgradeApplyState::Upgrade { future, name };
140 return Poll::Pending;
141 }
142 Poll::Ready(Ok(x)) => {
143 tracing::trace!(upgrade=%name, "Upgraded inbound stream");
144 return Poll::Ready(Ok(x));
145 }
146 Poll::Ready(Err(e)) => {
147 tracing::debug!(upgrade=%name, "Failed to upgrade inbound stream");
148 return Poll::Ready(Err(UpgradeError::Apply(e)));
149 }
150 }
151 }
152 InboundUpgradeApplyState::Undefined => {
153 panic!("InboundUpgradeApplyState::poll called after completion")
154 }
155 }
156 }
157 }
158}
159
160pub struct OutboundUpgradeApply<C, U>
162where
163 C: AsyncRead + AsyncWrite + Unpin,
164 U: OutboundConnectionUpgrade<Negotiated<C>>,
165{
166 inner: OutboundUpgradeApplyState<C, U>,
167}
168
169enum OutboundUpgradeApplyState<C, U>
170where
171 C: AsyncRead + AsyncWrite + Unpin,
172 U: OutboundConnectionUpgrade<Negotiated<C>>,
173{
174 Init {
175 future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
176 upgrade: U,
177 },
178 Upgrade {
179 future: Pin<Box<U::Future>>,
180 name: String,
181 },
182 Undefined,
183}
184
185impl<C, U> Unpin for OutboundUpgradeApply<C, U>
186where
187 C: AsyncRead + AsyncWrite + Unpin,
188 U: OutboundConnectionUpgrade<Negotiated<C>>,
189{
190}
191
192impl<C, U> Future for OutboundUpgradeApply<C, U>
193where
194 C: AsyncRead + AsyncWrite + Unpin,
195 U: OutboundConnectionUpgrade<Negotiated<C>>,
196{
197 type Output = Result<U::Output, UpgradeError<U::Error>>;
198
199 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
200 loop {
201 match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) {
202 OutboundUpgradeApplyState::Init {
203 mut future,
204 upgrade,
205 } => {
206 let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? {
207 Poll::Ready(x) => x,
208 Poll::Pending => {
209 self.inner = OutboundUpgradeApplyState::Init { future, upgrade };
210 return Poll::Pending;
211 }
212 };
213 self.inner = OutboundUpgradeApplyState::Upgrade {
214 future: Box::pin(upgrade.upgrade_outbound(connection, info.clone())),
215 name: info.as_ref().to_owned(),
216 };
217 }
218 OutboundUpgradeApplyState::Upgrade { mut future, name } => {
219 match Future::poll(Pin::new(&mut future), cx) {
220 Poll::Pending => {
221 self.inner = OutboundUpgradeApplyState::Upgrade { future, name };
222 return Poll::Pending;
223 }
224 Poll::Ready(Ok(x)) => {
225 tracing::trace!(upgrade=%name, "Upgraded outbound stream");
226 return Poll::Ready(Ok(x));
227 }
228 Poll::Ready(Err(e)) => {
229 tracing::debug!(upgrade=%name, "Failed to upgrade outbound stream",);
230 return Poll::Ready(Err(UpgradeError::Apply(e)));
231 }
232 }
233 }
234 OutboundUpgradeApplyState::Undefined => {
235 panic!("OutboundUpgradeApplyState::poll called after completion")
236 }
237 }
238 }
239 }
240}