libp2p_noise/io/
framed.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Provides a [`Codec`] type implementing the [`Encoder`] and [`Decoder`] traits.
22//!
23//! Alongside a [`asynchronous_codec::Framed`] this provides a [Sink](futures::Sink)
24//! and [Stream](futures::Stream) for length-delimited Noise protocol messages.
25
26use super::handshake::proto;
27use crate::{protocol::PublicKey, Error};
28use asynchronous_codec::{Decoder, Encoder};
29use bytes::{Buf, Bytes, BytesMut};
30use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
31use std::io;
32use std::mem::size_of;
33
34/// Max. size of a noise message.
35const MAX_NOISE_MSG_LEN: usize = 65535;
36/// Space given to the encryption buffer to hold key material.
37const EXTRA_ENCRYPT_SPACE: usize = 1024;
38/// Max. length for Noise protocol message payloads.
39pub(crate) const MAX_FRAME_LEN: usize = MAX_NOISE_MSG_LEN - EXTRA_ENCRYPT_SPACE;
40static_assertions::const_assert! {
41    MAX_FRAME_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_MSG_LEN
42}
43
44/// Codec holds the noise session state `S` and acts as a medium for
45/// encoding and decoding length-delimited session messages.
46pub(crate) struct Codec<S> {
47    session: S,
48
49    // We reuse write and encryption buffers across multiple messages to avoid reallocations.
50    // We cannot reuse read and decryption buffers because we cannot return borrowed data.
51    write_buffer: BytesMut,
52    encrypt_buffer: BytesMut,
53}
54
55impl<S> Codec<S> {
56    pub(crate) fn new(session: S) -> Self {
57        Codec {
58            session,
59            write_buffer: BytesMut::default(),
60            encrypt_buffer: BytesMut::default(),
61        }
62    }
63}
64
65impl Codec<snow::HandshakeState> {
66    /// Checks if the session was started in the `initiator` role.
67    pub(crate) fn is_initiator(&self) -> bool {
68        self.session.is_initiator()
69    }
70
71    /// Checks if the session was started in the `responder` role.
72    pub(crate) fn is_responder(&self) -> bool {
73        !self.session.is_initiator()
74    }
75
76    /// Converts the underlying Noise session from the [`snow::HandshakeState`] to a
77    /// [`snow::TransportState`] once the handshake is complete, including the static
78    /// DH [`PublicKey`] of the remote if received.
79    ///
80    /// If the Noise protocol session state does not permit transitioning to
81    /// transport mode because the handshake is incomplete, an error is returned.
82    ///
83    /// An error is also returned if the remote's static DH key is not present or
84    /// cannot be parsed, as that indicates a fatal handshake error for the noise
85    /// `XX` pattern, which is the only handshake protocol libp2p currently supports.
86    pub(crate) fn into_transport(self) -> Result<(PublicKey, Codec<snow::TransportState>), Error> {
87        let dh_remote_pubkey = self.session.get_remote_static().ok_or_else(|| {
88            Error::Io(io::Error::new(
89                io::ErrorKind::Other,
90                "expect key to always be present at end of XX session",
91            ))
92        })?;
93
94        let dh_remote_pubkey = PublicKey::from_slice(dh_remote_pubkey)?;
95        let codec = Codec::new(self.session.into_transport_mode()?);
96
97        Ok((dh_remote_pubkey, codec))
98    }
99}
100
101impl Encoder for Codec<snow::HandshakeState> {
102    type Error = io::Error;
103    type Item<'a> = &'a proto::NoiseHandshakePayload;
104
105    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
106        let item_size = item.get_size();
107
108        self.write_buffer.resize(item_size, 0);
109        let mut writer = Writer::new(&mut self.write_buffer[..item_size]);
110        item.write_message(&mut writer)
111            .expect("Protobuf encoding to succeed");
112
113        encrypt(
114            &self.write_buffer[..item_size],
115            dst,
116            &mut self.encrypt_buffer,
117            |item, buffer| self.session.write_message(item, buffer),
118        )?;
119
120        Ok(())
121    }
122}
123
124impl Decoder for Codec<snow::HandshakeState> {
125    type Error = io::Error;
126    type Item = proto::NoiseHandshakePayload;
127
128    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
129        let cleartext = match decrypt(src, |ciphertext, decrypt_buffer| {
130            self.session.read_message(ciphertext, decrypt_buffer)
131        })? {
132            None => return Ok(None),
133            Some(cleartext) => cleartext,
134        };
135
136        let mut reader = BytesReader::from_bytes(&cleartext[..]);
137        let pb =
138            proto::NoiseHandshakePayload::from_reader(&mut reader, &cleartext).map_err(|_| {
139                io::Error::new(
140                    io::ErrorKind::InvalidData,
141                    "Failed decoding handshake payload",
142                )
143            })?;
144
145        Ok(Some(pb))
146    }
147}
148
149impl Encoder for Codec<snow::TransportState> {
150    type Error = io::Error;
151    type Item<'a> = &'a [u8];
152
153    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
154        encrypt(item, dst, &mut self.encrypt_buffer, |item, buffer| {
155            self.session.write_message(item, buffer)
156        })
157    }
158}
159
160impl Decoder for Codec<snow::TransportState> {
161    type Error = io::Error;
162    type Item = Bytes;
163
164    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
165        decrypt(src, |ciphertext, decrypt_buffer| {
166            self.session.read_message(ciphertext, decrypt_buffer)
167        })
168    }
169}
170
171/// Encrypts the given cleartext to `dst`.
172///
173/// This is a standalone function to allow us reusing the `encrypt_buffer` and to use to across different session states of the noise protocol.
174fn encrypt(
175    cleartext: &[u8],
176    dst: &mut BytesMut,
177    encrypt_buffer: &mut BytesMut,
178    encrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result<usize, snow::Error>,
179) -> io::Result<()> {
180    tracing::trace!("Encrypting {} bytes", cleartext.len());
181
182    encrypt_buffer.resize(cleartext.len() + EXTRA_ENCRYPT_SPACE, 0);
183    let n = encrypt_fn(cleartext, encrypt_buffer).map_err(into_io_error)?;
184
185    tracing::trace!("Outgoing ciphertext has {n} bytes");
186
187    encode_length_prefixed(&encrypt_buffer[..n], dst);
188
189    Ok(())
190}
191
192/// Encrypts the given ciphertext.
193///
194/// This is a standalone function so we can use it across different session states of the noise protocol.
195/// In case `ciphertext` does not contain enough bytes to decrypt the entire frame, `Ok(None)` is returned.
196fn decrypt(
197    ciphertext: &mut BytesMut,
198    decrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result<usize, snow::Error>,
199) -> io::Result<Option<Bytes>> {
200    let ciphertext = match decode_length_prefixed(ciphertext)? {
201        Some(b) => b,
202        None => return Ok(None),
203    };
204
205    tracing::trace!("Incoming ciphertext has {} bytes", ciphertext.len());
206
207    let mut decrypt_buffer = BytesMut::zeroed(ciphertext.len());
208    let n = decrypt_fn(&ciphertext, &mut decrypt_buffer).map_err(into_io_error)?;
209
210    tracing::trace!("Decrypted cleartext has {n} bytes");
211
212    Ok(Some(decrypt_buffer.split_to(n).freeze()))
213}
214
215fn into_io_error(err: snow::Error) -> io::Error {
216    io::Error::new(io::ErrorKind::InvalidData, err)
217}
218
219const U16_LENGTH: usize = size_of::<u16>();
220
221fn encode_length_prefixed(src: &[u8], dst: &mut BytesMut) {
222    dst.reserve(U16_LENGTH + src.len());
223    dst.extend_from_slice(&(src.len() as u16).to_be_bytes());
224    dst.extend_from_slice(src);
225}
226
227fn decode_length_prefixed(src: &mut BytesMut) -> Result<Option<Bytes>, io::Error> {
228    if src.len() < size_of::<u16>() {
229        return Ok(None);
230    }
231
232    let mut len_bytes = [0u8; U16_LENGTH];
233    len_bytes.copy_from_slice(&src[..U16_LENGTH]);
234    let len = u16::from_be_bytes(len_bytes) as usize;
235
236    if src.len() - U16_LENGTH >= len {
237        // Skip the length header we already read.
238        src.advance(U16_LENGTH);
239        Ok(Some(src.split_to(len).freeze()))
240    } else {
241        Ok(None)
242    }
243}