@@ -30,12 +30,14 @@ use crate::errors::{Error, ErrorKind, HttpError};
30
30
use crate :: handshake:: client:: encoding:: { build_request, encode_request} ;
31
31
use crate :: handshake:: io:: BufferedIo ;
32
32
use crate :: handshake:: {
33
- negotiate_response , validate_header, validate_header_value, ParseResult , ProtocolRegistry ,
34
- StreamingParser , ACCEPT_KEY , BAD_STATUS_CODE , UPGRADE_STR , WEBSOCKET_STR ,
33
+ validate_header, validate_header_value, ParseResult , StreamingParser , SubprotocolRegistry ,
34
+ TryFromWrapper , ACCEPT_KEY , BAD_STATUS_CODE , UPGRADE_STR , WEBSOCKET_STR ,
35
35
} ;
36
36
use crate :: {
37
37
NoExt , NoExtProvider , Role , TryIntoRequest , WebSocket , WebSocketConfig , WebSocketStream ,
38
38
} ;
39
+ use http:: header:: LOCATION ;
40
+ use log:: warn;
39
41
use ratchet_ext:: ExtensionProvider ;
40
42
use tokio_util:: codec:: Decoder ;
41
43
80
82
& mut stream,
81
83
request. try_into_request ( ) ?,
82
84
NoExtProvider ,
83
- ProtocolRegistry :: default ( ) ,
85
+ SubprotocolRegistry :: default ( ) ,
84
86
& mut read_buffer,
85
87
)
86
88
. await ?;
@@ -98,7 +100,7 @@ pub async fn subscribe_with<S, E, R>(
98
100
mut stream : S ,
99
101
request : R ,
100
102
extension : E ,
101
- subprotocols : ProtocolRegistry ,
103
+ subprotocols : SubprotocolRegistry ,
102
104
) -> Result < UpgradedClient < S , E :: Extension > , Error >
103
105
where
104
106
S : WebSocketStream ,
@@ -128,7 +130,7 @@ async fn exec_client_handshake<S, E>(
128
130
stream : & mut S ,
129
131
request : Request < ( ) > ,
130
132
extension : E ,
131
- subprotocols : ProtocolRegistry ,
133
+ subprotocols : SubprotocolRegistry ,
132
134
buf : & mut BytesMut ,
133
135
) -> Result < HandshakeResult < E :: Extension > , Error >
134
136
where
@@ -163,14 +165,14 @@ where
163
165
struct ClientHandshake < ' s , S , E > {
164
166
buffered : BufferedIo < ' s , S > ,
165
167
nonce : Nonce ,
166
- subprotocols : ProtocolRegistry ,
168
+ subprotocols : SubprotocolRegistry ,
167
169
extension : & ' s E ,
168
170
}
169
171
170
172
pub struct StreamingResponseParser < ' b , E > {
171
173
nonce : & ' b Nonce ,
172
174
extension : & ' b E ,
173
- subprotocols : & ' b mut ProtocolRegistry ,
175
+ subprotocols : & ' b mut SubprotocolRegistry ,
174
176
}
175
177
176
178
impl < ' b , E > Decoder for StreamingResponseParser < ' b , E >
@@ -188,11 +190,11 @@ where
188
190
} = self ;
189
191
190
192
let mut headers = [ httparse:: EMPTY_HEADER ; 32 ] ;
191
- let mut response = Response :: new ( & mut headers) ;
193
+ let response = Response :: new ( & mut headers) ;
192
194
193
- match try_parse_response ( buf, & mut response, nonce, extension, subprotocols) ? {
195
+ match try_parse_response ( buf, response, nonce, extension, subprotocols) ? {
194
196
ParseResult :: Complete ( result, count) => Ok ( Some ( ( result, count) ) ) ,
195
- ParseResult :: Partial => {
197
+ ParseResult :: Partial ( response ) => {
196
198
check_partial_response ( & response) ?;
197
199
Ok ( None )
198
200
}
@@ -207,7 +209,7 @@ where
207
209
{
208
210
pub fn new (
209
211
socket : & ' s mut S ,
210
- subprotocols : ProtocolRegistry ,
212
+ subprotocols : SubprotocolRegistry ,
211
213
extension : & ' s E ,
212
214
buf : & ' s mut BytesMut ,
213
215
) -> ClientHandshake < ' s , S , E > {
@@ -303,84 +305,92 @@ fn check_partial_response(response: &Response) -> Result<(), Error> {
303
305
Ok ( ( ) )
304
306
}
305
307
Some ( code) => match StatusCode :: try_from ( code) {
306
- Ok ( code) => Err ( Error :: with_cause ( ErrorKind :: Http , HttpError :: Status ( code) ) ) ,
308
+ Ok ( code) => Err ( Error :: with_cause (
309
+ ErrorKind :: Http ,
310
+ HttpError :: Status ( code. as_u16 ( ) ) ,
311
+ ) ) ,
307
312
Err ( _) => Err ( Error :: with_cause ( ErrorKind :: Http , BAD_STATUS_CODE ) ) ,
308
313
} ,
309
314
None => Ok ( ( ) ) ,
310
315
}
311
316
}
312
317
313
- fn try_parse_response < ' l , E > (
314
- buffer : & ' l [ u8 ] ,
315
- response : & mut Response < ' _ , ' l > ,
318
+ fn try_parse_response < ' b , E > (
319
+ buffer : & ' b [ u8 ] ,
320
+ mut response : Response < ' b , ' b > ,
316
321
expected_nonce : & Nonce ,
317
322
extension : E ,
318
- subprotocols : & mut ProtocolRegistry ,
319
- ) -> Result < ParseResult < HandshakeResult < E :: Extension > > , Error >
323
+ subprotocols : & mut SubprotocolRegistry ,
324
+ ) -> Result < ParseResult < Response < ' b , ' b > , HandshakeResult < E :: Extension > > , Error >
320
325
where
321
326
E : ExtensionProvider ,
322
327
{
323
328
match response. parse ( buffer) {
324
- Ok ( Status :: Complete ( count) ) => {
325
- parse_response ( response, expected_nonce, extension, subprotocols)
326
- . map ( |r| ParseResult :: Complete ( r, count) )
327
- }
328
- Ok ( Status :: Partial ) => Ok ( ParseResult :: Partial ) ,
329
+ Ok ( Status :: Complete ( count) ) => parse_response (
330
+ TryFromWrapper ( response) . try_into ( ) ?,
331
+ expected_nonce,
332
+ extension,
333
+ subprotocols,
334
+ )
335
+ . map ( |r| ParseResult :: Complete ( r, count) ) ,
336
+ Ok ( Status :: Partial ) => Ok ( ParseResult :: Partial ( response) ) ,
329
337
Err ( e) => Err ( e. into ( ) ) ,
330
338
}
331
339
}
332
340
333
341
fn parse_response < E > (
334
- response : & Response ,
342
+ response : http :: Response < ( ) > ,
335
343
expected_nonce : & Nonce ,
336
344
extension : E ,
337
- subprotocols : & mut ProtocolRegistry ,
345
+ subprotocols : & SubprotocolRegistry ,
338
346
) -> Result < HandshakeResult < E :: Extension > , Error >
339
347
where
340
348
E : ExtensionProvider ,
341
349
{
342
- match response. version {
350
+ if response. version ( ) < Version :: HTTP_11 {
343
351
// rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher
344
- Some ( 1 ) => { }
345
- _ => {
346
- return Err ( Error :: with_cause (
347
- ErrorKind :: Http ,
348
- HttpError :: HttpVersion ( format ! ( "{:?}" , Version :: HTTP_10 ) ) ,
349
- ) )
350
- }
352
+ return Err ( Error :: with_cause (
353
+ ErrorKind :: Http ,
354
+ HttpError :: HttpVersion ( format ! ( "{:?}" , Version :: HTTP_10 ) ) ,
355
+ ) ) ;
351
356
}
352
357
353
- let raw_status_code = response. code . ok_or_else ( || Error :: new ( ErrorKind :: Http ) ) ?;
354
- let status_code = StatusCode :: from_u16 ( raw_status_code) ?;
358
+ let status_code = response. status ( ) ;
355
359
match status_code {
356
360
c if c == StatusCode :: SWITCHING_PROTOCOLS => { }
357
361
c if c. is_redirection ( ) => {
358
- return match response. headers . iter ( ) . find ( |h| h . name == header :: LOCATION ) {
359
- Some ( header ) => {
362
+ return match response. headers ( ) . get ( LOCATION ) {
363
+ Some ( value ) => {
360
364
// the value _should_ be valid UTF-8
361
- let location = String :: from_utf8 ( header . value . to_vec ( ) )
365
+ let location = String :: from_utf8 ( value. as_bytes ( ) . to_vec ( ) )
362
366
. map_err ( |_| Error :: new ( ErrorKind :: Http ) ) ?;
363
367
Err ( Error :: with_cause (
364
368
ErrorKind :: Http ,
365
369
HttpError :: Redirected ( location) ,
366
370
) )
367
371
}
368
- None => Err ( Error :: with_cause ( ErrorKind :: Http , HttpError :: Status ( c) ) ) ,
372
+ None => {
373
+ warn ! ( "Received a redirection status code with no location header" ) ;
374
+ Err ( Error :: with_cause (
375
+ ErrorKind :: Http ,
376
+ HttpError :: Status ( c. as_u16 ( ) ) ,
377
+ ) )
378
+ }
369
379
} ;
370
380
}
371
381
status_code => {
372
382
return Err ( Error :: with_cause (
373
383
ErrorKind :: Http ,
374
- HttpError :: Status ( status_code) ,
384
+ HttpError :: Status ( status_code. as_u16 ( ) ) ,
375
385
) )
376
386
}
377
387
}
378
388
379
- validate_header_value ( response. headers , header:: UPGRADE , WEBSOCKET_STR ) ?;
380
- validate_header_value ( response. headers , header:: CONNECTION , UPGRADE_STR ) ?;
389
+ validate_header_value ( response. headers ( ) , header:: UPGRADE , WEBSOCKET_STR ) ?;
390
+ validate_header_value ( response. headers ( ) , header:: CONNECTION , UPGRADE_STR ) ?;
381
391
382
392
validate_header (
383
- response. headers ,
393
+ response. headers ( ) ,
384
394
header:: SEC_WEBSOCKET_ACCEPT ,
385
395
|_name, actual| {
386
396
let mut digest = Sha1 :: new ( ) ;
@@ -397,9 +407,9 @@ where
397
407
) ?;
398
408
399
409
Ok ( HandshakeResult {
400
- subprotocol : negotiate_response ( subprotocols, response) ?,
410
+ subprotocol : subprotocols. validate_accepted_subprotocol ( response. headers ( ) ) ?,
401
411
extension : extension
402
- . negotiate_client ( response. headers )
412
+ . negotiate_client ( response. headers ( ) )
403
413
. map_err ( |e| Error :: with_cause ( ErrorKind :: Extension , e) ) ?,
404
414
} )
405
415
}
0 commit comments