libp2p_request_response/
handler.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21pub(crate) mod protocol;
22
23pub use protocol::ProtocolSupport;
24
25use crate::codec::Codec;
26use crate::handler::protocol::Protocol;
27use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD};
28
29use futures::channel::mpsc;
30use futures::{channel::oneshot, prelude::*};
31use libp2p_swarm::handler::{
32    ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
33    ListenUpgradeError,
34};
35use libp2p_swarm::{
36    handler::{ConnectionHandler, ConnectionHandlerEvent, StreamUpgradeError},
37    SubstreamProtocol,
38};
39use smallvec::SmallVec;
40use std::{
41    collections::VecDeque,
42    fmt, io,
43    sync::{
44        atomic::{AtomicU64, Ordering},
45        Arc,
46    },
47    task::{Context, Poll},
48    time::Duration,
49};
50
51/// A connection handler for a request response [`Behaviour`](super::Behaviour) protocol.
52pub struct Handler<TCodec>
53where
54    TCodec: Codec,
55{
56    /// The supported inbound protocols.
57    inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
58    /// The request/response message codec.
59    codec: TCodec,
60    /// Queue of events to emit in `poll()`.
61    pending_events: VecDeque<Event<TCodec>>,
62    /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`.
63    pending_outbound: VecDeque<OutboundMessage<TCodec>>,
64
65    requested_outbound: VecDeque<OutboundMessage<TCodec>>,
66    /// A channel for receiving inbound requests.
67    inbound_receiver: mpsc::Receiver<(
68        InboundRequestId,
69        TCodec::Request,
70        oneshot::Sender<TCodec::Response>,
71    )>,
72    /// The [`mpsc::Sender`] for the above receiver. Cloned for each inbound request.
73    inbound_sender: mpsc::Sender<(
74        InboundRequestId,
75        TCodec::Request,
76        oneshot::Sender<TCodec::Response>,
77    )>,
78
79    inbound_request_id: Arc<AtomicU64>,
80
81    worker_streams: futures_bounded::FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85enum RequestId {
86    Inbound(InboundRequestId),
87    Outbound(OutboundRequestId),
88}
89
90impl<TCodec> Handler<TCodec>
91where
92    TCodec: Codec + Send + Clone + 'static,
93{
94    pub(super) fn new(
95        inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
96        codec: TCodec,
97        substream_timeout: Duration,
98        inbound_request_id: Arc<AtomicU64>,
99        max_concurrent_streams: usize,
100    ) -> Self {
101        let (inbound_sender, inbound_receiver) = mpsc::channel(0);
102        Self {
103            inbound_protocols,
104            codec,
105            pending_outbound: VecDeque::new(),
106            requested_outbound: Default::default(),
107            inbound_receiver,
108            inbound_sender,
109            pending_events: VecDeque::new(),
110            inbound_request_id,
111            worker_streams: futures_bounded::FuturesMap::new(
112                substream_timeout,
113                max_concurrent_streams,
114            ),
115        }
116    }
117
118    /// Returns the next inbound request ID.
119    fn next_inbound_request_id(&mut self) -> InboundRequestId {
120        InboundRequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed))
121    }
122
123    fn on_fully_negotiated_inbound(
124        &mut self,
125        FullyNegotiatedInbound {
126            protocol: (mut stream, protocol),
127            info: (),
128        }: FullyNegotiatedInbound<
129            <Self as ConnectionHandler>::InboundProtocol,
130            <Self as ConnectionHandler>::InboundOpenInfo,
131        >,
132    ) {
133        let mut codec = self.codec.clone();
134        let request_id = self.next_inbound_request_id();
135        let mut sender = self.inbound_sender.clone();
136
137        let recv = async move {
138            // A channel for notifying the inbound upgrade when the
139            // response is sent.
140            let (rs_send, rs_recv) = oneshot::channel();
141
142            let read = codec.read_request(&protocol, &mut stream);
143            let request = read.await?;
144            sender
145                .send((request_id, request, rs_send))
146                .await
147                .expect("`ConnectionHandler` owns both ends of the channel");
148            drop(sender);
149
150            if let Ok(response) = rs_recv.await {
151                let write = codec.write_response(&protocol, &mut stream, response);
152                write.await?;
153
154                stream.close().await?;
155                Ok(Event::ResponseSent(request_id))
156            } else {
157                stream.close().await?;
158                Ok(Event::ResponseOmission(request_id))
159            }
160        };
161
162        // Inbound connections are reported to the upper layer from within the above task,
163        // so by failing to schedule it, it means the upper layer will never know about the
164        // inbound request. Because of that we do not report any inbound failure.
165        if self
166            .worker_streams
167            .try_push(RequestId::Inbound(request_id), recv.boxed())
168            .is_err()
169        {
170            tracing::warn!("Dropping inbound stream because we are at capacity")
171        }
172    }
173
174    fn on_fully_negotiated_outbound(
175        &mut self,
176        FullyNegotiatedOutbound {
177            protocol: (mut stream, protocol),
178            info: (),
179        }: FullyNegotiatedOutbound<
180            <Self as ConnectionHandler>::OutboundProtocol,
181            <Self as ConnectionHandler>::OutboundOpenInfo,
182        >,
183    ) {
184        let message = self
185            .requested_outbound
186            .pop_front()
187            .expect("negotiated a stream without a pending message");
188
189        let mut codec = self.codec.clone();
190        let request_id = message.request_id;
191
192        let send = async move {
193            let write = codec.write_request(&protocol, &mut stream, message.request);
194            write.await?;
195            stream.close().await?;
196            let read = codec.read_response(&protocol, &mut stream);
197            let response = read.await?;
198
199            Ok(Event::Response {
200                request_id,
201                response,
202            })
203        };
204
205        if self
206            .worker_streams
207            .try_push(RequestId::Outbound(request_id), send.boxed())
208            .is_err()
209        {
210            self.pending_events.push_back(Event::OutboundStreamFailed {
211                request_id: message.request_id,
212                error: io::Error::new(io::ErrorKind::Other, "max sub-streams reached"),
213            });
214        }
215    }
216
217    fn on_dial_upgrade_error(
218        &mut self,
219        DialUpgradeError { error, info: () }: DialUpgradeError<
220            <Self as ConnectionHandler>::OutboundOpenInfo,
221            <Self as ConnectionHandler>::OutboundProtocol,
222        >,
223    ) {
224        let message = self
225            .requested_outbound
226            .pop_front()
227            .expect("negotiated a stream without a pending message");
228
229        match error {
230            StreamUpgradeError::Timeout => {
231                self.pending_events
232                    .push_back(Event::OutboundTimeout(message.request_id));
233            }
234            StreamUpgradeError::NegotiationFailed => {
235                // The remote merely doesn't support the protocol(s) we requested.
236                // This is no reason to close the connection, which may
237                // successfully communicate with other protocols already.
238                // An event is reported to permit user code to react to the fact that
239                // the remote peer does not support the requested protocol(s).
240                self.pending_events
241                    .push_back(Event::OutboundUnsupportedProtocols(message.request_id));
242            }
243            StreamUpgradeError::Apply(e) => void::unreachable(e),
244            StreamUpgradeError::Io(e) => {
245                self.pending_events.push_back(Event::OutboundStreamFailed {
246                    request_id: message.request_id,
247                    error: e,
248                });
249            }
250        }
251    }
252    fn on_listen_upgrade_error(
253        &mut self,
254        ListenUpgradeError { error, .. }: ListenUpgradeError<
255            <Self as ConnectionHandler>::InboundOpenInfo,
256            <Self as ConnectionHandler>::InboundProtocol,
257        >,
258    ) {
259        void::unreachable(error)
260    }
261}
262
263/// The events emitted by the [`Handler`].
264pub enum Event<TCodec>
265where
266    TCodec: Codec,
267{
268    /// A request has been received.
269    Request {
270        request_id: InboundRequestId,
271        request: TCodec::Request,
272        sender: oneshot::Sender<TCodec::Response>,
273    },
274    /// A response has been received.
275    Response {
276        request_id: OutboundRequestId,
277        response: TCodec::Response,
278    },
279    /// A response to an inbound request has been sent.
280    ResponseSent(InboundRequestId),
281    /// A response to an inbound request was omitted as a result
282    /// of dropping the response `sender` of an inbound `Request`.
283    ResponseOmission(InboundRequestId),
284    /// An outbound request timed out while sending the request
285    /// or waiting for the response.
286    OutboundTimeout(OutboundRequestId),
287    /// An outbound request failed to negotiate a mutually supported protocol.
288    OutboundUnsupportedProtocols(OutboundRequestId),
289    OutboundStreamFailed {
290        request_id: OutboundRequestId,
291        error: io::Error,
292    },
293    /// An inbound request timed out while waiting for the request
294    /// or sending the response.
295    InboundTimeout(InboundRequestId),
296    InboundStreamFailed {
297        request_id: InboundRequestId,
298        error: io::Error,
299    },
300}
301
302impl<TCodec: Codec> fmt::Debug for Event<TCodec> {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        match self {
305            Event::Request {
306                request_id,
307                request: _,
308                sender: _,
309            } => f
310                .debug_struct("Event::Request")
311                .field("request_id", request_id)
312                .finish(),
313            Event::Response {
314                request_id,
315                response: _,
316            } => f
317                .debug_struct("Event::Response")
318                .field("request_id", request_id)
319                .finish(),
320            Event::ResponseSent(request_id) => f
321                .debug_tuple("Event::ResponseSent")
322                .field(request_id)
323                .finish(),
324            Event::ResponseOmission(request_id) => f
325                .debug_tuple("Event::ResponseOmission")
326                .field(request_id)
327                .finish(),
328            Event::OutboundTimeout(request_id) => f
329                .debug_tuple("Event::OutboundTimeout")
330                .field(request_id)
331                .finish(),
332            Event::OutboundUnsupportedProtocols(request_id) => f
333                .debug_tuple("Event::OutboundUnsupportedProtocols")
334                .field(request_id)
335                .finish(),
336            Event::OutboundStreamFailed { request_id, error } => f
337                .debug_struct("Event::OutboundStreamFailed")
338                .field("request_id", &request_id)
339                .field("error", &error)
340                .finish(),
341            Event::InboundTimeout(request_id) => f
342                .debug_tuple("Event::InboundTimeout")
343                .field(request_id)
344                .finish(),
345            Event::InboundStreamFailed { request_id, error } => f
346                .debug_struct("Event::InboundStreamFailed")
347                .field("request_id", &request_id)
348                .field("error", &error)
349                .finish(),
350        }
351    }
352}
353
354pub struct OutboundMessage<TCodec: Codec> {
355    pub(crate) request_id: OutboundRequestId,
356    pub(crate) request: TCodec::Request,
357    pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>,
358}
359
360impl<TCodec> fmt::Debug for OutboundMessage<TCodec>
361where
362    TCodec: Codec,
363{
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        f.debug_struct("OutboundMessage").finish_non_exhaustive()
366    }
367}
368
369impl<TCodec> ConnectionHandler for Handler<TCodec>
370where
371    TCodec: Codec + Send + Clone + 'static,
372{
373    type FromBehaviour = OutboundMessage<TCodec>;
374    type ToBehaviour = Event<TCodec>;
375    type InboundProtocol = Protocol<TCodec::Protocol>;
376    type OutboundProtocol = Protocol<TCodec::Protocol>;
377    type OutboundOpenInfo = ();
378    type InboundOpenInfo = ();
379
380    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
381        SubstreamProtocol::new(
382            Protocol {
383                protocols: self.inbound_protocols.clone(),
384            },
385            (),
386        )
387    }
388
389    fn on_behaviour_event(&mut self, request: Self::FromBehaviour) {
390        self.pending_outbound.push_back(request);
391    }
392
393    #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
394    fn poll(
395        &mut self,
396        cx: &mut Context<'_>,
397    ) -> Poll<ConnectionHandlerEvent<Protocol<TCodec::Protocol>, (), Self::ToBehaviour>> {
398        match self.worker_streams.poll_unpin(cx) {
399            Poll::Ready((_, Ok(Ok(event)))) => {
400                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
401            }
402            Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
403                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
404                    Event::InboundStreamFailed {
405                        request_id: id,
406                        error: e,
407                    },
408                ));
409            }
410            Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
411                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
412                    Event::OutboundStreamFailed {
413                        request_id: id,
414                        error: e,
415                    },
416                ));
417            }
418            Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => {
419                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
420                    Event::InboundTimeout(id),
421                ));
422            }
423            Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => {
424                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
425                    Event::OutboundTimeout(id),
426                ));
427            }
428            Poll::Pending => {}
429        }
430
431        // Drain pending events that were produced by `worker_streams`.
432        if let Some(event) = self.pending_events.pop_front() {
433            return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
434        } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
435            self.pending_events.shrink_to_fit();
436        }
437
438        // Check for inbound requests.
439        if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) {
440            // We received an inbound request.
441
442            return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request {
443                request_id: id,
444                request: rq,
445                sender: rs_sender,
446            }));
447        }
448
449        // Emit outbound requests.
450        if let Some(request) = self.pending_outbound.pop_front() {
451            let protocols = request.protocols.clone();
452            self.requested_outbound.push_back(request);
453
454            return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
455                protocol: SubstreamProtocol::new(Protocol { protocols }, ()),
456            });
457        }
458
459        debug_assert!(self.pending_outbound.is_empty());
460
461        if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
462            self.pending_outbound.shrink_to_fit();
463        }
464
465        Poll::Pending
466    }
467
468    fn on_connection_event(
469        &mut self,
470        event: ConnectionEvent<
471            Self::InboundProtocol,
472            Self::OutboundProtocol,
473            Self::InboundOpenInfo,
474            Self::OutboundOpenInfo,
475        >,
476    ) {
477        match event {
478            ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
479                self.on_fully_negotiated_inbound(fully_negotiated_inbound)
480            }
481            ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
482                self.on_fully_negotiated_outbound(fully_negotiated_outbound)
483            }
484            ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
485                self.on_dial_upgrade_error(dial_upgrade_error)
486            }
487            ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
488                self.on_listen_upgrade_error(listen_upgrade_error)
489            }
490            _ => {}
491        }
492    }
493}