1use libp2p_core::{Endpoint, Multiaddr};
65use libp2p_identity::PeerId;
66use libp2p_swarm::{
67 dummy, CloseConnection, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
68 THandlerInEvent, THandlerOutEvent, ToSwarm,
69};
70use std::collections::{HashSet, VecDeque};
71use std::fmt;
72use std::task::{Context, Poll, Waker};
73use void::Void;
74
75#[derive(Default, Debug)]
77pub struct Behaviour<S> {
78 state: S,
79 close_connections: VecDeque<PeerId>,
80 waker: Option<Waker>,
81}
82
83#[derive(Default)]
85pub struct AllowedPeers {
86 peers: HashSet<PeerId>,
87}
88
89#[derive(Default)]
91pub struct BlockedPeers {
92 peers: HashSet<PeerId>,
93}
94
95impl Behaviour<AllowedPeers> {
96 pub fn allow_peer(&mut self, peer: PeerId) {
98 self.state.peers.insert(peer);
99 if let Some(waker) = self.waker.take() {
100 waker.wake()
101 }
102 }
103
104 pub fn disallow_peer(&mut self, peer: PeerId) {
108 self.state.peers.remove(&peer);
109 self.close_connections.push_back(peer);
110 if let Some(waker) = self.waker.take() {
111 waker.wake()
112 }
113 }
114}
115
116impl Behaviour<BlockedPeers> {
117 pub fn block_peer(&mut self, peer: PeerId) {
121 self.state.peers.insert(peer);
122 self.close_connections.push_back(peer);
123 if let Some(waker) = self.waker.take() {
124 waker.wake()
125 }
126 }
127
128 pub fn unblock_peer(&mut self, peer: PeerId) {
130 self.state.peers.remove(&peer);
131 if let Some(waker) = self.waker.take() {
132 waker.wake()
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct NotAllowed {
140 peer: PeerId,
141}
142
143impl fmt::Display for NotAllowed {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 write!(f, "peer {} is not in the allow list", self.peer)
146 }
147}
148
149impl std::error::Error for NotAllowed {}
150
151#[derive(Debug)]
153pub struct Blocked {
154 peer: PeerId,
155}
156
157impl fmt::Display for Blocked {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 write!(f, "peer {} is in the block list", self.peer)
160 }
161}
162
163impl std::error::Error for Blocked {}
164
165trait Enforce: 'static {
166 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied>;
167}
168
169impl Enforce for AllowedPeers {
170 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
171 if !self.peers.contains(peer) {
172 return Err(ConnectionDenied::new(NotAllowed { peer: *peer }));
173 }
174
175 Ok(())
176 }
177}
178
179impl Enforce for BlockedPeers {
180 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
181 if self.peers.contains(peer) {
182 return Err(ConnectionDenied::new(Blocked { peer: *peer }));
183 }
184
185 Ok(())
186 }
187}
188
189impl<S> NetworkBehaviour for Behaviour<S>
190where
191 S: Enforce,
192{
193 type ConnectionHandler = dummy::ConnectionHandler;
194 type ToSwarm = Void;
195
196 fn handle_established_inbound_connection(
197 &mut self,
198 _: ConnectionId,
199 peer: PeerId,
200 _: &Multiaddr,
201 _: &Multiaddr,
202 ) -> Result<THandler<Self>, ConnectionDenied> {
203 self.state.enforce(&peer)?;
204
205 Ok(dummy::ConnectionHandler)
206 }
207
208 fn handle_pending_outbound_connection(
209 &mut self,
210 _: ConnectionId,
211 peer: Option<PeerId>,
212 _: &[Multiaddr],
213 _: Endpoint,
214 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
215 if let Some(peer) = peer {
216 self.state.enforce(&peer)?;
217 }
218
219 Ok(vec![])
220 }
221
222 fn handle_established_outbound_connection(
223 &mut self,
224 _: ConnectionId,
225 peer: PeerId,
226 _: &Multiaddr,
227 _: Endpoint,
228 ) -> Result<THandler<Self>, ConnectionDenied> {
229 self.state.enforce(&peer)?;
230
231 Ok(dummy::ConnectionHandler)
232 }
233
234 fn on_swarm_event(&mut self, _event: FromSwarm) {}
235
236 fn on_connection_handler_event(
237 &mut self,
238 _id: PeerId,
239 _: ConnectionId,
240 event: THandlerOutEvent<Self>,
241 ) {
242 void::unreachable(event)
243 }
244
245 fn poll(
246 &mut self,
247 cx: &mut Context<'_>,
248 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
249 if let Some(peer) = self.close_connections.pop_front() {
250 return Poll::Ready(ToSwarm::CloseConnection {
251 peer_id: peer,
252 connection: CloseConnection::All,
253 });
254 }
255
256 self.waker = Some(cx.waker().clone());
257 Poll::Pending
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use libp2p_swarm::{dial_opts::DialOpts, DialError, ListenError, Swarm, SwarmEvent};
265 use libp2p_swarm_test::SwarmExt;
266
267 #[async_std::test]
268 async fn cannot_dial_blocked_peer() {
269 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
270 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
271 listener.listen().with_memory_addr_external().await;
272
273 dialer.behaviour_mut().block_peer(*listener.local_peer_id());
274
275 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
276 panic!("unexpected dial error")
277 };
278 assert!(cause.downcast::<Blocked>().is_ok());
279 }
280
281 #[async_std::test]
282 async fn can_dial_unblocked_peer() {
283 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
284 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
285 listener.listen().with_memory_addr_external().await;
286
287 dialer.behaviour_mut().block_peer(*listener.local_peer_id());
288 dialer
289 .behaviour_mut()
290 .unblock_peer(*listener.local_peer_id());
291
292 dial(&mut dialer, &listener).unwrap();
293 }
294
295 #[async_std::test]
296 async fn blocked_peer_cannot_dial_us() {
297 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
298 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
299 listener.listen().with_memory_addr_external().await;
300
301 listener.behaviour_mut().block_peer(*dialer.local_peer_id());
302 dial(&mut dialer, &listener).unwrap();
303 async_std::task::spawn(dialer.loop_on_next());
304
305 let cause = listener
306 .wait(|e| match e {
307 SwarmEvent::IncomingConnectionError {
308 error: ListenError::Denied { cause },
309 ..
310 } => Some(cause),
311 _ => None,
312 })
313 .await;
314 assert!(cause.downcast::<Blocked>().is_ok());
315 }
316
317 #[async_std::test]
318 async fn connections_get_closed_upon_blocked() {
319 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
320 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
321 listener.listen().with_memory_addr_external().await;
322 dialer.connect(&mut listener).await;
323
324 dialer.behaviour_mut().block_peer(*listener.local_peer_id());
325
326 let (
327 [SwarmEvent::ConnectionClosed {
328 peer_id: closed_dialer_peer,
329 ..
330 }],
331 [SwarmEvent::ConnectionClosed {
332 peer_id: closed_listener_peer,
333 ..
334 }],
335 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
336 else {
337 panic!("unexpected events")
338 };
339 assert_eq!(closed_dialer_peer, *listener.local_peer_id());
340 assert_eq!(closed_listener_peer, *dialer.local_peer_id());
341 }
342
343 #[async_std::test]
344 async fn cannot_dial_peer_unless_allowed() {
345 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
346 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
347 listener.listen().with_memory_addr_external().await;
348
349 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
350 panic!("unexpected dial error")
351 };
352 assert!(cause.downcast::<NotAllowed>().is_ok());
353
354 dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
355 assert!(dial(&mut dialer, &listener).is_ok());
356 }
357
358 #[async_std::test]
359 async fn cannot_dial_disallowed_peer() {
360 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
361 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
362 listener.listen().with_memory_addr_external().await;
363
364 dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
365 dialer
366 .behaviour_mut()
367 .disallow_peer(*listener.local_peer_id());
368
369 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
370 panic!("unexpected dial error")
371 };
372 assert!(cause.downcast::<NotAllowed>().is_ok());
373 }
374
375 #[async_std::test]
376 async fn not_allowed_peer_cannot_dial_us() {
377 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
378 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
379 listener.listen().with_memory_addr_external().await;
380
381 dialer
382 .dial(
383 DialOpts::unknown_peer_id()
384 .address(listener.external_addresses().next().cloned().unwrap())
385 .build(),
386 )
387 .unwrap();
388
389 let (
390 [SwarmEvent::OutgoingConnectionError {
391 error:
392 DialError::Denied {
393 cause: outgoing_cause,
394 },
395 ..
396 }],
397 [_, SwarmEvent::IncomingConnectionError {
398 error:
399 ListenError::Denied {
400 cause: incoming_cause,
401 },
402 ..
403 }],
404 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
405 else {
406 panic!("unexpected events")
407 };
408 assert!(outgoing_cause.downcast::<NotAllowed>().is_ok());
409 assert!(incoming_cause.downcast::<NotAllowed>().is_ok());
410 }
411
412 #[async_std::test]
413 async fn connections_get_closed_upon_disallow() {
414 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
415 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
416 listener.listen().with_memory_addr_external().await;
417 dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
418 listener.behaviour_mut().allow_peer(*dialer.local_peer_id());
419
420 dialer.connect(&mut listener).await;
421
422 dialer
423 .behaviour_mut()
424 .disallow_peer(*listener.local_peer_id());
425 let (
426 [SwarmEvent::ConnectionClosed {
427 peer_id: closed_dialer_peer,
428 ..
429 }],
430 [SwarmEvent::ConnectionClosed {
431 peer_id: closed_listener_peer,
432 ..
433 }],
434 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
435 else {
436 panic!("unexpected events")
437 };
438 assert_eq!(closed_dialer_peer, *listener.local_peer_id());
439 assert_eq!(closed_listener_peer, *dialer.local_peer_id());
440 }
441
442 fn dial<S>(
443 dialer: &mut Swarm<Behaviour<S>>,
444 listener: &Swarm<Behaviour<S>>,
445 ) -> Result<(), DialError>
446 where
447 S: Enforce,
448 {
449 dialer.dial(
450 DialOpts::peer_id(*listener.local_peer_id())
451 .addresses(listener.external_addresses().cloned().collect())
452 .build(),
453 )
454 }
455}