1use libp2p_core::{ConnectedPoint, Endpoint, Multiaddr};
22use libp2p_identity::PeerId;
23use libp2p_swarm::{
24 behaviour::{ConnectionEstablished, DialFailure, ListenFailure},
25 dummy, ConnectionClosed, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
26 THandlerInEvent, THandlerOutEvent, ToSwarm,
27};
28use std::collections::{HashMap, HashSet};
29use std::fmt;
30use std::task::{Context, Poll};
31use void::Void;
32
33pub struct Behaviour {
62 limits: ConnectionLimits,
63
64 pending_inbound_connections: HashSet<ConnectionId>,
65 pending_outbound_connections: HashSet<ConnectionId>,
66 established_inbound_connections: HashSet<ConnectionId>,
67 established_outbound_connections: HashSet<ConnectionId>,
68 established_per_peer: HashMap<PeerId, HashSet<ConnectionId>>,
69}
70
71impl Behaviour {
72 pub fn new(limits: ConnectionLimits) -> Self {
73 Self {
74 limits,
75 pending_inbound_connections: Default::default(),
76 pending_outbound_connections: Default::default(),
77 established_inbound_connections: Default::default(),
78 established_outbound_connections: Default::default(),
79 established_per_peer: Default::default(),
80 }
81 }
82
83 pub fn limits_mut(&mut self) -> &mut ConnectionLimits {
86 &mut self.limits
87 }
88}
89
90fn check_limit(limit: Option<u32>, current: usize, kind: Kind) -> Result<(), ConnectionDenied> {
91 let limit = limit.unwrap_or(u32::MAX);
92 let current = current as u32;
93
94 if current >= limit {
95 return Err(ConnectionDenied::new(Exceeded { limit, kind }));
96 }
97
98 Ok(())
99}
100
101#[derive(Debug, Clone, Copy)]
103pub struct Exceeded {
104 limit: u32,
105 kind: Kind,
106}
107
108impl Exceeded {
109 pub fn limit(&self) -> u32 {
110 self.limit
111 }
112}
113
114impl fmt::Display for Exceeded {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 write!(
117 f,
118 "connection limit exceeded: at most {} {} are allowed",
119 self.limit, self.kind
120 )
121 }
122}
123
124#[derive(Debug, Clone, Copy)]
125enum Kind {
126 PendingIncoming,
127 PendingOutgoing,
128 EstablishedIncoming,
129 EstablishedOutgoing,
130 EstablishedPerPeer,
131 EstablishedTotal,
132}
133
134impl fmt::Display for Kind {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 match self {
137 Kind::PendingIncoming => write!(f, "pending incoming connections"),
138 Kind::PendingOutgoing => write!(f, "pending outgoing connections"),
139 Kind::EstablishedIncoming => write!(f, "established incoming connections"),
140 Kind::EstablishedOutgoing => write!(f, "established outgoing connections"),
141 Kind::EstablishedPerPeer => write!(f, "established connections per peer"),
142 Kind::EstablishedTotal => write!(f, "established connections"),
143 }
144 }
145}
146
147impl std::error::Error for Exceeded {}
148
149#[derive(Debug, Clone, Default)]
151pub struct ConnectionLimits {
152 max_pending_incoming: Option<u32>,
153 max_pending_outgoing: Option<u32>,
154 max_established_incoming: Option<u32>,
155 max_established_outgoing: Option<u32>,
156 max_established_per_peer: Option<u32>,
157 max_established_total: Option<u32>,
158}
159
160impl ConnectionLimits {
161 pub fn with_max_pending_incoming(mut self, limit: Option<u32>) -> Self {
163 self.max_pending_incoming = limit;
164 self
165 }
166
167 pub fn with_max_pending_outgoing(mut self, limit: Option<u32>) -> Self {
169 self.max_pending_outgoing = limit;
170 self
171 }
172
173 pub fn with_max_established_incoming(mut self, limit: Option<u32>) -> Self {
175 self.max_established_incoming = limit;
176 self
177 }
178
179 pub fn with_max_established_outgoing(mut self, limit: Option<u32>) -> Self {
181 self.max_established_outgoing = limit;
182 self
183 }
184
185 pub fn with_max_established(mut self, limit: Option<u32>) -> Self {
192 self.max_established_total = limit;
193 self
194 }
195
196 pub fn with_max_established_per_peer(mut self, limit: Option<u32>) -> Self {
199 self.max_established_per_peer = limit;
200 self
201 }
202}
203
204impl NetworkBehaviour for Behaviour {
205 type ConnectionHandler = dummy::ConnectionHandler;
206 type ToSwarm = Void;
207
208 fn handle_pending_inbound_connection(
209 &mut self,
210 connection_id: ConnectionId,
211 _: &Multiaddr,
212 _: &Multiaddr,
213 ) -> Result<(), ConnectionDenied> {
214 check_limit(
215 self.limits.max_pending_incoming,
216 self.pending_inbound_connections.len(),
217 Kind::PendingIncoming,
218 )?;
219
220 self.pending_inbound_connections.insert(connection_id);
221
222 Ok(())
223 }
224
225 fn handle_established_inbound_connection(
226 &mut self,
227 connection_id: ConnectionId,
228 peer: PeerId,
229 _: &Multiaddr,
230 _: &Multiaddr,
231 ) -> Result<THandler<Self>, ConnectionDenied> {
232 self.pending_inbound_connections.remove(&connection_id);
233
234 check_limit(
235 self.limits.max_established_incoming,
236 self.established_inbound_connections.len(),
237 Kind::EstablishedIncoming,
238 )?;
239 check_limit(
240 self.limits.max_established_per_peer,
241 self.established_per_peer
242 .get(&peer)
243 .map(|connections| connections.len())
244 .unwrap_or(0),
245 Kind::EstablishedPerPeer,
246 )?;
247 check_limit(
248 self.limits.max_established_total,
249 self.established_inbound_connections.len()
250 + self.established_outbound_connections.len(),
251 Kind::EstablishedTotal,
252 )?;
253
254 Ok(dummy::ConnectionHandler)
255 }
256
257 fn handle_pending_outbound_connection(
258 &mut self,
259 connection_id: ConnectionId,
260 _: Option<PeerId>,
261 _: &[Multiaddr],
262 _: Endpoint,
263 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
264 check_limit(
265 self.limits.max_pending_outgoing,
266 self.pending_outbound_connections.len(),
267 Kind::PendingOutgoing,
268 )?;
269
270 self.pending_outbound_connections.insert(connection_id);
271
272 Ok(vec![])
273 }
274
275 fn handle_established_outbound_connection(
276 &mut self,
277 connection_id: ConnectionId,
278 peer: PeerId,
279 _: &Multiaddr,
280 _: Endpoint,
281 ) -> Result<THandler<Self>, ConnectionDenied> {
282 self.pending_outbound_connections.remove(&connection_id);
283
284 check_limit(
285 self.limits.max_established_outgoing,
286 self.established_outbound_connections.len(),
287 Kind::EstablishedOutgoing,
288 )?;
289 check_limit(
290 self.limits.max_established_per_peer,
291 self.established_per_peer
292 .get(&peer)
293 .map(|connections| connections.len())
294 .unwrap_or(0),
295 Kind::EstablishedPerPeer,
296 )?;
297 check_limit(
298 self.limits.max_established_total,
299 self.established_inbound_connections.len()
300 + self.established_outbound_connections.len(),
301 Kind::EstablishedTotal,
302 )?;
303
304 Ok(dummy::ConnectionHandler)
305 }
306
307 fn on_swarm_event(&mut self, event: FromSwarm) {
308 match event {
309 FromSwarm::ConnectionClosed(ConnectionClosed {
310 peer_id,
311 connection_id,
312 ..
313 }) => {
314 self.established_inbound_connections.remove(&connection_id);
315 self.established_outbound_connections.remove(&connection_id);
316 self.established_per_peer
317 .entry(peer_id)
318 .or_default()
319 .remove(&connection_id);
320 }
321 FromSwarm::ConnectionEstablished(ConnectionEstablished {
322 peer_id,
323 endpoint,
324 connection_id,
325 ..
326 }) => {
327 match endpoint {
328 ConnectedPoint::Listener { .. } => {
329 self.established_inbound_connections.insert(connection_id);
330 }
331 ConnectedPoint::Dialer { .. } => {
332 self.established_outbound_connections.insert(connection_id);
333 }
334 }
335
336 self.established_per_peer
337 .entry(peer_id)
338 .or_default()
339 .insert(connection_id);
340 }
341 FromSwarm::DialFailure(DialFailure { connection_id, .. }) => {
342 self.pending_outbound_connections.remove(&connection_id);
343 }
344 FromSwarm::ListenFailure(ListenFailure { connection_id, .. }) => {
345 self.pending_inbound_connections.remove(&connection_id);
346 }
347 _ => {}
348 }
349 }
350
351 fn on_connection_handler_event(
352 &mut self,
353 _id: PeerId,
354 _: ConnectionId,
355 event: THandlerOutEvent<Self>,
356 ) {
357 void::unreachable(event)
358 }
359
360 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
361 Poll::Pending
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use libp2p_swarm::{
369 behaviour::toggle::Toggle, dial_opts::DialOpts, dial_opts::PeerCondition, DialError,
370 ListenError, Swarm, SwarmEvent,
371 };
372 use libp2p_swarm_test::SwarmExt;
373 use quickcheck::*;
374
375 #[test]
376 fn max_outgoing() {
377 use rand::Rng;
378
379 let outgoing_limit = rand::thread_rng().gen_range(1..10);
380
381 let mut network = Swarm::new_ephemeral(|_| {
382 Behaviour::new(
383 ConnectionLimits::default().with_max_pending_outgoing(Some(outgoing_limit)),
384 )
385 });
386
387 let addr: Multiaddr = "/memory/1234".parse().unwrap();
388 let target = PeerId::random();
389
390 for _ in 0..outgoing_limit {
391 network
392 .dial(
393 DialOpts::peer_id(target)
394 .condition(PeerCondition::Always)
396 .addresses(vec![addr.clone()])
397 .build(),
398 )
399 .expect("Unexpected connection limit.");
400 }
401
402 match network
403 .dial(
404 DialOpts::peer_id(target)
405 .condition(PeerCondition::Always)
406 .addresses(vec![addr])
407 .build(),
408 )
409 .expect_err("Unexpected dialing success.")
410 {
411 DialError::Denied { cause } => {
412 let exceeded = cause
413 .downcast::<Exceeded>()
414 .expect("connection denied because of limit");
415
416 assert_eq!(exceeded.limit(), outgoing_limit);
417 }
418 e => panic!("Unexpected error: {e:?}"),
419 }
420
421 let info = network.network_info();
422 assert_eq!(info.num_peers(), 0);
423 assert_eq!(
424 info.connection_counters().num_pending_outgoing(),
425 outgoing_limit
426 );
427 }
428
429 #[test]
430 fn max_established_incoming() {
431 fn prop(Limit(limit): Limit) {
432 let mut swarm1 = Swarm::new_ephemeral(|_| {
433 Behaviour::new(
434 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
435 )
436 });
437 let mut swarm2 = Swarm::new_ephemeral(|_| {
438 Behaviour::new(
439 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
440 )
441 });
442
443 async_std::task::block_on(async {
444 let (listen_addr, _) = swarm1.listen().with_memory_addr_external().await;
445
446 for _ in 0..limit {
447 swarm2.connect(&mut swarm1).await;
448 }
449
450 swarm2.dial(listen_addr).unwrap();
451
452 async_std::task::spawn(swarm2.loop_on_next());
453
454 let cause = swarm1
455 .wait(|event| match event {
456 SwarmEvent::IncomingConnectionError {
457 error: ListenError::Denied { cause },
458 ..
459 } => Some(cause),
460 _ => None,
461 })
462 .await;
463
464 assert_eq!(cause.downcast::<Exceeded>().unwrap().limit, limit);
465 });
466 }
467
468 #[derive(Debug, Clone)]
469 struct Limit(u32);
470
471 impl Arbitrary for Limit {
472 fn arbitrary(g: &mut Gen) -> Self {
473 Self(g.gen_range(1..10))
474 }
475 }
476
477 quickcheck(prop as fn(_));
478 }
479
480 #[test]
488 fn support_other_behaviour_denying_connection() {
489 let mut swarm1 = Swarm::new_ephemeral(|_| {
490 Behaviour::new_with_connection_denier(ConnectionLimits::default())
491 });
492 let mut swarm2 = Swarm::new_ephemeral(|_| Behaviour::new(ConnectionLimits::default()));
493
494 async_std::task::block_on(async {
495 let (listen_addr, _) = swarm1.listen().await;
497 swarm2.dial(listen_addr).unwrap();
498 async_std::task::spawn(swarm2.loop_on_next());
499
500 let cause = swarm1
502 .wait(|event| match event {
503 SwarmEvent::IncomingConnectionError {
504 error: ListenError::Denied { cause },
505 ..
506 } => Some(cause),
507 _ => None,
508 })
509 .await;
510
511 cause.downcast::<std::io::Error>().unwrap();
512
513 assert_eq!(
514 0,
515 swarm1
516 .behaviour_mut()
517 .limits
518 .established_inbound_connections
519 .len(),
520 "swarm1 connection limit behaviour to not count denied established connection as established connection"
521 )
522 });
523 }
524
525 #[derive(libp2p_swarm_derive::NetworkBehaviour)]
526 #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
527 struct Behaviour {
528 limits: super::Behaviour,
529 connection_denier: Toggle<ConnectionDenier>,
530 }
531
532 impl Behaviour {
533 fn new(limits: ConnectionLimits) -> Self {
534 Self {
535 limits: super::Behaviour::new(limits),
536 connection_denier: None.into(),
537 }
538 }
539 fn new_with_connection_denier(limits: ConnectionLimits) -> Self {
540 Self {
541 limits: super::Behaviour::new(limits),
542 connection_denier: Some(ConnectionDenier {}).into(),
543 }
544 }
545 }
546
547 struct ConnectionDenier {}
548
549 impl NetworkBehaviour for ConnectionDenier {
550 type ConnectionHandler = dummy::ConnectionHandler;
551 type ToSwarm = Void;
552
553 fn handle_established_inbound_connection(
554 &mut self,
555 _connection_id: ConnectionId,
556 _peer: PeerId,
557 _local_addr: &Multiaddr,
558 _remote_addr: &Multiaddr,
559 ) -> Result<THandler<Self>, ConnectionDenied> {
560 Err(ConnectionDenied::new(std::io::Error::new(
561 std::io::ErrorKind::Other,
562 "ConnectionDenier",
563 )))
564 }
565
566 fn handle_established_outbound_connection(
567 &mut self,
568 _connection_id: ConnectionId,
569 _peer: PeerId,
570 _addr: &Multiaddr,
571 _role_override: Endpoint,
572 ) -> Result<THandler<Self>, ConnectionDenied> {
573 Err(ConnectionDenied::new(std::io::Error::new(
574 std::io::ErrorKind::Other,
575 "ConnectionDenier",
576 )))
577 }
578
579 fn on_swarm_event(&mut self, _event: FromSwarm) {}
580
581 fn on_connection_handler_event(
582 &mut self,
583 _peer_id: PeerId,
584 _connection_id: ConnectionId,
585 event: THandlerOutEvent<Self>,
586 ) {
587 void::unreachable(event)
588 }
589
590 fn poll(
591 &mut self,
592 _: &mut Context<'_>,
593 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
594 Poll::Pending
595 }
596 }
597}