libp2p_kad/
handler.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::behaviour::Mode;
22use crate::protocol::{
23    KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
24};
25use crate::record::{self, Record};
26use crate::QueryId;
27use either::Either;
28use futures::channel::oneshot;
29use futures::prelude::*;
30use futures::stream::SelectAll;
31use libp2p_core::{upgrade, ConnectedPoint};
32use libp2p_identity::PeerId;
33use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound};
34use libp2p_swarm::{
35    ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol,
36    SupportedProtocols,
37};
38use std::collections::VecDeque;
39use std::task::Waker;
40use std::time::Duration;
41use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll};
42
43const MAX_NUM_STREAMS: usize = 32;
44
45/// Protocol handler that manages substreams for the Kademlia protocol
46/// on a single connection with a peer.
47///
48/// The handler will automatically open a Kademlia substream with the remote for each request we
49/// make.
50///
51/// It also handles requests made by the remote.
52pub struct Handler {
53    /// Configuration of the wire protocol.
54    protocol_config: ProtocolConfig,
55
56    /// In client mode, we don't accept inbound substreams.
57    mode: Mode,
58
59    /// Next unique ID of a connection.
60    next_connec_unique_id: UniqueConnecId,
61
62    /// List of active outbound streams.
63    outbound_substreams:
64        futures_bounded::FuturesTupleSet<io::Result<Option<KadResponseMsg>>, QueryId>,
65
66    /// Contains one [`oneshot::Sender`] per outbound stream that we have requested.
67    pending_streams:
68        VecDeque<oneshot::Sender<Result<KadOutStreamSink<Stream>, StreamUpgradeError<io::Error>>>>,
69
70    /// List of outbound substreams that are waiting to become active next.
71    /// Contains the request we want to send, and the user data if we expect an answer.
72    pending_messages: VecDeque<(KadRequestMsg, QueryId)>,
73
74    /// List of active inbound substreams with the state they are in.
75    inbound_substreams: SelectAll<InboundSubstreamState>,
76
77    /// The connected endpoint of the connection that the handler
78    /// is associated with.
79    endpoint: ConnectedPoint,
80
81    /// The [`PeerId`] of the remote.
82    remote_peer_id: PeerId,
83
84    /// The current state of protocol confirmation.
85    protocol_status: Option<ProtocolStatus>,
86
87    remote_supported_protocols: SupportedProtocols,
88}
89
90/// The states of protocol confirmation that a connection
91/// handler transitions through.
92#[derive(Debug, Copy, Clone, PartialEq)]
93struct ProtocolStatus {
94    /// Whether the remote node supports one of our kademlia protocols.
95    supported: bool,
96    /// Whether we reported the state to the behaviour.
97    reported: bool,
98}
99
100/// State of an active inbound substream.
101enum InboundSubstreamState {
102    /// Waiting for a request from the remote.
103    WaitingMessage {
104        /// Whether it is the first message to be awaited on this stream.
105        first: bool,
106        connection_id: UniqueConnecId,
107        substream: KadInStreamSink<Stream>,
108    },
109    /// Waiting for the behaviour to send a [`HandlerIn`] event containing the response.
110    WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
111    /// Waiting to send an answer back to the remote.
112    PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
113    /// Waiting to flush an answer back to the remote.
114    PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
115    /// The substream is being closed.
116    Closing(KadInStreamSink<Stream>),
117    /// The substream was cancelled in favor of a new one.
118    Cancelled,
119
120    Poisoned {
121        phantom: PhantomData<QueryId>,
122    },
123}
124
125impl InboundSubstreamState {
126    fn try_answer_with(
127        &mut self,
128        id: RequestId,
129        msg: KadResponseMsg,
130    ) -> Result<(), KadResponseMsg> {
131        match std::mem::replace(
132            self,
133            InboundSubstreamState::Poisoned {
134                phantom: PhantomData,
135            },
136        ) {
137            InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
138                if conn_id == id.connec_unique_id =>
139            {
140                *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
141
142                if let Some(waker) = waker.take() {
143                    waker.wake();
144                }
145
146                Ok(())
147            }
148            other => {
149                *self = other;
150
151                Err(msg)
152            }
153        }
154    }
155
156    fn close(&mut self) {
157        match std::mem::replace(
158            self,
159            InboundSubstreamState::Poisoned {
160                phantom: PhantomData,
161            },
162        ) {
163            InboundSubstreamState::WaitingMessage { substream, .. }
164            | InboundSubstreamState::WaitingBehaviour(_, substream, _)
165            | InboundSubstreamState::PendingSend(_, substream, _)
166            | InboundSubstreamState::PendingFlush(_, substream)
167            | InboundSubstreamState::Closing(substream) => {
168                *self = InboundSubstreamState::Closing(substream);
169            }
170            InboundSubstreamState::Cancelled => {
171                *self = InboundSubstreamState::Cancelled;
172            }
173            InboundSubstreamState::Poisoned { .. } => unreachable!(),
174        }
175    }
176}
177
178/// Event produced by the Kademlia handler.
179#[derive(Debug)]
180pub enum HandlerEvent {
181    /// The configured protocol name has been confirmed by the peer through
182    /// a successfully negotiated substream or by learning the supported protocols of the remote.
183    ProtocolConfirmed { endpoint: ConnectedPoint },
184    /// The configured protocol name(s) are not or no longer supported by the peer on the provided
185    /// connection and it should be removed from the routing table.
186    ProtocolNotSupported { endpoint: ConnectedPoint },
187
188    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
189    /// returned is not specified, but should be around 20.
190    FindNodeReq {
191        /// The key for which to locate the closest nodes.
192        key: Vec<u8>,
193        /// Identifier of the request. Needs to be passed back when answering.
194        request_id: RequestId,
195    },
196
197    /// Response to an `HandlerIn::FindNodeReq`.
198    FindNodeRes {
199        /// Results of the request.
200        closer_peers: Vec<KadPeer>,
201        /// The user data passed to the `FindNodeReq`.
202        query_id: QueryId,
203    },
204
205    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
206    /// this key.
207    GetProvidersReq {
208        /// The key for which providers are requested.
209        key: record::Key,
210        /// Identifier of the request. Needs to be passed back when answering.
211        request_id: RequestId,
212    },
213
214    /// Response to an `HandlerIn::GetProvidersReq`.
215    GetProvidersRes {
216        /// Nodes closest to the key.
217        closer_peers: Vec<KadPeer>,
218        /// Known providers for this key.
219        provider_peers: Vec<KadPeer>,
220        /// The user data passed to the `GetProvidersReq`.
221        query_id: QueryId,
222    },
223
224    /// An error happened when performing a query.
225    QueryError {
226        /// The error that happened.
227        error: HandlerQueryErr,
228        /// The user data passed to the query.
229        query_id: QueryId,
230    },
231
232    /// The peer announced itself as a provider of a key.
233    AddProvider {
234        /// The key for which the peer is a provider of the associated value.
235        key: record::Key,
236        /// The peer that is the provider of the value for `key`.
237        provider: KadPeer,
238    },
239
240    /// Request to get a value from the dht records
241    GetRecord {
242        /// Key for which we should look in the dht
243        key: record::Key,
244        /// Identifier of the request. Needs to be passed back when answering.
245        request_id: RequestId,
246    },
247
248    /// Response to a `HandlerIn::GetRecord`.
249    GetRecordRes {
250        /// The result is present if the key has been found
251        record: Option<Record>,
252        /// Nodes closest to the key.
253        closer_peers: Vec<KadPeer>,
254        /// The user data passed to the `GetValue`.
255        query_id: QueryId,
256    },
257
258    /// Request to put a value in the dht records
259    PutRecord {
260        record: Record,
261        /// Identifier of the request. Needs to be passed back when answering.
262        request_id: RequestId,
263    },
264
265    /// Response to a request to store a record.
266    PutRecordRes {
267        /// The key of the stored record.
268        key: record::Key,
269        /// The value of the stored record.
270        value: Vec<u8>,
271        /// The user data passed to the `PutValue`.
272        query_id: QueryId,
273    },
274}
275
276/// Error that can happen when requesting an RPC query.
277#[derive(Debug)]
278pub enum HandlerQueryErr {
279    /// Received an answer that doesn't correspond to the request.
280    UnexpectedMessage,
281    /// I/O error in the substream.
282    Io(io::Error),
283}
284
285impl fmt::Display for HandlerQueryErr {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        match self {
288            HandlerQueryErr::UnexpectedMessage => {
289                write!(
290                    f,
291                    "Remote answered our Kademlia RPC query with the wrong message type"
292                )
293            }
294            HandlerQueryErr::Io(err) => {
295                write!(f, "I/O error during a Kademlia RPC query: {err}")
296            }
297        }
298    }
299}
300
301impl error::Error for HandlerQueryErr {
302    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
303        match self {
304            HandlerQueryErr::UnexpectedMessage => None,
305            HandlerQueryErr::Io(err) => Some(err),
306        }
307    }
308}
309
310/// Event to send to the handler.
311#[derive(Debug)]
312pub enum HandlerIn {
313    /// Resets the (sub)stream associated with the given request ID,
314    /// thus signaling an error to the remote.
315    ///
316    /// Explicitly resetting the (sub)stream associated with a request
317    /// can be used as an alternative to letting requests simply time
318    /// out on the remote peer, thus potentially avoiding some delay
319    /// for the query on the remote.
320    Reset(RequestId),
321
322    /// Change the connection to the specified mode.
323    ReconfigureMode { new_mode: Mode },
324
325    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
326    /// returned is not specified, but should be around 20.
327    FindNodeReq {
328        /// Identifier of the node.
329        key: Vec<u8>,
330        /// ID of the query that generated this request.
331        query_id: QueryId,
332    },
333
334    /// Response to a `FindNodeReq`.
335    FindNodeRes {
336        /// Results of the request.
337        closer_peers: Vec<KadPeer>,
338        /// Identifier of the request that was made by the remote.
339        ///
340        /// It is a logic error to use an id of the handler of a different node.
341        request_id: RequestId,
342    },
343
344    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
345    /// this key.
346    GetProvidersReq {
347        /// Identifier being searched.
348        key: record::Key,
349        /// ID of the query that generated this request.
350        query_id: QueryId,
351    },
352
353    /// Response to a `GetProvidersReq`.
354    GetProvidersRes {
355        /// Nodes closest to the key.
356        closer_peers: Vec<KadPeer>,
357        /// Known providers for this key.
358        provider_peers: Vec<KadPeer>,
359        /// Identifier of the request that was made by the remote.
360        ///
361        /// It is a logic error to use an id of the handler of a different node.
362        request_id: RequestId,
363    },
364
365    /// Indicates that this provider is known for this key.
366    ///
367    /// The API of the handler doesn't expose any event that allows you to know whether this
368    /// succeeded.
369    AddProvider {
370        /// Key for which we should add providers.
371        key: record::Key,
372        /// Known provider for this key.
373        provider: KadPeer,
374        /// ID of the query that generated this request.
375        query_id: QueryId,
376    },
377
378    /// Request to retrieve a record from the DHT.
379    GetRecord {
380        /// The key of the record.
381        key: record::Key,
382        /// ID of the query that generated this request.
383        query_id: QueryId,
384    },
385
386    /// Response to a `GetRecord` request.
387    GetRecordRes {
388        /// The value that might have been found in our storage.
389        record: Option<Record>,
390        /// Nodes that are closer to the key we were searching for.
391        closer_peers: Vec<KadPeer>,
392        /// Identifier of the request that was made by the remote.
393        request_id: RequestId,
394    },
395
396    /// Put a value into the dht records.
397    PutRecord {
398        record: Record,
399        /// ID of the query that generated this request.
400        query_id: QueryId,
401    },
402
403    /// Response to a `PutRecord`.
404    PutRecordRes {
405        /// Key of the value that was put.
406        key: record::Key,
407        /// Value that was put.
408        value: Vec<u8>,
409        /// Identifier of the request that was made by the remote.
410        request_id: RequestId,
411    },
412}
413
414/// Unique identifier for a request. Must be passed back in order to answer a request from
415/// the remote.
416#[derive(Debug, PartialEq, Eq, Copy, Clone)]
417pub struct RequestId {
418    /// Unique identifier for an incoming connection.
419    connec_unique_id: UniqueConnecId,
420}
421
422/// Unique identifier for a connection.
423#[derive(Debug, Copy, Clone, PartialEq, Eq)]
424struct UniqueConnecId(u64);
425
426impl Handler {
427    pub fn new(
428        protocol_config: ProtocolConfig,
429        endpoint: ConnectedPoint,
430        remote_peer_id: PeerId,
431        mode: Mode,
432    ) -> Self {
433        match &endpoint {
434            ConnectedPoint::Dialer { .. } => {
435                tracing::debug!(
436                    peer=%remote_peer_id,
437                    mode=%mode,
438                    "New outbound connection"
439                );
440            }
441            ConnectedPoint::Listener { .. } => {
442                tracing::debug!(
443                    peer=%remote_peer_id,
444                    mode=%mode,
445                    "New inbound connection"
446                );
447            }
448        }
449
450        Handler {
451            protocol_config,
452            mode,
453            endpoint,
454            remote_peer_id,
455            next_connec_unique_id: UniqueConnecId(0),
456            inbound_substreams: Default::default(),
457            outbound_substreams: futures_bounded::FuturesTupleSet::new(
458                Duration::from_secs(10),
459                MAX_NUM_STREAMS,
460            ),
461            pending_streams: Default::default(),
462            pending_messages: Default::default(),
463            protocol_status: None,
464            remote_supported_protocols: Default::default(),
465        }
466    }
467
468    fn on_fully_negotiated_outbound(
469        &mut self,
470        FullyNegotiatedOutbound {
471            protocol: stream,
472            info: (),
473        }: FullyNegotiatedOutbound<
474            <Self as ConnectionHandler>::OutboundProtocol,
475            <Self as ConnectionHandler>::OutboundOpenInfo,
476        >,
477    ) {
478        if let Some(sender) = self.pending_streams.pop_front() {
479            let _ = sender.send(Ok(stream));
480        }
481
482        if self.protocol_status.is_none() {
483            // Upon the first successfully negotiated substream, we know that the
484            // remote is configured with the same protocol name and we want
485            // the behaviour to add this peer to the routing table, if possible.
486            self.protocol_status = Some(ProtocolStatus {
487                supported: true,
488                reported: false,
489            });
490        }
491    }
492
493    fn on_fully_negotiated_inbound(
494        &mut self,
495        FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
496            <Self as ConnectionHandler>::InboundProtocol,
497            <Self as ConnectionHandler>::InboundOpenInfo,
498        >,
499    ) {
500        // If `self.allow_listening` is false, then we produced a `DeniedUpgrade` and `protocol`
501        // is a `Void`.
502        let protocol = match protocol {
503            future::Either::Left(p) => p,
504            future::Either::Right(p) => void::unreachable(p),
505        };
506
507        if self.protocol_status.is_none() {
508            // Upon the first successfully negotiated substream, we know that the
509            // remote is configured with the same protocol name and we want
510            // the behaviour to add this peer to the routing table, if possible.
511            self.protocol_status = Some(ProtocolStatus {
512                supported: true,
513                reported: false,
514            });
515        }
516
517        if self.inbound_substreams.len() == MAX_NUM_STREAMS {
518            if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
519                matches!(
520                    s,
521                    // An inbound substream waiting to be reused.
522                    InboundSubstreamState::WaitingMessage { first: false, .. }
523                )
524            }) {
525                *s = InboundSubstreamState::Cancelled;
526                tracing::debug!(
527                    peer=?self.remote_peer_id,
528                    "New inbound substream to peer exceeds inbound substream limit. \
529                    Removed older substream waiting to be reused."
530                )
531            } else {
532                tracing::warn!(
533                    peer=?self.remote_peer_id,
534                    "New inbound substream to peer exceeds inbound substream limit. \
535                     No older substream waiting to be reused. Dropping new substream."
536                );
537                return;
538            }
539        }
540
541        let connec_unique_id = self.next_connec_unique_id;
542        self.next_connec_unique_id.0 += 1;
543        self.inbound_substreams
544            .push(InboundSubstreamState::WaitingMessage {
545                first: true,
546                connection_id: connec_unique_id,
547                substream: protocol,
548            });
549    }
550
551    /// Takes the given [`KadRequestMsg`] and composes it into an outbound request-response protocol handshake using a [`oneshot::channel`].
552    fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) {
553        let (sender, receiver) = oneshot::channel();
554
555        self.pending_streams.push_back(sender);
556        let result = self.outbound_substreams.try_push(
557            async move {
558                let mut stream = receiver
559                    .await
560                    .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
561                    .map_err(|e| match e {
562                        StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
563                        StreamUpgradeError::Apply(e) => e,
564                        StreamUpgradeError::NegotiationFailed => io::Error::new(
565                            io::ErrorKind::ConnectionRefused,
566                            "protocol not supported",
567                        ),
568                        StreamUpgradeError::Io(e) => e,
569                    })?;
570
571                let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
572
573                stream.send(msg).await?;
574                stream.close().await?;
575
576                if !has_answer {
577                    return Ok(None);
578                }
579
580                let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
581
582                Ok(Some(msg))
583            },
584            id,
585        );
586
587        debug_assert!(
588            result.is_ok(),
589            "Expected to not create more streams than allowed"
590        );
591    }
592}
593
594impl ConnectionHandler for Handler {
595    type FromBehaviour = HandlerIn;
596    type ToBehaviour = HandlerEvent;
597    type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
598    type OutboundProtocol = ProtocolConfig;
599    type OutboundOpenInfo = ();
600    type InboundOpenInfo = ();
601
602    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
603        match self.mode {
604            Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
605            Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
606        }
607    }
608
609    fn on_behaviour_event(&mut self, message: HandlerIn) {
610        match message {
611            HandlerIn::Reset(request_id) => {
612                if let Some(state) = self
613                    .inbound_substreams
614                    .iter_mut()
615                    .find(|state| match state {
616                        InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
617                            conn_id == &request_id.connec_unique_id
618                        }
619                        _ => false,
620                    })
621                {
622                    state.close();
623                }
624            }
625            HandlerIn::FindNodeReq { key, query_id } => {
626                let msg = KadRequestMsg::FindNode { key };
627                self.pending_messages.push_back((msg, query_id));
628            }
629            HandlerIn::FindNodeRes {
630                closer_peers,
631                request_id,
632            } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
633            HandlerIn::GetProvidersReq { key, query_id } => {
634                let msg = KadRequestMsg::GetProviders { key };
635                self.pending_messages.push_back((msg, query_id));
636            }
637            HandlerIn::GetProvidersRes {
638                closer_peers,
639                provider_peers,
640                request_id,
641            } => self.answer_pending_request(
642                request_id,
643                KadResponseMsg::GetProviders {
644                    closer_peers,
645                    provider_peers,
646                },
647            ),
648            HandlerIn::AddProvider {
649                key,
650                provider,
651                query_id,
652            } => {
653                let msg = KadRequestMsg::AddProvider { key, provider };
654                self.pending_messages.push_back((msg, query_id));
655            }
656            HandlerIn::GetRecord { key, query_id } => {
657                let msg = KadRequestMsg::GetValue { key };
658                self.pending_messages.push_back((msg, query_id));
659            }
660            HandlerIn::PutRecord { record, query_id } => {
661                let msg = KadRequestMsg::PutValue { record };
662                self.pending_messages.push_back((msg, query_id));
663            }
664            HandlerIn::GetRecordRes {
665                record,
666                closer_peers,
667                request_id,
668            } => {
669                self.answer_pending_request(
670                    request_id,
671                    KadResponseMsg::GetValue {
672                        record,
673                        closer_peers,
674                    },
675                );
676            }
677            HandlerIn::PutRecordRes {
678                key,
679                request_id,
680                value,
681            } => {
682                self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
683            }
684            HandlerIn::ReconfigureMode { new_mode } => {
685                let peer = self.remote_peer_id;
686
687                match &self.endpoint {
688                    ConnectedPoint::Dialer { .. } => {
689                        tracing::debug!(
690                            %peer,
691                            mode=%new_mode,
692                            "Changed mode on outbound connection"
693                        )
694                    }
695                    ConnectedPoint::Listener { local_addr, .. } => {
696                        tracing::debug!(
697                            %peer,
698                            mode=%new_mode,
699                            local_address=%local_addr,
700                            "Changed mode on inbound connection assuming that one of our external addresses routes to the local address")
701                    }
702                }
703
704                self.mode = new_mode;
705            }
706        }
707    }
708
709    #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
710    fn poll(
711        &mut self,
712        cx: &mut Context<'_>,
713    ) -> Poll<
714        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
715    > {
716        loop {
717            match &mut self.protocol_status {
718                Some(status) if !status.reported => {
719                    status.reported = true;
720                    let event = if status.supported {
721                        HandlerEvent::ProtocolConfirmed {
722                            endpoint: self.endpoint.clone(),
723                        }
724                    } else {
725                        HandlerEvent::ProtocolNotSupported {
726                            endpoint: self.endpoint.clone(),
727                        }
728                    };
729
730                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
731                }
732                _ => {}
733            }
734
735            match self.outbound_substreams.poll_unpin(cx) {
736                Poll::Ready((Ok(Ok(Some(response))), query_id)) => {
737                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
738                        process_kad_response(response, query_id),
739                    ))
740                }
741                Poll::Ready((Ok(Ok(None)), _)) => {
742                    continue;
743                }
744                Poll::Ready((Ok(Err(e)), query_id)) => {
745                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
746                        HandlerEvent::QueryError {
747                            error: HandlerQueryErr::Io(e),
748                            query_id,
749                        },
750                    ))
751                }
752                Poll::Ready((Err(_timeout), query_id)) => {
753                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
754                        HandlerEvent::QueryError {
755                            error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()),
756                            query_id,
757                        },
758                    ))
759                }
760                Poll::Pending => {}
761            }
762
763            if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
764                return Poll::Ready(event);
765            }
766
767            if self.outbound_substreams.len() < MAX_NUM_STREAMS {
768                if let Some((msg, id)) = self.pending_messages.pop_front() {
769                    self.queue_new_stream(id, msg);
770                    return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
771                        protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
772                    });
773                }
774            }
775
776            return Poll::Pending;
777        }
778    }
779
780    fn on_connection_event(
781        &mut self,
782        event: ConnectionEvent<
783            Self::InboundProtocol,
784            Self::OutboundProtocol,
785            Self::InboundOpenInfo,
786            Self::OutboundOpenInfo,
787        >,
788    ) {
789        match event {
790            ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
791                self.on_fully_negotiated_outbound(fully_negotiated_outbound)
792            }
793            ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
794                self.on_fully_negotiated_inbound(fully_negotiated_inbound)
795            }
796            ConnectionEvent::DialUpgradeError(ev) => {
797                if let Some(sender) = self.pending_streams.pop_front() {
798                    let _ = sender.send(Err(ev.error));
799                }
800            }
801            ConnectionEvent::RemoteProtocolsChange(change) => {
802                let dirty = self.remote_supported_protocols.on_protocols_change(change);
803
804                if dirty {
805                    let remote_supports_our_kademlia_protocols = self
806                        .remote_supported_protocols
807                        .iter()
808                        .any(|p| self.protocol_config.protocol_names().contains(p));
809
810                    self.protocol_status = Some(compute_new_protocol_status(
811                        remote_supports_our_kademlia_protocols,
812                        self.protocol_status,
813                    ))
814                }
815            }
816            _ => {}
817        }
818    }
819}
820
821fn compute_new_protocol_status(
822    now_supported: bool,
823    current_status: Option<ProtocolStatus>,
824) -> ProtocolStatus {
825    let current_status = match current_status {
826        None => {
827            return ProtocolStatus {
828                supported: now_supported,
829                reported: false,
830            }
831        }
832        Some(current) => current,
833    };
834
835    if now_supported == current_status.supported {
836        return ProtocolStatus {
837            supported: now_supported,
838            reported: true,
839        };
840    }
841
842    if now_supported {
843        tracing::debug!("Remote now supports our kademlia protocol");
844    } else {
845        tracing::debug!("Remote no longer supports our kademlia protocol");
846    }
847
848    ProtocolStatus {
849        supported: now_supported,
850        reported: false,
851    }
852}
853
854impl Handler {
855    fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
856        for state in self.inbound_substreams.iter_mut() {
857            match state.try_answer_with(request_id, msg) {
858                Ok(()) => return,
859                Err(m) => {
860                    msg = m;
861                }
862            }
863        }
864
865        debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
866    }
867}
868
869impl futures::Stream for InboundSubstreamState {
870    type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent>;
871
872    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
873        let this = self.get_mut();
874
875        loop {
876            match std::mem::replace(
877                this,
878                Self::Poisoned {
879                    phantom: PhantomData,
880                },
881            ) {
882                InboundSubstreamState::WaitingMessage {
883                    first,
884                    connection_id,
885                    mut substream,
886                } => match substream.poll_next_unpin(cx) {
887                    Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
888                        tracing::warn!("Kademlia PING messages are unsupported");
889
890                        *this = InboundSubstreamState::Closing(substream);
891                    }
892                    Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
893                        *this =
894                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
895                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
896                            HandlerEvent::FindNodeReq {
897                                key,
898                                request_id: RequestId {
899                                    connec_unique_id: connection_id,
900                                },
901                            },
902                        )));
903                    }
904                    Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
905                        *this =
906                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
907                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
908                            HandlerEvent::GetProvidersReq {
909                                key,
910                                request_id: RequestId {
911                                    connec_unique_id: connection_id,
912                                },
913                            },
914                        )));
915                    }
916                    Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
917                        *this = InboundSubstreamState::WaitingMessage {
918                            first: false,
919                            connection_id,
920                            substream,
921                        };
922                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
923                            HandlerEvent::AddProvider { key, provider },
924                        )));
925                    }
926                    Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
927                        *this =
928                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
929                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
930                            HandlerEvent::GetRecord {
931                                key,
932                                request_id: RequestId {
933                                    connec_unique_id: connection_id,
934                                },
935                            },
936                        )));
937                    }
938                    Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
939                        *this =
940                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
941                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
942                            HandlerEvent::PutRecord {
943                                record,
944                                request_id: RequestId {
945                                    connec_unique_id: connection_id,
946                                },
947                            },
948                        )));
949                    }
950                    Poll::Pending => {
951                        *this = InboundSubstreamState::WaitingMessage {
952                            first,
953                            connection_id,
954                            substream,
955                        };
956                        return Poll::Pending;
957                    }
958                    Poll::Ready(None) => {
959                        return Poll::Ready(None);
960                    }
961                    Poll::Ready(Some(Err(e))) => {
962                        tracing::trace!("Inbound substream error: {:?}", e);
963                        return Poll::Ready(None);
964                    }
965                },
966                InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
967                    *this = InboundSubstreamState::WaitingBehaviour(
968                        id,
969                        substream,
970                        Some(cx.waker().clone()),
971                    );
972
973                    return Poll::Pending;
974                }
975                InboundSubstreamState::PendingSend(id, mut substream, msg) => {
976                    match substream.poll_ready_unpin(cx) {
977                        Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
978                            Ok(()) => {
979                                *this = InboundSubstreamState::PendingFlush(id, substream);
980                            }
981                            Err(_) => return Poll::Ready(None),
982                        },
983                        Poll::Pending => {
984                            *this = InboundSubstreamState::PendingSend(id, substream, msg);
985                            return Poll::Pending;
986                        }
987                        Poll::Ready(Err(_)) => return Poll::Ready(None),
988                    }
989                }
990                InboundSubstreamState::PendingFlush(id, mut substream) => {
991                    match substream.poll_flush_unpin(cx) {
992                        Poll::Ready(Ok(())) => {
993                            *this = InboundSubstreamState::WaitingMessage {
994                                first: false,
995                                connection_id: id,
996                                substream,
997                            };
998                        }
999                        Poll::Pending => {
1000                            *this = InboundSubstreamState::PendingFlush(id, substream);
1001                            return Poll::Pending;
1002                        }
1003                        Poll::Ready(Err(_)) => return Poll::Ready(None),
1004                    }
1005                }
1006                InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1007                    Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1008                    Poll::Pending => {
1009                        *this = InboundSubstreamState::Closing(stream);
1010                        return Poll::Pending;
1011                    }
1012                },
1013                InboundSubstreamState::Poisoned { .. } => unreachable!(),
1014                InboundSubstreamState::Cancelled => return Poll::Ready(None),
1015            }
1016        }
1017    }
1018}
1019
1020/// Process a Kademlia message that's supposed to be a response to one of our requests.
1021fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1022    // TODO: must check that the response corresponds to the request
1023    match event {
1024        KadResponseMsg::Pong => {
1025            // We never send out pings.
1026            HandlerEvent::QueryError {
1027                error: HandlerQueryErr::UnexpectedMessage,
1028                query_id,
1029            }
1030        }
1031        KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1032            closer_peers,
1033            query_id,
1034        },
1035        KadResponseMsg::GetProviders {
1036            closer_peers,
1037            provider_peers,
1038        } => HandlerEvent::GetProvidersRes {
1039            closer_peers,
1040            provider_peers,
1041            query_id,
1042        },
1043        KadResponseMsg::GetValue {
1044            record,
1045            closer_peers,
1046        } => HandlerEvent::GetRecordRes {
1047            record,
1048            closer_peers,
1049            query_id,
1050        },
1051        KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1052            key,
1053            value,
1054            query_id,
1055        },
1056    }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062    use quickcheck::{Arbitrary, Gen};
1063    use tracing_subscriber::EnvFilter;
1064
1065    impl Arbitrary for ProtocolStatus {
1066        fn arbitrary(g: &mut Gen) -> Self {
1067            Self {
1068                supported: bool::arbitrary(g),
1069                reported: bool::arbitrary(g),
1070            }
1071        }
1072    }
1073
1074    #[test]
1075    fn compute_next_protocol_status_test() {
1076        let _ = tracing_subscriber::fmt()
1077            .with_env_filter(EnvFilter::from_default_env())
1078            .try_init();
1079
1080        fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1081            let new = compute_new_protocol_status(now_supported, current);
1082
1083            match current {
1084                None => {
1085                    assert!(!new.reported);
1086                    assert_eq!(new.supported, now_supported);
1087                }
1088                Some(current) => {
1089                    if current.supported == now_supported {
1090                        assert!(new.reported);
1091                    } else {
1092                        assert!(!new.reported);
1093                    }
1094
1095                    assert_eq!(new.supported, now_supported);
1096                }
1097            }
1098        }
1099
1100        quickcheck::quickcheck(prop as fn(_, _))
1101    }
1102}