libp2p_swarm_derive/
lib.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21#![recursion_limit = "256"]
22#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
23
24mod syn_ext;
25
26use crate::syn_ext::RequireStrLit;
27use heck::ToUpperCamelCase;
28use proc_macro::TokenStream;
29use quote::quote;
30use syn::punctuated::Punctuated;
31use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Meta, Token};
32
33/// Generates a delegating `NetworkBehaviour` implementation for the struct this is used for. See
34/// the trait documentation for better description.
35#[proc_macro_derive(NetworkBehaviour, attributes(behaviour))]
36pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
37    let ast = parse_macro_input!(input as DeriveInput);
38    build(&ast).unwrap_or_else(|e| e.to_compile_error().into())
39}
40
41/// The actual implementation.
42fn build(ast: &DeriveInput) -> syn::Result<TokenStream> {
43    match ast.data {
44        Data::Struct(ref s) => build_struct(ast, s),
45        Data::Enum(_) => Err(syn::Error::new_spanned(
46            ast,
47            "Cannot derive `NetworkBehaviour` on enums",
48        )),
49        Data::Union(_) => Err(syn::Error::new_spanned(
50            ast,
51            "Cannot derive `NetworkBehaviour` on union",
52        )),
53    }
54}
55
56/// The version for structs
57fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> syn::Result<TokenStream> {
58    let name = &ast.ident;
59    let (_, ty_generics, where_clause) = ast.generics.split_for_impl();
60    let BehaviourAttributes {
61        prelude_path,
62        user_specified_out_event,
63    } = parse_attributes(ast)?;
64
65    let multiaddr = quote! { #prelude_path::Multiaddr };
66    let trait_to_impl = quote! { #prelude_path::NetworkBehaviour };
67    let either_ident = quote! { #prelude_path::Either };
68    let network_behaviour_action = quote! { #prelude_path::ToSwarm };
69    let connection_handler = quote! { #prelude_path::ConnectionHandler };
70    let proto_select_ident = quote! { #prelude_path::ConnectionHandlerSelect };
71    let peer_id = quote! { #prelude_path::PeerId };
72    let connection_id = quote! { #prelude_path::ConnectionId };
73    let from_swarm = quote! { #prelude_path::FromSwarm };
74    let t_handler = quote! { #prelude_path::THandler };
75    let t_handler_in_event = quote! { #prelude_path::THandlerInEvent };
76    let t_handler_out_event = quote! { #prelude_path::THandlerOutEvent };
77    let endpoint = quote! { #prelude_path::Endpoint };
78    let connection_denied = quote! { #prelude_path::ConnectionDenied };
79
80    // Build the generics.
81    let impl_generics = {
82        let tp = ast.generics.type_params();
83        let lf = ast.generics.lifetimes();
84        let cst = ast.generics.const_params();
85        quote! {<#(#lf,)* #(#tp,)* #(#cst,)*>}
86    };
87
88    let (out_event_name, out_event_definition, out_event_from_clauses) = {
89        // If we find a `#[behaviour(to_swarm = "Foo")]` attribute on the
90        // struct, we set `Foo` as the out event. If not, the `ToSwarm` is
91        // generated.
92        match user_specified_out_event {
93            // User provided `ToSwarm`.
94            Some(name) => {
95                let definition = None;
96                let from_clauses = data_struct
97                    .fields
98                    .iter()
99                    .map(|field| {
100                        let ty = &field.ty;
101                        quote! {#name: From< <#ty as #trait_to_impl>::ToSwarm >}
102                    })
103                    .collect::<Vec<_>>();
104                (name, definition, from_clauses)
105            }
106            // User did not provide `ToSwarm`. Generate it.
107            None => {
108                let enum_name_str = ast.ident.to_string() + "Event";
109                let enum_name: syn::Type =
110                    syn::parse_str(&enum_name_str).expect("ident + `Event` is a valid type");
111                let definition = {
112                    let fields = data_struct.fields.iter().map(|field| {
113                        let variant: syn::Variant = syn::parse_str(
114                            &field
115                                .ident
116                                .clone()
117                                .expect("Fields of NetworkBehaviour implementation to be named.")
118                                .to_string()
119                                .to_upper_camel_case(),
120                        )
121                        .expect("uppercased field name to be a valid enum variant");
122                        let ty = &field.ty;
123                        (variant, ty)
124                    });
125
126                    let enum_variants = fields
127                        .clone()
128                        .map(|(variant, ty)| quote! {#variant(<#ty as #trait_to_impl>::ToSwarm)});
129
130                    let visibility = &ast.vis;
131
132                    let additional = fields
133                        .clone()
134                        .map(|(_variant, tp)| quote! { #tp : #trait_to_impl })
135                        .collect::<Vec<_>>();
136
137                    let additional_debug = fields
138                        .clone()
139                        .map(|(_variant, ty)| quote! { <#ty as #trait_to_impl>::ToSwarm : ::core::fmt::Debug })
140                        .collect::<Vec<_>>();
141
142                    let where_clause = {
143                        if let Some(where_clause) = where_clause {
144                            if where_clause.predicates.trailing_punct() {
145                                Some(quote! {#where_clause #(#additional),* })
146                            } else {
147                                Some(quote! {#where_clause, #(#additional),*})
148                            }
149                        } else if additional.is_empty() {
150                            None
151                        } else {
152                            Some(quote! {where #(#additional),*})
153                        }
154                    };
155
156                    let where_clause_debug = where_clause
157                        .as_ref()
158                        .map(|where_clause| quote! {#where_clause, #(#additional_debug),*});
159
160                    let match_variants = fields.map(|(variant, _ty)| variant);
161                    let msg = format!("`NetworkBehaviour::ToSwarm` produced by {name}.");
162
163                    Some(quote! {
164                        #[doc = #msg]
165                        #visibility enum #enum_name #impl_generics
166                            #where_clause
167                        {
168                            #(#enum_variants),*
169                        }
170
171                        impl #impl_generics ::core::fmt::Debug for #enum_name #ty_generics #where_clause_debug {
172                            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
173                                match &self {
174                                    #(#enum_name::#match_variants(event) => {
175                                        write!(f, "{}: {:?}", #enum_name_str, event)
176                                    }),*
177                                }
178                            }
179                        }
180                    })
181                };
182                let from_clauses = vec![];
183                (enum_name, definition, from_clauses)
184            }
185        }
186    };
187
188    // Build the `where ...` clause of the trait implementation.
189    let where_clause = {
190        let additional = data_struct
191            .fields
192            .iter()
193            .map(|field| {
194                let ty = &field.ty;
195                quote! {#ty: #trait_to_impl}
196            })
197            .chain(out_event_from_clauses)
198            .collect::<Vec<_>>();
199
200        if let Some(where_clause) = where_clause {
201            if where_clause.predicates.trailing_punct() {
202                Some(quote! {#where_clause #(#additional),* })
203            } else {
204                Some(quote! {#where_clause, #(#additional),*})
205            }
206        } else {
207            Some(quote! {where #(#additional),*})
208        }
209    };
210
211    // Build the list of statements to put in the body of `on_swarm_event()`.
212    let on_swarm_event_stmts = {
213        data_struct
214            .fields
215            .iter()
216            .enumerate()
217            .map(|(field_n, field)| match field.ident {
218                Some(ref i) => quote! {
219                    self.#i.on_swarm_event(event);
220                },
221                None => quote! {
222                    self.#field_n.on_swarm_event(event);
223                },
224            })
225    };
226
227    // Build the list of variants to put in the body of `on_connection_handler_event()`.
228    //
229    // The event type is a construction of nested `#either_ident`s of the events of the children.
230    // We call `on_connection_handler_event` on the corresponding child.
231    let on_node_event_stmts =
232        data_struct
233            .fields
234            .iter()
235            .enumerate()
236            .enumerate()
237            .map(|(enum_n, (field_n, field))| {
238                let mut elem = if enum_n != 0 {
239                    quote! { #either_ident::Right(ev) }
240                } else {
241                    quote! { ev }
242                };
243
244                for _ in 0..data_struct.fields.len() - 1 - enum_n {
245                    elem = quote! { #either_ident::Left(#elem) };
246                }
247
248                Some(match field.ident {
249                    Some(ref i) => quote! { #elem => {
250                    #trait_to_impl::on_connection_handler_event(&mut self.#i, peer_id, connection_id, ev) }},
251                    None => quote! { #elem => {
252                    #trait_to_impl::on_connection_handler_event(&mut self.#field_n, peer_id, connection_id, ev) }},
253                })
254            });
255
256    // The [`ConnectionHandler`] associated type.
257    let connection_handler_ty = {
258        let mut ph_ty = None;
259        for field in data_struct.fields.iter() {
260            let ty = &field.ty;
261            let field_info = quote! { #t_handler<#ty> };
262            match ph_ty {
263                Some(ev) => ph_ty = Some(quote! { #proto_select_ident<#ev, #field_info> }),
264                ref mut ev @ None => *ev = Some(field_info),
265            }
266        }
267        // ph_ty = Some(quote! )
268        ph_ty.unwrap_or(quote! {()}) // TODO: `!` instead
269    };
270
271    // The content of `handle_pending_inbound_connection`.
272    let handle_pending_inbound_connection_stmts =
273        data_struct
274            .fields
275            .iter()
276            .enumerate()
277            .map(|(field_n, field)| {
278                match field.ident {
279                    Some(ref i) => quote! {
280                        #trait_to_impl::handle_pending_inbound_connection(&mut self.#i, connection_id, local_addr, remote_addr)?;
281                    },
282                    None => quote! {
283                        #trait_to_impl::handle_pending_inbound_connection(&mut self.#field_n, connection_id, local_addr, remote_addr)?;
284                    }
285                }
286            });
287
288    // The content of `handle_established_inbound_connection`.
289    let handle_established_inbound_connection = {
290        let mut out_handler = None;
291
292        for (field_n, field) in data_struct.fields.iter().enumerate() {
293            let field_name = match field.ident {
294                Some(ref i) => quote! { self.#i },
295                None => quote! { self.#field_n },
296            };
297
298            let builder = quote! {
299                #field_name.handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)?
300            };
301
302            match out_handler {
303                Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
304                ref mut h @ None => *h = Some(builder),
305            }
306        }
307
308        out_handler.unwrap_or(quote! {()}) // TODO: See test `empty`.
309    };
310
311    // The content of `handle_pending_outbound_connection`.
312    let handle_pending_outbound_connection = {
313        let extend_stmts =
314            data_struct
315                .fields
316                .iter()
317                .enumerate()
318                .map(|(field_n, field)| {
319                    match field.ident {
320                        Some(ref i) => quote! {
321                            combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#i, connection_id, maybe_peer, addresses, effective_role)?);
322                        },
323                        None => quote! {
324                            combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#field_n, connection_id, maybe_peer, addresses, effective_role)?);
325                        }
326                    }
327                });
328
329        quote! {
330            let mut combined_addresses = vec![];
331
332            #(#extend_stmts)*
333
334            Ok(combined_addresses)
335        }
336    };
337
338    // The content of `handle_established_outbound_connection`.
339    let handle_established_outbound_connection = {
340        let mut out_handler = None;
341
342        for (field_n, field) in data_struct.fields.iter().enumerate() {
343            let field_name = match field.ident {
344                Some(ref i) => quote! { self.#i },
345                None => quote! { self.#field_n },
346            };
347
348            let builder = quote! {
349                #field_name.handle_established_outbound_connection(connection_id, peer, addr, role_override)?
350            };
351
352            match out_handler {
353                Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
354                ref mut h @ None => *h = Some(builder),
355            }
356        }
357
358        out_handler.unwrap_or(quote! {()}) // TODO: See test `empty`.
359    };
360
361    // List of statements to put in `poll()`.
362    //
363    // We poll each child one by one and wrap around the output.
364    let poll_stmts = data_struct
365        .fields
366        .iter()
367        .enumerate()
368        .map(|(field_n, field)| {
369            let field = field
370                .ident
371                .clone()
372                .expect("Fields of NetworkBehaviour implementation to be named.");
373
374            let mut wrapped_event = if field_n != 0 {
375                quote! { #either_ident::Right(event) }
376            } else {
377                quote! { event }
378            };
379            for _ in 0..data_struct.fields.len() - 1 - field_n {
380                wrapped_event = quote! { #either_ident::Left(#wrapped_event) };
381            }
382
383            // If the `NetworkBehaviour`'s `ToSwarm` is generated by the derive macro, wrap the sub
384            // `NetworkBehaviour` `ToSwarm` in the variant of the generated `ToSwarm`. If the
385            // `NetworkBehaviour`'s `ToSwarm` is provided by the user, use the corresponding `From`
386            // implementation.
387            let map_out_event = if out_event_definition.is_some() {
388                let event_variant: syn::Variant =
389                    syn::parse_str(&field.to_string().to_upper_camel_case())
390                        .expect("uppercased field name to be a valid enum variant name");
391                quote! { #out_event_name::#event_variant }
392            } else {
393                quote! { |e| e.into() }
394            };
395
396            let map_in_event = quote! { |event| #wrapped_event };
397
398            quote! {
399                match #trait_to_impl::poll(&mut self.#field, cx) {
400                    std::task::Poll::Ready(e) => return std::task::Poll::Ready(e.map_out(#map_out_event).map_in(#map_in_event)),
401                    std::task::Poll::Pending => {},
402                }
403            }
404        });
405
406    let out_event_reference = if out_event_definition.is_some() {
407        quote! { #out_event_name #ty_generics }
408    } else {
409        quote! { #out_event_name }
410    };
411
412    // Now the magic happens.
413    let final_quote = quote! {
414        #out_event_definition
415
416        impl #impl_generics #trait_to_impl for #name #ty_generics
417        #where_clause
418        {
419            type ConnectionHandler = #connection_handler_ty;
420            type ToSwarm = #out_event_reference;
421
422            #[allow(clippy::needless_question_mark)]
423            fn handle_pending_inbound_connection(
424                &mut self,
425                connection_id: #connection_id,
426                local_addr: &#multiaddr,
427                remote_addr: &#multiaddr,
428            ) -> Result<(), #connection_denied> {
429                #(#handle_pending_inbound_connection_stmts)*
430
431                Ok(())
432            }
433
434            #[allow(clippy::needless_question_mark)]
435            fn handle_established_inbound_connection(
436                &mut self,
437                connection_id: #connection_id,
438                peer: #peer_id,
439                local_addr: &#multiaddr,
440                remote_addr: &#multiaddr,
441            ) -> Result<#t_handler<Self>, #connection_denied> {
442                Ok(#handle_established_inbound_connection)
443            }
444
445            #[allow(clippy::needless_question_mark)]
446            fn handle_pending_outbound_connection(
447                &mut self,
448                connection_id: #connection_id,
449                maybe_peer: Option<#peer_id>,
450                addresses: &[#multiaddr],
451                effective_role: #endpoint,
452            ) -> Result<::std::vec::Vec<#multiaddr>, #connection_denied> {
453                #handle_pending_outbound_connection
454            }
455
456            #[allow(clippy::needless_question_mark)]
457            fn handle_established_outbound_connection(
458                &mut self,
459                connection_id: #connection_id,
460                peer: #peer_id,
461                addr: &#multiaddr,
462                role_override: #endpoint,
463            ) -> Result<#t_handler<Self>, #connection_denied> {
464                Ok(#handle_established_outbound_connection)
465            }
466
467            fn on_connection_handler_event(
468                &mut self,
469                peer_id: #peer_id,
470                connection_id: #connection_id,
471                event: #t_handler_out_event<Self>
472            ) {
473                match event {
474                    #(#on_node_event_stmts),*
475                }
476            }
477
478            fn poll(&mut self, cx: &mut std::task::Context) -> std::task::Poll<#network_behaviour_action<Self::ToSwarm, #t_handler_in_event<Self>>> {
479                #(#poll_stmts)*
480                std::task::Poll::Pending
481            }
482
483            fn on_swarm_event(&mut self, event: #from_swarm) {
484                #(#on_swarm_event_stmts)*
485            }
486        }
487    };
488
489    Ok(final_quote.into())
490}
491
492struct BehaviourAttributes {
493    prelude_path: syn::Path,
494    user_specified_out_event: Option<syn::Type>,
495}
496
497/// Parses the `value` of a key=value pair in the `#[behaviour]` attribute into the requested type.
498fn parse_attributes(ast: &DeriveInput) -> syn::Result<BehaviourAttributes> {
499    let mut attributes = BehaviourAttributes {
500        prelude_path: syn::parse_quote! { ::libp2p::swarm::derive_prelude },
501        user_specified_out_event: None,
502    };
503
504    for attr in ast
505        .attrs
506        .iter()
507        .filter(|attr| attr.path().is_ident("behaviour"))
508    {
509        let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
510
511        for meta in nested {
512            if meta.path().is_ident("prelude") {
513                let value = meta.require_name_value()?.value.require_str_lit()?;
514
515                attributes.prelude_path = syn::parse_str(&value)?;
516
517                continue;
518            }
519
520            if meta.path().is_ident("to_swarm") || meta.path().is_ident("out_event") {
521                let value = meta.require_name_value()?.value.require_str_lit()?;
522
523                attributes.user_specified_out_event = Some(syn::parse_str(&value)?);
524
525                continue;
526            }
527        }
528    }
529
530    Ok(attributes)
531}