1#![recursion_limit = "256"]
22#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
23
24mod syn_ext;
25
26use crate::syn_ext::RequireStrLit;
27use heck::ToUpperCamelCase;
28use proc_macro::TokenStream;
29use quote::quote;
30use syn::punctuated::Punctuated;
31use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Meta, Token};
32
33#[proc_macro_derive(NetworkBehaviour, attributes(behaviour))]
36pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
37 let ast = parse_macro_input!(input as DeriveInput);
38 build(&ast).unwrap_or_else(|e| e.to_compile_error().into())
39}
40
41fn build(ast: &DeriveInput) -> syn::Result<TokenStream> {
43 match ast.data {
44 Data::Struct(ref s) => build_struct(ast, s),
45 Data::Enum(_) => Err(syn::Error::new_spanned(
46 ast,
47 "Cannot derive `NetworkBehaviour` on enums",
48 )),
49 Data::Union(_) => Err(syn::Error::new_spanned(
50 ast,
51 "Cannot derive `NetworkBehaviour` on union",
52 )),
53 }
54}
55
56fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> syn::Result<TokenStream> {
58 let name = &ast.ident;
59 let (_, ty_generics, where_clause) = ast.generics.split_for_impl();
60 let BehaviourAttributes {
61 prelude_path,
62 user_specified_out_event,
63 } = parse_attributes(ast)?;
64
65 let multiaddr = quote! { #prelude_path::Multiaddr };
66 let trait_to_impl = quote! { #prelude_path::NetworkBehaviour };
67 let either_ident = quote! { #prelude_path::Either };
68 let network_behaviour_action = quote! { #prelude_path::ToSwarm };
69 let connection_handler = quote! { #prelude_path::ConnectionHandler };
70 let proto_select_ident = quote! { #prelude_path::ConnectionHandlerSelect };
71 let peer_id = quote! { #prelude_path::PeerId };
72 let connection_id = quote! { #prelude_path::ConnectionId };
73 let from_swarm = quote! { #prelude_path::FromSwarm };
74 let t_handler = quote! { #prelude_path::THandler };
75 let t_handler_in_event = quote! { #prelude_path::THandlerInEvent };
76 let t_handler_out_event = quote! { #prelude_path::THandlerOutEvent };
77 let endpoint = quote! { #prelude_path::Endpoint };
78 let connection_denied = quote! { #prelude_path::ConnectionDenied };
79
80 let impl_generics = {
82 let tp = ast.generics.type_params();
83 let lf = ast.generics.lifetimes();
84 let cst = ast.generics.const_params();
85 quote! {<#(#lf,)* #(#tp,)* #(#cst,)*>}
86 };
87
88 let (out_event_name, out_event_definition, out_event_from_clauses) = {
89 match user_specified_out_event {
93 Some(name) => {
95 let definition = None;
96 let from_clauses = data_struct
97 .fields
98 .iter()
99 .map(|field| {
100 let ty = &field.ty;
101 quote! {#name: From< <#ty as #trait_to_impl>::ToSwarm >}
102 })
103 .collect::<Vec<_>>();
104 (name, definition, from_clauses)
105 }
106 None => {
108 let enum_name_str = ast.ident.to_string() + "Event";
109 let enum_name: syn::Type =
110 syn::parse_str(&enum_name_str).expect("ident + `Event` is a valid type");
111 let definition = {
112 let fields = data_struct.fields.iter().map(|field| {
113 let variant: syn::Variant = syn::parse_str(
114 &field
115 .ident
116 .clone()
117 .expect("Fields of NetworkBehaviour implementation to be named.")
118 .to_string()
119 .to_upper_camel_case(),
120 )
121 .expect("uppercased field name to be a valid enum variant");
122 let ty = &field.ty;
123 (variant, ty)
124 });
125
126 let enum_variants = fields
127 .clone()
128 .map(|(variant, ty)| quote! {#variant(<#ty as #trait_to_impl>::ToSwarm)});
129
130 let visibility = &ast.vis;
131
132 let additional = fields
133 .clone()
134 .map(|(_variant, tp)| quote! { #tp : #trait_to_impl })
135 .collect::<Vec<_>>();
136
137 let additional_debug = fields
138 .clone()
139 .map(|(_variant, ty)| quote! { <#ty as #trait_to_impl>::ToSwarm : ::core::fmt::Debug })
140 .collect::<Vec<_>>();
141
142 let where_clause = {
143 if let Some(where_clause) = where_clause {
144 if where_clause.predicates.trailing_punct() {
145 Some(quote! {#where_clause #(#additional),* })
146 } else {
147 Some(quote! {#where_clause, #(#additional),*})
148 }
149 } else if additional.is_empty() {
150 None
151 } else {
152 Some(quote! {where #(#additional),*})
153 }
154 };
155
156 let where_clause_debug = where_clause
157 .as_ref()
158 .map(|where_clause| quote! {#where_clause, #(#additional_debug),*});
159
160 let match_variants = fields.map(|(variant, _ty)| variant);
161 let msg = format!("`NetworkBehaviour::ToSwarm` produced by {name}.");
162
163 Some(quote! {
164 #[doc = #msg]
165 #visibility enum #enum_name #impl_generics
166 #where_clause
167 {
168 #(#enum_variants),*
169 }
170
171 impl #impl_generics ::core::fmt::Debug for #enum_name #ty_generics #where_clause_debug {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
173 match &self {
174 #(#enum_name::#match_variants(event) => {
175 write!(f, "{}: {:?}", #enum_name_str, event)
176 }),*
177 }
178 }
179 }
180 })
181 };
182 let from_clauses = vec![];
183 (enum_name, definition, from_clauses)
184 }
185 }
186 };
187
188 let where_clause = {
190 let additional = data_struct
191 .fields
192 .iter()
193 .map(|field| {
194 let ty = &field.ty;
195 quote! {#ty: #trait_to_impl}
196 })
197 .chain(out_event_from_clauses)
198 .collect::<Vec<_>>();
199
200 if let Some(where_clause) = where_clause {
201 if where_clause.predicates.trailing_punct() {
202 Some(quote! {#where_clause #(#additional),* })
203 } else {
204 Some(quote! {#where_clause, #(#additional),*})
205 }
206 } else {
207 Some(quote! {where #(#additional),*})
208 }
209 };
210
211 let on_swarm_event_stmts = {
213 data_struct
214 .fields
215 .iter()
216 .enumerate()
217 .map(|(field_n, field)| match field.ident {
218 Some(ref i) => quote! {
219 self.#i.on_swarm_event(event);
220 },
221 None => quote! {
222 self.#field_n.on_swarm_event(event);
223 },
224 })
225 };
226
227 let on_node_event_stmts =
232 data_struct
233 .fields
234 .iter()
235 .enumerate()
236 .enumerate()
237 .map(|(enum_n, (field_n, field))| {
238 let mut elem = if enum_n != 0 {
239 quote! { #either_ident::Right(ev) }
240 } else {
241 quote! { ev }
242 };
243
244 for _ in 0..data_struct.fields.len() - 1 - enum_n {
245 elem = quote! { #either_ident::Left(#elem) };
246 }
247
248 Some(match field.ident {
249 Some(ref i) => quote! { #elem => {
250 #trait_to_impl::on_connection_handler_event(&mut self.#i, peer_id, connection_id, ev) }},
251 None => quote! { #elem => {
252 #trait_to_impl::on_connection_handler_event(&mut self.#field_n, peer_id, connection_id, ev) }},
253 })
254 });
255
256 let connection_handler_ty = {
258 let mut ph_ty = None;
259 for field in data_struct.fields.iter() {
260 let ty = &field.ty;
261 let field_info = quote! { #t_handler<#ty> };
262 match ph_ty {
263 Some(ev) => ph_ty = Some(quote! { #proto_select_ident<#ev, #field_info> }),
264 ref mut ev @ None => *ev = Some(field_info),
265 }
266 }
267 ph_ty.unwrap_or(quote! {()}) };
270
271 let handle_pending_inbound_connection_stmts =
273 data_struct
274 .fields
275 .iter()
276 .enumerate()
277 .map(|(field_n, field)| {
278 match field.ident {
279 Some(ref i) => quote! {
280 #trait_to_impl::handle_pending_inbound_connection(&mut self.#i, connection_id, local_addr, remote_addr)?;
281 },
282 None => quote! {
283 #trait_to_impl::handle_pending_inbound_connection(&mut self.#field_n, connection_id, local_addr, remote_addr)?;
284 }
285 }
286 });
287
288 let handle_established_inbound_connection = {
290 let mut out_handler = None;
291
292 for (field_n, field) in data_struct.fields.iter().enumerate() {
293 let field_name = match field.ident {
294 Some(ref i) => quote! { self.#i },
295 None => quote! { self.#field_n },
296 };
297
298 let builder = quote! {
299 #field_name.handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)?
300 };
301
302 match out_handler {
303 Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
304 ref mut h @ None => *h = Some(builder),
305 }
306 }
307
308 out_handler.unwrap_or(quote! {()}) };
310
311 let handle_pending_outbound_connection = {
313 let extend_stmts =
314 data_struct
315 .fields
316 .iter()
317 .enumerate()
318 .map(|(field_n, field)| {
319 match field.ident {
320 Some(ref i) => quote! {
321 combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#i, connection_id, maybe_peer, addresses, effective_role)?);
322 },
323 None => quote! {
324 combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#field_n, connection_id, maybe_peer, addresses, effective_role)?);
325 }
326 }
327 });
328
329 quote! {
330 let mut combined_addresses = vec![];
331
332 #(#extend_stmts)*
333
334 Ok(combined_addresses)
335 }
336 };
337
338 let handle_established_outbound_connection = {
340 let mut out_handler = None;
341
342 for (field_n, field) in data_struct.fields.iter().enumerate() {
343 let field_name = match field.ident {
344 Some(ref i) => quote! { self.#i },
345 None => quote! { self.#field_n },
346 };
347
348 let builder = quote! {
349 #field_name.handle_established_outbound_connection(connection_id, peer, addr, role_override)?
350 };
351
352 match out_handler {
353 Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
354 ref mut h @ None => *h = Some(builder),
355 }
356 }
357
358 out_handler.unwrap_or(quote! {()}) };
360
361 let poll_stmts = data_struct
365 .fields
366 .iter()
367 .enumerate()
368 .map(|(field_n, field)| {
369 let field = field
370 .ident
371 .clone()
372 .expect("Fields of NetworkBehaviour implementation to be named.");
373
374 let mut wrapped_event = if field_n != 0 {
375 quote! { #either_ident::Right(event) }
376 } else {
377 quote! { event }
378 };
379 for _ in 0..data_struct.fields.len() - 1 - field_n {
380 wrapped_event = quote! { #either_ident::Left(#wrapped_event) };
381 }
382
383 let map_out_event = if out_event_definition.is_some() {
388 let event_variant: syn::Variant =
389 syn::parse_str(&field.to_string().to_upper_camel_case())
390 .expect("uppercased field name to be a valid enum variant name");
391 quote! { #out_event_name::#event_variant }
392 } else {
393 quote! { |e| e.into() }
394 };
395
396 let map_in_event = quote! { |event| #wrapped_event };
397
398 quote! {
399 match #trait_to_impl::poll(&mut self.#field, cx) {
400 std::task::Poll::Ready(e) => return std::task::Poll::Ready(e.map_out(#map_out_event).map_in(#map_in_event)),
401 std::task::Poll::Pending => {},
402 }
403 }
404 });
405
406 let out_event_reference = if out_event_definition.is_some() {
407 quote! { #out_event_name #ty_generics }
408 } else {
409 quote! { #out_event_name }
410 };
411
412 let final_quote = quote! {
414 #out_event_definition
415
416 impl #impl_generics #trait_to_impl for #name #ty_generics
417 #where_clause
418 {
419 type ConnectionHandler = #connection_handler_ty;
420 type ToSwarm = #out_event_reference;
421
422 #[allow(clippy::needless_question_mark)]
423 fn handle_pending_inbound_connection(
424 &mut self,
425 connection_id: #connection_id,
426 local_addr: &#multiaddr,
427 remote_addr: &#multiaddr,
428 ) -> Result<(), #connection_denied> {
429 #(#handle_pending_inbound_connection_stmts)*
430
431 Ok(())
432 }
433
434 #[allow(clippy::needless_question_mark)]
435 fn handle_established_inbound_connection(
436 &mut self,
437 connection_id: #connection_id,
438 peer: #peer_id,
439 local_addr: &#multiaddr,
440 remote_addr: &#multiaddr,
441 ) -> Result<#t_handler<Self>, #connection_denied> {
442 Ok(#handle_established_inbound_connection)
443 }
444
445 #[allow(clippy::needless_question_mark)]
446 fn handle_pending_outbound_connection(
447 &mut self,
448 connection_id: #connection_id,
449 maybe_peer: Option<#peer_id>,
450 addresses: &[#multiaddr],
451 effective_role: #endpoint,
452 ) -> Result<::std::vec::Vec<#multiaddr>, #connection_denied> {
453 #handle_pending_outbound_connection
454 }
455
456 #[allow(clippy::needless_question_mark)]
457 fn handle_established_outbound_connection(
458 &mut self,
459 connection_id: #connection_id,
460 peer: #peer_id,
461 addr: &#multiaddr,
462 role_override: #endpoint,
463 ) -> Result<#t_handler<Self>, #connection_denied> {
464 Ok(#handle_established_outbound_connection)
465 }
466
467 fn on_connection_handler_event(
468 &mut self,
469 peer_id: #peer_id,
470 connection_id: #connection_id,
471 event: #t_handler_out_event<Self>
472 ) {
473 match event {
474 #(#on_node_event_stmts),*
475 }
476 }
477
478 fn poll(&mut self, cx: &mut std::task::Context) -> std::task::Poll<#network_behaviour_action<Self::ToSwarm, #t_handler_in_event<Self>>> {
479 #(#poll_stmts)*
480 std::task::Poll::Pending
481 }
482
483 fn on_swarm_event(&mut self, event: #from_swarm) {
484 #(#on_swarm_event_stmts)*
485 }
486 }
487 };
488
489 Ok(final_quote.into())
490}
491
492struct BehaviourAttributes {
493 prelude_path: syn::Path,
494 user_specified_out_event: Option<syn::Type>,
495}
496
497fn parse_attributes(ast: &DeriveInput) -> syn::Result<BehaviourAttributes> {
499 let mut attributes = BehaviourAttributes {
500 prelude_path: syn::parse_quote! { ::libp2p::swarm::derive_prelude },
501 user_specified_out_event: None,
502 };
503
504 for attr in ast
505 .attrs
506 .iter()
507 .filter(|attr| attr.path().is_ident("behaviour"))
508 {
509 let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
510
511 for meta in nested {
512 if meta.path().is_ident("prelude") {
513 let value = meta.require_name_value()?.value.require_str_lit()?;
514
515 attributes.prelude_path = syn::parse_str(&value)?;
516
517 continue;
518 }
519
520 if meta.path().is_ident("to_swarm") || meta.path().is_ident("out_event") {
521 let value = meta.require_name_value()?.value.require_str_lit()?;
522
523 attributes.user_specified_out_event = Some(syn::parse_str(&value)?);
524
525 continue;
526 }
527 }
528 }
529
530 Ok(attributes)
531}