1use std::{
4 convert::TryFrom,
5 io::{Read, Write},
6 net::{SocketAddr, TcpStream, ToSocketAddrs},
7 result::Result as StdResult,
8};
9
10use http::{request::Parts, HeaderName, Uri};
11use log::*;
12
13use crate::{
14 handshake::client::{generate_key, Request, Response},
15 protocol::WebSocketConfig,
16 stream::MaybeTlsStream,
17};
18
19use crate::{
20 error::{Error, Result, UrlError},
21 handshake::{client::ClientHandshake, HandshakeError},
22 protocol::WebSocket,
23 stream::{Mode, NoDelay},
24};
25
26pub fn connect_with_config<Req: IntoClientRequest>(
45 request: Req,
46 config: Option<WebSocketConfig>,
47 max_redirects: u8,
48) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
49 fn try_client_handshake(
50 request: Request,
51 config: Option<WebSocketConfig>,
52 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
53 let uri = request.uri();
54 let mode = uri_mode(uri)?;
55
56 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
57 if let Mode::Tls = mode {
58 return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
59 }
60
61 let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
62 let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
63 let port = uri.port_u16().unwrap_or(match mode {
64 Mode::Plain => 80,
65 Mode::Tls => 443,
66 });
67 let addrs = (host, port).to_socket_addrs()?;
68 let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
69 NoDelay::set_nodelay(&mut stream, true)?;
70
71 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
72 let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
73 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
74 let client = crate::tls::client_tls_with_config(request, stream, config, None);
75
76 client.map_err(|e| match e {
77 HandshakeError::Failure(f) => f,
78 HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
79 })
80 }
81
82 fn create_request(parts: &Parts, uri: &Uri) -> Request {
83 let mut builder =
84 Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
85 *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
86 builder.body(()).expect("Failed to create `Request`")
87 }
88
89 let (parts, _) = request.into_client_request()?.into_parts();
90 let mut uri = parts.uri.clone();
91
92 for attempt in 0..(max_redirects + 1) {
93 let request = create_request(&parts, &uri);
94
95 match try_client_handshake(request, config) {
96 Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
97 if let Some(location) = res.headers().get("Location") {
98 uri = location.to_str()?.parse::<Uri>()?;
99 debug!("Redirecting to {:?}", uri);
100 continue;
101 } else {
102 warn!("No `Location` found in redirect");
103 return Err(Error::Http(res));
104 }
105 }
106 other => return other,
107 }
108 }
109
110 unreachable!("Bug in a redirect handling logic")
111}
112
113pub fn connect<Req: IntoClientRequest>(
126 request: Req,
127) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
128 connect_with_config(request, None, 3)
129}
130
131fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
132 for addr in addrs {
133 debug!("Trying to contact {} at {}...", uri, addr);
134 if let Ok(stream) = TcpStream::connect(addr) {
135 return Ok(stream);
136 }
137 }
138 Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
139}
140
141pub fn uri_mode(uri: &Uri) -> Result<Mode> {
146 match uri.scheme_str() {
147 Some("ws") => Ok(Mode::Plain),
148 Some("wss") => Ok(Mode::Tls),
149 _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
150 }
151}
152
153pub fn client_with_config<Stream, Req>(
160 request: Req,
161 stream: Stream,
162 config: Option<WebSocketConfig>,
163) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
164where
165 Stream: Read + Write,
166 Req: IntoClientRequest,
167{
168 ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
169}
170
171pub fn client<Stream, Req>(
177 request: Req,
178 stream: Stream,
179) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
180where
181 Stream: Read + Write,
182 Req: IntoClientRequest,
183{
184 client_with_config(request, stream, None)
185}
186
187pub trait IntoClientRequest {
196 fn into_client_request(self) -> Result<Request>;
198}
199
200impl<'a> IntoClientRequest for &'a str {
201 fn into_client_request(self) -> Result<Request> {
202 self.parse::<Uri>()?.into_client_request()
203 }
204}
205
206impl<'a> IntoClientRequest for &'a String {
207 fn into_client_request(self) -> Result<Request> {
208 <&str as IntoClientRequest>::into_client_request(self)
209 }
210}
211
212impl IntoClientRequest for String {
213 fn into_client_request(self) -> Result<Request> {
214 <&str as IntoClientRequest>::into_client_request(&self)
215 }
216}
217
218impl<'a> IntoClientRequest for &'a Uri {
219 fn into_client_request(self) -> Result<Request> {
220 self.clone().into_client_request()
221 }
222}
223
224impl IntoClientRequest for Uri {
225 fn into_client_request(self) -> Result<Request> {
226 let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
227 let host = authority
228 .find('@')
229 .map(|idx| authority.split_at(idx + 1).1)
230 .unwrap_or_else(|| authority);
231
232 if host.is_empty() {
233 return Err(Error::Url(UrlError::EmptyHostName));
234 }
235
236 let req = Request::builder()
237 .method("GET")
238 .header("Host", host)
239 .header("Connection", "Upgrade")
240 .header("Upgrade", "websocket")
241 .header("Sec-WebSocket-Version", "13")
242 .header("Sec-WebSocket-Key", generate_key())
243 .uri(self)
244 .body(())?;
245 Ok(req)
246 }
247}
248
249#[cfg(feature = "url")]
250impl<'a> IntoClientRequest for &'a url::Url {
251 fn into_client_request(self) -> Result<Request> {
252 self.as_str().into_client_request()
253 }
254}
255
256#[cfg(feature = "url")]
257impl IntoClientRequest for url::Url {
258 fn into_client_request(self) -> Result<Request> {
259 self.as_str().into_client_request()
260 }
261}
262
263impl IntoClientRequest for Request {
264 fn into_client_request(self) -> Result<Request> {
265 Ok(self)
266 }
267}
268
269impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
270 fn into_client_request(self) -> Result<Request> {
271 use crate::handshake::headers::FromHttparse;
272 Request::from_httparse(self)
273 }
274}
275
276#[derive(Debug, Clone)]
294pub struct ClientRequestBuilder {
295 uri: Uri,
296 additional_headers: Vec<(String, String)>,
298 subprotocols: Vec<String>,
300}
301
302impl ClientRequestBuilder {
303 #[must_use]
305 pub const fn new(uri: Uri) -> Self {
306 Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
307 }
308
309 pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
311 where
312 K: Into<String>,
313 V: Into<String>,
314 {
315 self.additional_headers.push((key.into(), value.into()));
316 self
317 }
318
319 pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
321 where
322 P: Into<String>,
323 {
324 self.subprotocols.push(protocol.into());
325 self
326 }
327}
328
329impl IntoClientRequest for ClientRequestBuilder {
330 fn into_client_request(self) -> Result<Request> {
331 let mut request = self.uri.into_client_request()?;
332 let headers = request.headers_mut();
333 for (k, v) in self.additional_headers {
334 let key = HeaderName::try_from(k)?;
335 let value = v.parse()?;
336 headers.append(key, value);
337 }
338 if !self.subprotocols.is_empty() {
339 let protocols = self.subprotocols.join(", ").parse()?;
340 headers.append("Sec-WebSocket-Protocol", protocols);
341 }
342 Ok(request)
343 }
344}