1pub(crate) mod protocol;
22
23pub use protocol::ProtocolSupport;
24
25use crate::codec::Codec;
26use crate::handler::protocol::Protocol;
27use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD};
28
29use futures::channel::mpsc;
30use futures::{channel::oneshot, prelude::*};
31use libp2p_swarm::handler::{
32 ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
33 ListenUpgradeError,
34};
35use libp2p_swarm::{
36 handler::{ConnectionHandler, ConnectionHandlerEvent, StreamUpgradeError},
37 SubstreamProtocol,
38};
39use smallvec::SmallVec;
40use std::{
41 collections::VecDeque,
42 fmt, io,
43 sync::{
44 atomic::{AtomicU64, Ordering},
45 Arc,
46 },
47 task::{Context, Poll},
48 time::Duration,
49};
50
51pub struct Handler<TCodec>
53where
54 TCodec: Codec,
55{
56 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
58 codec: TCodec,
60 pending_events: VecDeque<Event<TCodec>>,
62 pending_outbound: VecDeque<OutboundMessage<TCodec>>,
64
65 requested_outbound: VecDeque<OutboundMessage<TCodec>>,
66 inbound_receiver: mpsc::Receiver<(
68 InboundRequestId,
69 TCodec::Request,
70 oneshot::Sender<TCodec::Response>,
71 )>,
72 inbound_sender: mpsc::Sender<(
74 InboundRequestId,
75 TCodec::Request,
76 oneshot::Sender<TCodec::Response>,
77 )>,
78
79 inbound_request_id: Arc<AtomicU64>,
80
81 worker_streams: futures_bounded::FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85enum RequestId {
86 Inbound(InboundRequestId),
87 Outbound(OutboundRequestId),
88}
89
90impl<TCodec> Handler<TCodec>
91where
92 TCodec: Codec + Send + Clone + 'static,
93{
94 pub(super) fn new(
95 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
96 codec: TCodec,
97 substream_timeout: Duration,
98 inbound_request_id: Arc<AtomicU64>,
99 max_concurrent_streams: usize,
100 ) -> Self {
101 let (inbound_sender, inbound_receiver) = mpsc::channel(0);
102 Self {
103 inbound_protocols,
104 codec,
105 pending_outbound: VecDeque::new(),
106 requested_outbound: Default::default(),
107 inbound_receiver,
108 inbound_sender,
109 pending_events: VecDeque::new(),
110 inbound_request_id,
111 worker_streams: futures_bounded::FuturesMap::new(
112 substream_timeout,
113 max_concurrent_streams,
114 ),
115 }
116 }
117
118 fn next_inbound_request_id(&mut self) -> InboundRequestId {
120 InboundRequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed))
121 }
122
123 fn on_fully_negotiated_inbound(
124 &mut self,
125 FullyNegotiatedInbound {
126 protocol: (mut stream, protocol),
127 info: (),
128 }: FullyNegotiatedInbound<
129 <Self as ConnectionHandler>::InboundProtocol,
130 <Self as ConnectionHandler>::InboundOpenInfo,
131 >,
132 ) {
133 let mut codec = self.codec.clone();
134 let request_id = self.next_inbound_request_id();
135 let mut sender = self.inbound_sender.clone();
136
137 let recv = async move {
138 let (rs_send, rs_recv) = oneshot::channel();
141
142 let read = codec.read_request(&protocol, &mut stream);
143 let request = read.await?;
144 sender
145 .send((request_id, request, rs_send))
146 .await
147 .expect("`ConnectionHandler` owns both ends of the channel");
148 drop(sender);
149
150 if let Ok(response) = rs_recv.await {
151 let write = codec.write_response(&protocol, &mut stream, response);
152 write.await?;
153
154 stream.close().await?;
155 Ok(Event::ResponseSent(request_id))
156 } else {
157 stream.close().await?;
158 Ok(Event::ResponseOmission(request_id))
159 }
160 };
161
162 if self
166 .worker_streams
167 .try_push(RequestId::Inbound(request_id), recv.boxed())
168 .is_err()
169 {
170 tracing::warn!("Dropping inbound stream because we are at capacity")
171 }
172 }
173
174 fn on_fully_negotiated_outbound(
175 &mut self,
176 FullyNegotiatedOutbound {
177 protocol: (mut stream, protocol),
178 info: (),
179 }: FullyNegotiatedOutbound<
180 <Self as ConnectionHandler>::OutboundProtocol,
181 <Self as ConnectionHandler>::OutboundOpenInfo,
182 >,
183 ) {
184 let message = self
185 .requested_outbound
186 .pop_front()
187 .expect("negotiated a stream without a pending message");
188
189 let mut codec = self.codec.clone();
190 let request_id = message.request_id;
191
192 let send = async move {
193 let write = codec.write_request(&protocol, &mut stream, message.request);
194 write.await?;
195 stream.close().await?;
196 let read = codec.read_response(&protocol, &mut stream);
197 let response = read.await?;
198
199 Ok(Event::Response {
200 request_id,
201 response,
202 })
203 };
204
205 if self
206 .worker_streams
207 .try_push(RequestId::Outbound(request_id), send.boxed())
208 .is_err()
209 {
210 self.pending_events.push_back(Event::OutboundStreamFailed {
211 request_id: message.request_id,
212 error: io::Error::new(io::ErrorKind::Other, "max sub-streams reached"),
213 });
214 }
215 }
216
217 fn on_dial_upgrade_error(
218 &mut self,
219 DialUpgradeError { error, info: () }: DialUpgradeError<
220 <Self as ConnectionHandler>::OutboundOpenInfo,
221 <Self as ConnectionHandler>::OutboundProtocol,
222 >,
223 ) {
224 let message = self
225 .requested_outbound
226 .pop_front()
227 .expect("negotiated a stream without a pending message");
228
229 match error {
230 StreamUpgradeError::Timeout => {
231 self.pending_events
232 .push_back(Event::OutboundTimeout(message.request_id));
233 }
234 StreamUpgradeError::NegotiationFailed => {
235 self.pending_events
241 .push_back(Event::OutboundUnsupportedProtocols(message.request_id));
242 }
243 StreamUpgradeError::Apply(e) => void::unreachable(e),
244 StreamUpgradeError::Io(e) => {
245 self.pending_events.push_back(Event::OutboundStreamFailed {
246 request_id: message.request_id,
247 error: e,
248 });
249 }
250 }
251 }
252 fn on_listen_upgrade_error(
253 &mut self,
254 ListenUpgradeError { error, .. }: ListenUpgradeError<
255 <Self as ConnectionHandler>::InboundOpenInfo,
256 <Self as ConnectionHandler>::InboundProtocol,
257 >,
258 ) {
259 void::unreachable(error)
260 }
261}
262
263pub enum Event<TCodec>
265where
266 TCodec: Codec,
267{
268 Request {
270 request_id: InboundRequestId,
271 request: TCodec::Request,
272 sender: oneshot::Sender<TCodec::Response>,
273 },
274 Response {
276 request_id: OutboundRequestId,
277 response: TCodec::Response,
278 },
279 ResponseSent(InboundRequestId),
281 ResponseOmission(InboundRequestId),
284 OutboundTimeout(OutboundRequestId),
287 OutboundUnsupportedProtocols(OutboundRequestId),
289 OutboundStreamFailed {
290 request_id: OutboundRequestId,
291 error: io::Error,
292 },
293 InboundTimeout(InboundRequestId),
296 InboundStreamFailed {
297 request_id: InboundRequestId,
298 error: io::Error,
299 },
300}
301
302impl<TCodec: Codec> fmt::Debug for Event<TCodec> {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 match self {
305 Event::Request {
306 request_id,
307 request: _,
308 sender: _,
309 } => f
310 .debug_struct("Event::Request")
311 .field("request_id", request_id)
312 .finish(),
313 Event::Response {
314 request_id,
315 response: _,
316 } => f
317 .debug_struct("Event::Response")
318 .field("request_id", request_id)
319 .finish(),
320 Event::ResponseSent(request_id) => f
321 .debug_tuple("Event::ResponseSent")
322 .field(request_id)
323 .finish(),
324 Event::ResponseOmission(request_id) => f
325 .debug_tuple("Event::ResponseOmission")
326 .field(request_id)
327 .finish(),
328 Event::OutboundTimeout(request_id) => f
329 .debug_tuple("Event::OutboundTimeout")
330 .field(request_id)
331 .finish(),
332 Event::OutboundUnsupportedProtocols(request_id) => f
333 .debug_tuple("Event::OutboundUnsupportedProtocols")
334 .field(request_id)
335 .finish(),
336 Event::OutboundStreamFailed { request_id, error } => f
337 .debug_struct("Event::OutboundStreamFailed")
338 .field("request_id", &request_id)
339 .field("error", &error)
340 .finish(),
341 Event::InboundTimeout(request_id) => f
342 .debug_tuple("Event::InboundTimeout")
343 .field(request_id)
344 .finish(),
345 Event::InboundStreamFailed { request_id, error } => f
346 .debug_struct("Event::InboundStreamFailed")
347 .field("request_id", &request_id)
348 .field("error", &error)
349 .finish(),
350 }
351 }
352}
353
354pub struct OutboundMessage<TCodec: Codec> {
355 pub(crate) request_id: OutboundRequestId,
356 pub(crate) request: TCodec::Request,
357 pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>,
358}
359
360impl<TCodec> fmt::Debug for OutboundMessage<TCodec>
361where
362 TCodec: Codec,
363{
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 f.debug_struct("OutboundMessage").finish_non_exhaustive()
366 }
367}
368
369impl<TCodec> ConnectionHandler for Handler<TCodec>
370where
371 TCodec: Codec + Send + Clone + 'static,
372{
373 type FromBehaviour = OutboundMessage<TCodec>;
374 type ToBehaviour = Event<TCodec>;
375 type InboundProtocol = Protocol<TCodec::Protocol>;
376 type OutboundProtocol = Protocol<TCodec::Protocol>;
377 type OutboundOpenInfo = ();
378 type InboundOpenInfo = ();
379
380 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
381 SubstreamProtocol::new(
382 Protocol {
383 protocols: self.inbound_protocols.clone(),
384 },
385 (),
386 )
387 }
388
389 fn on_behaviour_event(&mut self, request: Self::FromBehaviour) {
390 self.pending_outbound.push_back(request);
391 }
392
393 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
394 fn poll(
395 &mut self,
396 cx: &mut Context<'_>,
397 ) -> Poll<ConnectionHandlerEvent<Protocol<TCodec::Protocol>, (), Self::ToBehaviour>> {
398 match self.worker_streams.poll_unpin(cx) {
399 Poll::Ready((_, Ok(Ok(event)))) => {
400 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
401 }
402 Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
403 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
404 Event::InboundStreamFailed {
405 request_id: id,
406 error: e,
407 },
408 ));
409 }
410 Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
411 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
412 Event::OutboundStreamFailed {
413 request_id: id,
414 error: e,
415 },
416 ));
417 }
418 Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => {
419 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
420 Event::InboundTimeout(id),
421 ));
422 }
423 Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => {
424 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
425 Event::OutboundTimeout(id),
426 ));
427 }
428 Poll::Pending => {}
429 }
430
431 if let Some(event) = self.pending_events.pop_front() {
433 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
434 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
435 self.pending_events.shrink_to_fit();
436 }
437
438 if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) {
440 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request {
443 request_id: id,
444 request: rq,
445 sender: rs_sender,
446 }));
447 }
448
449 if let Some(request) = self.pending_outbound.pop_front() {
451 let protocols = request.protocols.clone();
452 self.requested_outbound.push_back(request);
453
454 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
455 protocol: SubstreamProtocol::new(Protocol { protocols }, ()),
456 });
457 }
458
459 debug_assert!(self.pending_outbound.is_empty());
460
461 if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
462 self.pending_outbound.shrink_to_fit();
463 }
464
465 Poll::Pending
466 }
467
468 fn on_connection_event(
469 &mut self,
470 event: ConnectionEvent<
471 Self::InboundProtocol,
472 Self::OutboundProtocol,
473 Self::InboundOpenInfo,
474 Self::OutboundOpenInfo,
475 >,
476 ) {
477 match event {
478 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
479 self.on_fully_negotiated_inbound(fully_negotiated_inbound)
480 }
481 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
482 self.on_fully_negotiated_outbound(fully_negotiated_outbound)
483 }
484 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
485 self.on_dial_upgrade_error(dial_upgrade_error)
486 }
487 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
488 self.on_listen_upgrade_error(listen_upgrade_error)
489 }
490 _ => {}
491 }
492 }
493}