Skip to content

Commit d27b15b

Browse files
authored
Extracts core WebSocket upgrade logic for integration into other libraries (#45)
* Removes NegotiatedExtension * Initial restructure to expose request parsing * Extracts core WebSocket upgrade logic for servers * Rolls back CI changes * Resolves incorrect response building * Resolves PR comments * Resolves PR comments * Fixes incorrect HTTP Error documentation * Updates to new API * Adds more server upgrade functions * Exposes simple function to build a response from request headers * Restructures server handshake parsing functions * Updates incorrect documentation
1 parent 317d373 commit d27b15b

File tree

23 files changed

+936
-560
lines changed

23 files changed

+936
-560
lines changed

ratchet_core/src/builder.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
use crate::errors::Error;
1616
use crate::ext::NoExtProvider;
17-
use crate::handshake::{ProtocolRegistry, UpgradedServer};
17+
use crate::handshake::{SubprotocolRegistry, UpgradedServer};
1818
use crate::{subscribe_with, TryIntoRequest, UpgradedClient, WebSocketConfig, WebSocketStream};
1919
use ratchet_ext::ExtensionProvider;
20-
use std::borrow::Cow;
2120

2221
/// A builder to construct WebSocket clients.
2322
///
@@ -28,15 +27,15 @@ use std::borrow::Cow;
2827
pub struct WebSocketClientBuilder<E> {
2928
config: Option<WebSocketConfig>,
3029
extension: E,
31-
subprotocols: ProtocolRegistry,
30+
subprotocols: SubprotocolRegistry,
3231
}
3332

3433
impl Default for WebSocketClientBuilder<NoExtProvider> {
3534
fn default() -> Self {
3635
WebSocketClientBuilder {
3736
config: None,
3837
extension: NoExtProvider,
39-
subprotocols: ProtocolRegistry::default(),
38+
subprotocols: SubprotocolRegistry::default(),
4039
}
4140
}
4241
}
@@ -95,9 +94,9 @@ impl<E> WebSocketClientBuilder<E> {
9594
pub fn subprotocols<I>(mut self, subprotocols: I) -> Result<Self, Error>
9695
where
9796
I: IntoIterator,
98-
I::Item: Into<Cow<'static, str>>,
97+
I::Item: Into<String>,
9998
{
100-
self.subprotocols = ProtocolRegistry::new(subprotocols)?;
99+
self.subprotocols = SubprotocolRegistry::new(subprotocols)?;
101100
Ok(self)
102101
}
103102
}
@@ -110,7 +109,7 @@ impl<E> WebSocketClientBuilder<E> {
110109
#[derive(Debug)]
111110
pub struct WebSocketServerBuilder<E> {
112111
config: Option<WebSocketConfig>,
113-
subprotocols: ProtocolRegistry,
112+
subprotocols: SubprotocolRegistry,
114113
extension: E,
115114
}
116115

@@ -119,7 +118,7 @@ impl Default for WebSocketServerBuilder<NoExtProvider> {
119118
WebSocketServerBuilder {
120119
config: None,
121120
extension: NoExtProvider,
122-
subprotocols: ProtocolRegistry::default(),
121+
subprotocols: SubprotocolRegistry::default(),
123122
}
124123
}
125124
}
@@ -168,9 +167,9 @@ impl<E> WebSocketServerBuilder<E> {
168167
pub fn subprotocols<I>(mut self, subprotocols: I) -> Result<Self, Error>
169168
where
170169
I: IntoIterator,
171-
I::Item: Into<Cow<'static, str>>,
170+
I::Item: Into<String>,
172171
{
173-
self.subprotocols = ProtocolRegistry::new(subprotocols)?;
172+
self.subprotocols = SubprotocolRegistry::new(subprotocols)?;
174173
Ok(self)
175174
}
176175
}

ratchet_core/src/errors.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ use crate::protocol::{CloseCodeParseErr, OpCodeParseErr};
1616
use http::header::{HeaderName, InvalidHeaderValue};
1717
use http::status::InvalidStatusCode;
1818
use http::uri::InvalidUri;
19-
use http::StatusCode;
2019
use std::any::Any;
2120
use std::error::Error as StdError;
2221
use std::fmt::{Display, Formatter};
@@ -156,9 +155,12 @@ pub enum HttpError {
156155
#[error("Redirected: `{0}`")]
157156
Redirected(String),
158157
/// The peer returned with a status code other than 101.
159-
#[error("Status code: `{0}`")]
160-
Status(StatusCode),
161-
/// An invalid HTTP version was received in a request.
158+
#[error("Status code: `{0:?}`")]
159+
Status(u16),
160+
/// A request or response was missing its status code.
161+
#[error("Missing status code")]
162+
MissingStatus,
163+
/// An invalid HTTP version was received in a request or response.
162164
#[error("Invalid HTTP version: `{0:?}`")]
163165
HttpVersion(String),
164166
/// A request or response was missing an expected header.
@@ -272,14 +274,11 @@ pub enum CloseCause {
272274
}
273275

274276
/// WebSocket protocol errors.
275-
#[derive(Copy, Clone, Debug, Eq, PartialEq, Error)]
277+
#[derive(Clone, Debug, Eq, PartialEq, Error)]
276278
pub enum ProtocolError {
277279
/// Invalid encoding was received.
278280
#[error("Not valid UTF-8 encoding")]
279281
Encoding,
280-
/// A peer selected a protocol that was not sent.
281-
#[error("Received an unknown subprotocol")]
282-
UnknownProtocol,
283282
/// An invalid OpCode was received.
284283
#[error("Bad OpCode: `{0}`")]
285284
OpCode(OpCodeParseErr),
@@ -313,6 +312,9 @@ pub enum ProtocolError {
313312
/// An invalid control frame was received.
314313
#[error("Received an invalid control frame")]
315314
InvalidControlFrame,
315+
/// Failed to build subprotocol header.
316+
#[error("Invalid subprotocol header: `{0}`")]
317+
InvalidSubprotocolHeader(String),
316318
}
317319

318320
impl From<FromUtf8Error> for Error {

ratchet_core/src/ext.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
use crate::Error;
1616
use bytes::BytesMut;
1717
use http::{HeaderMap, HeaderValue};
18-
use httparse::Header;
1918
use ratchet_ext::{
2019
Extension, ExtensionDecoder, ExtensionEncoder, ExtensionProvider, FrameHeader,
2120
ReunitableExtension, RsvBits, SplittableExtension,
@@ -61,14 +60,14 @@ impl ExtensionProvider for NoExtProvider {
6160

6261
fn negotiate_client(
6362
&self,
64-
_headers: &[Header],
63+
_headers: &HeaderMap,
6564
) -> Result<Option<Self::Extension>, Self::Error> {
6665
Ok(None)
6766
}
6867

6968
fn negotiate_server(
7069
&self,
71-
_headers: &[Header],
70+
_headers: &HeaderMap,
7271
) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
7372
Ok(None)
7473
}

ratchet_core/src/handshake/client/encoding.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ use ratchet_ext::ExtensionProvider;
2222

2323
use crate::errors::{Error, ErrorKind, HttpError};
2424
use crate::handshake::client::Nonce;
25-
use crate::handshake::{
26-
apply_to, ProtocolRegistry, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR,
27-
};
25+
use crate::handshake::{SubprotocolRegistry, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR};
2826

2927
use base64::engine::general_purpose::STANDARD;
3028
use log::error;
@@ -83,7 +81,7 @@ pub struct ValidatedRequest {
8381
pub fn build_request<E>(
8482
request: Request<()>,
8583
extension: &E,
86-
subprotocols: &ProtocolRegistry,
84+
subprotocols: &SubprotocolRegistry,
8785
) -> Result<ValidatedRequest, Error>
8886
where
8987
E: ExtensionProvider,
@@ -176,7 +174,7 @@ where
176174
));
177175
}
178176

179-
apply_to(subprotocols, &mut headers);
177+
subprotocols.apply_to(&mut headers);
180178

181179
if headers.get(SEC_WEBSOCKET_KEY).is_some() {
182180
error!("{} should not be set", SEC_WEBSOCKET_KEY);

ratchet_core/src/handshake/client/mod.rs

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ use crate::errors::{Error, ErrorKind, HttpError};
3030
use crate::handshake::client::encoding::{build_request, encode_request};
3131
use crate::handshake::io::BufferedIo;
3232
use crate::handshake::{
33-
negotiate_response, validate_header, validate_header_value, ParseResult, ProtocolRegistry,
34-
StreamingParser, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR,
33+
validate_header, validate_header_value, ParseResult, StreamingParser, SubprotocolRegistry,
34+
TryFromWrapper, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR,
3535
};
3636
use crate::{
3737
NoExt, NoExtProvider, Role, TryIntoRequest, WebSocket, WebSocketConfig, WebSocketStream,
3838
};
39+
use http::header::LOCATION;
40+
use log::warn;
3941
use ratchet_ext::ExtensionProvider;
4042
use tokio_util::codec::Decoder;
4143

@@ -80,7 +82,7 @@ where
8082
&mut stream,
8183
request.try_into_request()?,
8284
NoExtProvider,
83-
ProtocolRegistry::default(),
85+
SubprotocolRegistry::default(),
8486
&mut read_buffer,
8587
)
8688
.await?;
@@ -98,7 +100,7 @@ pub async fn subscribe_with<S, E, R>(
98100
mut stream: S,
99101
request: R,
100102
extension: E,
101-
subprotocols: ProtocolRegistry,
103+
subprotocols: SubprotocolRegistry,
102104
) -> Result<UpgradedClient<S, E::Extension>, Error>
103105
where
104106
S: WebSocketStream,
@@ -128,7 +130,7 @@ async fn exec_client_handshake<S, E>(
128130
stream: &mut S,
129131
request: Request<()>,
130132
extension: E,
131-
subprotocols: ProtocolRegistry,
133+
subprotocols: SubprotocolRegistry,
132134
buf: &mut BytesMut,
133135
) -> Result<HandshakeResult<E::Extension>, Error>
134136
where
@@ -163,14 +165,14 @@ where
163165
struct ClientHandshake<'s, S, E> {
164166
buffered: BufferedIo<'s, S>,
165167
nonce: Nonce,
166-
subprotocols: ProtocolRegistry,
168+
subprotocols: SubprotocolRegistry,
167169
extension: &'s E,
168170
}
169171

170172
pub struct StreamingResponseParser<'b, E> {
171173
nonce: &'b Nonce,
172174
extension: &'b E,
173-
subprotocols: &'b mut ProtocolRegistry,
175+
subprotocols: &'b mut SubprotocolRegistry,
174176
}
175177

176178
impl<'b, E> Decoder for StreamingResponseParser<'b, E>
@@ -188,11 +190,11 @@ where
188190
} = self;
189191

