tungstenite/
client.rs

1//! Methods to connect to a WebSocket as a client.
2
3use std::{
4    convert::TryFrom,
5    io::{Read, Write},
6    net::{SocketAddr, TcpStream, ToSocketAddrs},
7    result::Result as StdResult,
8};
9
10use http::{request::Parts, HeaderName, Uri};
11use log::*;
12
13use crate::{
14    handshake::client::{generate_key, Request, Response},
15    protocol::WebSocketConfig,
16    stream::MaybeTlsStream,
17};
18
19use crate::{
20    error::{Error, Result, UrlError},
21    handshake::{client::ClientHandshake, HandshakeError},
22    protocol::WebSocket,
23    stream::{Mode, NoDelay},
24};
25
26/// Connect to the given WebSocket in blocking mode.
27///
28/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
29/// equal to calling `connect()` function.
30///
31/// The URL may be either ws:// or wss://.
32/// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
33/// project's [README][readme] for more information on available features.
34///
35/// This function "just works" for those who wants a simple blocking solution
36/// similar to `std::net::TcpStream`. If you want a non-blocking or other
37/// custom stream, call `client` instead.
38///
39/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
40/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
41/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
42///
43/// [readme]: https://github.com/snapview/tungstenite-rs/#features
44pub fn connect_with_config<Req: IntoClientRequest>(
45    request: Req,
46    config: Option<WebSocketConfig>,
47    max_redirects: u8,
48) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
49    fn try_client_handshake(
50        request: Request,
51        config: Option<WebSocketConfig>,
52    ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
53        let uri = request.uri();
54        let mode = uri_mode(uri)?;
55
56        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
57        if let Mode::Tls = mode {
58            return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
59        }
60
61        let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
62        let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
63        let port = uri.port_u16().unwrap_or(match mode {
64            Mode::Plain => 80,
65            Mode::Tls => 443,
66        });
67        let addrs = (host, port).to_socket_addrs()?;
68        let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
69        NoDelay::set_nodelay(&mut stream, true)?;
70
71        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
72        let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
73        #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
74        let client = crate::tls::client_tls_with_config(request, stream, config, None);
75
76        client.map_err(|e| match e {
77            HandshakeError::Failure(f) => f,
78            HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
79        })
80    }
81
82    fn create_request(parts: &Parts, uri: &Uri) -> Request {
83        let mut builder =
84            Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
85        *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
86        builder.body(()).expect("Failed to create `Request`")
87    }
88
89    let (parts, _) = request.into_client_request()?.into_parts();
90    let mut uri = parts.uri.clone();
91
92    for attempt in 0..(max_redirects + 1) {
93        let request = create_request(&parts, &uri);
94
95        match try_client_handshake(request, config) {
96            Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
97                if let Some(location) = res.headers().get("Location") {
98                    uri = location.to_str()?.parse::<Uri>()?;
99                    debug!("Redirecting to {:?}", uri);
100                    continue;
101                } else {
102                    warn!("No `Location` found in redirect");
103                    return Err(Error::Http(res));
104                }
105            }
106            other => return other,
107        }
108    }
109
110    unreachable!("Bug in a redirect handling logic")
111}
112
113/// Connect to the given WebSocket in blocking mode.
114///
115/// The URL may be either ws:// or wss://.
116/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
117///
118/// This function "just works" for those who wants a simple blocking solution
119/// similar to `std::net::TcpStream`. If you want a non-blocking or other
120/// custom stream, call `client` instead.
121///
122/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
123/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
124/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
125pub fn connect<Req: IntoClientRequest>(
126    request: Req,
127) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
128    connect_with_config(request, None, 3)
129}
130
131fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
132    for addr in addrs {
133        debug!("Trying to contact {} at {}...", uri, addr);
134        if let Ok(stream) = TcpStream::connect(addr) {
135            return Ok(stream);
136        }
137    }
138    Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
139}
140
141/// Get the mode of the given URL.
142///
143/// This function may be used to ease the creation of custom TLS streams
144/// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
145pub fn uri_mode(uri: &Uri) -> Result<Mode> {
146    match uri.scheme_str() {
147        Some("ws") => Ok(Mode::Plain),
148        Some("wss") => Ok(Mode::Tls),
149        _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
150    }
151}
152
153/// Do the client handshake over the given stream given a web socket configuration. Passing `None`
154/// as configuration is equal to calling `client()` function.
155///
156/// Use this function if you need a nonblocking handshake support or if you
157/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
158/// Any stream supporting `Read + Write` will do.
159pub fn client_with_config<Stream, Req>(
160    request: Req,
161    stream: Stream,
162    config: Option<WebSocketConfig>,
163) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
164where
165    Stream: Read + Write,
166    Req: IntoClientRequest,
167{
168    ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
169}
170
171/// Do the client handshake over the given stream.
172///
173/// Use this function if you need a nonblocking handshake support or if you
174/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
175/// Any stream supporting `Read + Write` will do.
176pub fn client<Stream, Req>(
177    request: Req,
178    stream: Stream,
179) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
180where
181    Stream: Read + Write,
182    Req: IntoClientRequest,
183{
184    client_with_config(request, stream, None)
185}
186
187/// Trait for converting various types into HTTP requests used for a client connection.
188///
189/// This trait is implemented by default for string slices, strings, `http::Uri` and
190/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
191/// simply take your request and pass it as is further without altering any headers or URLs, so
192/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
193/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
194/// the proper `http::Request<()>` for you.
195pub trait IntoClientRequest {
196    /// Convert into a `Request` that can be used for a client connection.
197    fn into_client_request(self) -> Result<Request>;
198}
199
200impl<'a> IntoClientRequest for &'a str {
201    fn into_client_request(self) -> Result<Request> {
202        self.parse::<Uri>()?.into_client_request()
203    }
204}
205
206impl<'a> IntoClientRequest for &'a String {
207    fn into_client_request(self) -> Result<Request> {
208        <&str as IntoClientRequest>::into_client_request(self)
209    }
210}
211
212impl IntoClientRequest for String {
213    fn into_client_request(self) -> Result<Request> {
214        <&str as IntoClientRequest>::into_client_request(&self)
215    }
216}
217
218impl<'a> IntoClientRequest for &'a Uri {
219    fn into_client_request(self) -> Result<Request> {
220        self.clone().into_client_request()
221    }
222}
223
224impl IntoClientRequest for Uri {
225    fn into_client_request(self) -> Result<Request> {
226        let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
227        let host = authority
228            .find('@')
229            .map(|idx| authority.split_at(idx + 1).1)
230            .unwrap_or_else(|| authority);
231
232        if host.is_empty() {
233            return Err(Error::Url(UrlError::EmptyHostName));
234        }
235
236        let req = Request::builder()
237            .method("GET")
238            .header("Host", host)
239            .header("Connection", "Upgrade")
240            .header("Upgrade", "websocket")
241            .header("Sec-WebSocket-Version", "13")
242            .header("Sec-WebSocket-Key", generate_key())
243            .uri(self)
244            .body(())?;
245        Ok(req)
246    }
247}
248
249#[cfg(feature = "url")]
250impl<'a> IntoClientRequest for &'a url::Url {
251    fn into_client_request(self) -> Result<Request> {
252        self.as_str().into_client_request()
253    }
254}
255
256#[cfg(feature = "url")]
257impl IntoClientRequest for url::Url {
258    fn into_client_request(self) -> Result<Request> {
259        self.as_str().into_client_request()
260    }
261}
262
263impl IntoClientRequest for Request {
264    fn into_client_request(self) -> Result<Request> {
265        Ok(self)
266    }
267}
268
269impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
270    fn into_client_request(self) -> Result<Request> {
271        use crate::handshake::headers::FromHttparse;
272        Request::from_httparse(self)
273    }
274}
275
276/// Builder for a custom [`IntoClientRequest`] with options to add
277/// custom additional headers and sub protocols.
278///
279/// # Example
280///
281/// ```rust no_run
282/// # use crate::*;
283/// use http::Uri;
284/// use tungstenite::{connect, ClientRequestBuilder};
285///
286/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
287/// let token = "my_jwt_token";
288/// let builder = ClientRequestBuilder::new(uri)
289///     .with_header("Authorization", format!("Bearer {token}"))
290///     .with_sub_protocol("my_sub_protocol");
291/// let socket = connect(builder).unwrap();
292/// ```
293#[derive(Debug, Clone)]
294pub struct ClientRequestBuilder {
295    uri: Uri,
296    /// Additional [`Request`] handshake headers
297    additional_headers: Vec<(String, String)>,
298    /// Handsake subprotocols
299    subprotocols: Vec<String>,
300}
301
302impl ClientRequestBuilder {
303    /// Initializes an empty request builder
304    #[must_use]
305    pub const fn new(uri: Uri) -> Self {
306        Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
307    }
308
309    /// Adds (`key`, `value`) as an additional header to the handshake request
310    pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
311    where
312        K: Into<String>,
313        V: Into<String>,
314    {
315        self.additional_headers.push((key.into(), value.into()));
316        self
317    }
318
319    /// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
320    pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
321    where
322        P: Into<String>,
323    {
324        self.subprotocols.push(protocol.into());
325        self
326    }
327}
328
329impl IntoClientRequest for ClientRequestBuilder {
330    fn into_client_request(self) -> Result<Request> {
331        let mut request = self.uri.into_client_request()?;
332        let headers = request.headers_mut();
333        for (k, v) in self.additional_headers {
334            let key = HeaderName::try_from(k)?;
335            let value = v.parse()?;
336            headers.append(key, value);
337        }
338        if !self.subprotocols.is_empty() {
339            let protocols = self.subprotocols.join(", ").parse()?;
340            headers.append("Sec-WebSocket-Protocol", protocols);
341        }
342        Ok(request)
343    }
344}