190192
let mut headers = [httparse::EMPTY_HEADER; 32];
191-
let mut response = Response::new(&mut headers);
193+
let response = Response::new(&mut headers);
192194

193-
match try_parse_response(buf, &mut response, nonce, extension, subprotocols)? {
195+
match try_parse_response(buf, response, nonce, extension, subprotocols)? {
194196
ParseResult::Complete(result, count) => Ok(Some((result, count))),
195-
ParseResult::Partial => {
197+
ParseResult::Partial(response) => {
196198
check_partial_response(&response)?;
197199
Ok(None)
198200
}
@@ -207,7 +209,7 @@ where
207209
{
208210
pub fn new(
209211
socket: &'s mut S,
210-
subprotocols: ProtocolRegistry,
212+
subprotocols: SubprotocolRegistry,
211213
extension: &'s E,
212214
buf: &'s mut BytesMut,
213215
) -> ClientHandshake<'s, S, E> {
@@ -303,84 +305,92 @@ fn check_partial_response(response: &Response) -> Result<(), Error> {
303305
Ok(())
304306
}
305307
Some(code) => match StatusCode::try_from(code) {
306-
Ok(code) => Err(Error::with_cause(ErrorKind::Http, HttpError::Status(code))),
308+
Ok(code) => Err(Error::with_cause(
309+
ErrorKind::Http,
310+
HttpError::Status(code.as_u16()),
311+
)),
307312
Err(_) => Err(Error::with_cause(ErrorKind::Http, BAD_STATUS_CODE)),
308313
},
309314
None => Ok(()),
310315
}
311316
}
312317

313-
fn try_parse_response<'l, E>(
314-
buffer: &'l [u8],
315-
response: &mut Response<'_, 'l>,
318+
fn try_parse_response<'b, E>(
319+
buffer: &'b [u8],
320+
mut response: Response<'b, 'b>,
316321
expected_nonce: &Nonce,
317322
extension: E,
318-
subprotocols: &mut ProtocolRegistry,
319-
) -> Result<ParseResult<HandshakeResult<E::Extension>>, Error>
323+
subprotocols: &mut SubprotocolRegistry,
324+
) -> Result<ParseResult<Response<'b, 'b>, HandshakeResult<E::Extension>>, Error>
320325
where
321326
E: ExtensionProvider,
322327
{
323328
match response.parse(buffer) {
324-
Ok(Status::Complete(count)) => {
325-
parse_response(response, expected_nonce, extension, subprotocols)
326-
.map(|r| ParseResult::Complete(r, count))
327-
}
328-
Ok(Status::Partial) => Ok(ParseResult::Partial),
329+
Ok(Status::Complete(count)) => parse_response(
330+
TryFromWrapper(response).try_into()?,
331+
expected_nonce,
332+
extension,
333+
subprotocols,
334+
)
335+
.map(|r| ParseResult::Complete(r, count)),
336+
Ok(Status::Partial) => Ok(ParseResult::Partial(response)),
329337
Err(e) => Err(e.into()),
330338
}
331339
}
332340

333341
fn parse_response<E>(
334-
response: &Response,
342+
response: http::Response<()>,
335343
expected_nonce: &Nonce,
336344
extension: E,
337-
subprotocols: &mut ProtocolRegistry,
345+
subprotocols: &SubprotocolRegistry,
338346
) -> Result<HandshakeResult<E::Extension>, Error>
339347
where
340348
E: ExtensionProvider,
341349
{
342-
match response.version {
350+
if response.version() < Version::HTTP_11 {
343351
// rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher
344-
Some(1) => {}
345-
_ => {
346-
return Err(Error::with_cause(
347-
ErrorKind::Http,
348-
HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)),
349-
))
350-
}
352+
return Err(Error::with_cause(
353+
ErrorKind::Http,
354+
HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)),
355+
));
351356
}
352357

353-
let raw_status_code = response.code.ok_or_else(|| Error::new(ErrorKind::Http))?;
354-
let status_code = StatusCode::from_u16(raw_status_code)?;
358+
let status_code = response.status();
355359
match status_code {
356360
c if c == StatusCode::SWITCHING_PROTOCOLS => {}
357361
c if c.is_redirection() => {
358-
return match response.headers.iter().find(|h| h.name == header::LOCATION) {
359-
Some(header) => {
362+
return match response.headers().get(LOCATION) {
363+
Some(value) => {
360364
// the value _should_ be valid UTF-8
361-
let location = String::from_utf8(header.value.to_vec())
365+
let location = String::from_utf8(value.as_bytes().to_vec())
362366
.map_err(|_| Error::new(ErrorKind::Http))?;
363367
Err(Error::with_cause(
364368
ErrorKind::Http,
365369
HttpError::Redirected(location),
366370
))
367371
}
368-
None => Err(Error::with_cause(ErrorKind::Http, HttpError::Status(c))),
372+
None => {
373+
warn!("Received a redirection status code with no location header");
374+
Err(Error::with_cause(
375+
ErrorKind::Http,
376+
HttpError::Status(c.as_u16()),
377+
))
378+
}
369379
};
370380
}
371381
status_code => {
372382
return Err(Error::with_cause(
373383
ErrorKind::Http,
374-
HttpError::Status(status_code),
384+
HttpError::Status(status_code.as_u16()),
375385
))
376386
}
377387
}
378388

379-
validate_header_value(response.headers, header::UPGRADE, WEBSOCKET_STR)?;
380-
validate_header_value(response.headers, header::CONNECTION, UPGRADE_STR)?;
389+
validate_header_value(response.headers(), header::UPGRADE, WEBSOCKET_STR)?;
390+
validate_header_value(response.headers(), header::CONNECTION, UPGRADE_STR)?;
381391

382392
validate_header(
383-
response.headers,
393+
response.headers(),
384394
header::SEC_WEBSOCKET_ACCEPT,
385395
|_name, actual| {
386396
let mut digest = Sha1::new();
@@ -397,9 +407,9 @@ where
397407
)?;
398408

399409
Ok(HandshakeResult {
400-
subprotocol: negotiate_response(subprotocols, response)?,
410+
subprotocol: subprotocols.validate_accepted_subprotocol(response.headers())?,
401411
extension: extension
402-
.negotiate_client(response.headers)
412+
.negotiate_client(response.headers())
403413
.map_err(|e| Error::with_cause(ErrorKind::Extension, e))?,
404414
})
405415
}

0 commit comments

Comments
 (0)