Skip to content

Commit 7255969

Browse files
committed
Enhance Netty pipeline with WebSocketOnlyHandler to enforce WebSocket-only traffic by rejecting non-WebSocket requests and MqttOverWSHandler to dynamically add MQTT handlers post-WebSocket handshake, ensuring protocol compliance and efficient resource management.
1 parent 376ecbf commit 7255969

File tree

8 files changed

+409
-25
lines changed

8 files changed

+409
-25
lines changed

bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/AbstractMQTTBroker.java

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import com.baidu.bifromq.baseenv.EnvProvider;
1717
import com.baidu.bifromq.baserpc.utils.NettyUtil;
18-
import com.baidu.bifromq.mqtt.handler.ByteBufToWebSocketFrameEncoder;
1918
import com.baidu.bifromq.mqtt.handler.ChannelAttrs;
2019
import com.baidu.bifromq.mqtt.handler.ClientAddrHandler;
2120
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler;
@@ -24,7 +23,8 @@
2423
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler;
2524
import com.baidu.bifromq.mqtt.handler.condition.DirectMemPressureCondition;
2625
import com.baidu.bifromq.mqtt.handler.condition.HeapMemPressureCondition;
27-
import com.baidu.bifromq.mqtt.handler.ws.WebSocketFrameToByteBufDecoder;
26+
import com.baidu.bifromq.mqtt.handler.ws.MqttOverWSHandler;
27+
import com.baidu.bifromq.mqtt.handler.ws.WebSocketOnlyHandler;
2828
import com.baidu.bifromq.mqtt.session.MQTTSessionContext;
2929
import com.google.common.collect.Sets;
3030
import com.google.common.util.concurrent.RateLimiter;
@@ -216,19 +216,11 @@ protected void initChannel(SocketChannel ch) {
216216
pipeline.addLast("httpDecoder", new HttpRequestDecoder());
217217
pipeline.addLast("remoteAddr", remoteAddrHandler);
218218
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
219+
pipeline.addLast("webSocketOnly", new WebSocketOnlyHandler(connBuilder.path()));
219220
pipeline.addLast("webSocketHandler", new WebSocketServerProtocolHandler(connBuilder.path(),
220221
MQTT_SUBPROTOCOL_CSV_LIST));
221-
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder());
222-
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder());
223-
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
224-
// insert PacketFilter here
225-
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(builder.maxBytesInMessage));
226-
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler());
227-
pipeline.addLast(ConditionalRejectHandler.NAME,
228-
new ConditionalRejectHandler(
229-
Sets.newHashSet(DirectMemPressureCondition.INSTANCE, HeapMemPressureCondition.INSTANCE),
230-
sessionContext.eventCollector));
231-
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(builder.connectTimeoutSeconds));
222+
pipeline.addLast("webSocketHandshakeListener", new MqttOverWSHandler(
223+
builder.maxBytesInMessage, builder.connectTimeoutSeconds, sessionContext.eventCollector));
232224
}));
233225
}
234226
});
@@ -247,19 +239,11 @@ protected void initChannel(SocketChannel ch) {
247239
pipeline.addLast("httpDecoder", new HttpRequestDecoder());
248240
pipeline.addLast("remoteAddr", remoteAddrHandler);
249241
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
242+
pipeline.addLast("webSocketOnly", new WebSocketOnlyHandler(connBuilder.path()));
250243
pipeline.addLast("webSocketHandler", new WebSocketServerProtocolHandler(connBuilder.path(),
251244
MQTT_SUBPROTOCOL_CSV_LIST));
252-
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder());
253-
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder());
254-
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
255-
// insert PacketFilter between Encoder
256-
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(builder.maxBytesInMessage));
257-
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler());
258-
pipeline.addLast(ConditionalRejectHandler.NAME,
259-
new ConditionalRejectHandler(
260-
Sets.newHashSet(DirectMemPressureCondition.INSTANCE, HeapMemPressureCondition.INSTANCE),
261-
sessionContext.eventCollector));
262-
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(builder.connectTimeoutSeconds));
245+
pipeline.addLast("webSocketHandshakeListener", new MqttOverWSHandler(
246+
builder.maxBytesInMessage, builder.connectTimeoutSeconds, sessionContext.eventCollector));
263247
}));
264248
}
265249
});
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
* See the License for the specific language governing permissions and limitations under the License.
1212
*/
1313

14-
package com.baidu.bifromq.mqtt.handler;
14+
package com.baidu.bifromq.mqtt.handler.ws;
1515

1616
import io.netty.buffer.ByteBuf;
1717
import io.netty.channel.ChannelHandlerContext;
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
* Unless required by applicable law or agreed to in writing,
9+
* software distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and limitations under the License.
12+
*/
13+
14+
package com.baidu.bifromq.mqtt.handler.ws;
15+
16+
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler;
17+
import com.baidu.bifromq.mqtt.handler.MQTTMessageDebounceHandler;
18+
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler;
19+
import com.baidu.bifromq.mqtt.handler.condition.DirectMemPressureCondition;
20+
import com.baidu.bifromq.mqtt.handler.condition.HeapMemPressureCondition;
21+
import com.baidu.bifromq.plugin.eventcollector.IEventCollector;
22+
import com.google.common.collect.Sets;
23+
import io.netty.channel.ChannelHandlerContext;
24+
import io.netty.channel.ChannelInboundHandlerAdapter;
25+
import io.netty.channel.ChannelPipeline;
26+
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
27+
import io.netty.handler.codec.mqtt.MqttDecoder;
28+
import io.netty.handler.codec.mqtt.MqttEncoder;
29+
30+
/**
31+
* A handler that adds MQTT handlers to the pipeline after the WebSocket handshake is complete.
32+
*/
33+
public class MqttOverWSHandler extends ChannelInboundHandlerAdapter {
34+
private final int maxMQTTConnectPacketSize;
35+
private final int connectTimeoutSeconds;
36+
private final IEventCollector eventCollector;
37+
38+
public MqttOverWSHandler(int maxMQTTConnectPacketSize, int connectTimeoutSeconds, IEventCollector eventCollector) {
39+
this.maxMQTTConnectPacketSize = maxMQTTConnectPacketSize;
40+
this.connectTimeoutSeconds = connectTimeoutSeconds;
41+
this.eventCollector = eventCollector;
42+
}
43+
44+
@Override
45+
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
46+
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
47+
ChannelPipeline pipeline = ctx.pipeline();
48+
// Handshake complete, add MQTT handlers.
49+
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder());
50+
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder());
51+
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
52+
// insert PacketFilter between Encoder
53+
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(maxMQTTConnectPacketSize));
54+
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler());
55+
pipeline.addLast(ConditionalRejectHandler.NAME,
56+
new ConditionalRejectHandler(Sets.newHashSet(DirectMemPressureCondition.INSTANCE,
57+
HeapMemPressureCondition.INSTANCE), eventCollector));
58+
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(connectTimeoutSeconds));
59+
// Remove the handshake listener after adding MQTT handlers.
60+
ctx.pipeline().remove(this);
61+
} else {
62+
super.userEventTriggered(ctx, evt);
63+
}
64+
}
65+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
* Unless required by applicable law or agreed to in writing,
9+
* software distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and limitations under the License.
12+
*/
13+
14+
package com.baidu.bifromq.mqtt.handler.ws;
15+
16+
import io.netty.channel.ChannelFutureListener;
17+
import io.netty.channel.ChannelHandlerContext;
18+
import io.netty.channel.SimpleChannelInboundHandler;
19+
import io.netty.handler.codec.http.DefaultFullHttpResponse;
20+
import io.netty.handler.codec.http.FullHttpRequest;
21+
import io.netty.handler.codec.http.FullHttpResponse;
22+
import io.netty.handler.codec.http.HttpHeaderNames;
23+
import io.netty.handler.codec.http.HttpResponseStatus;
24+
25+
/**
26+
* A simple handler that rejects all requests that are not WebSocket upgrade requests.
27+
*/
28+
public class WebSocketOnlyHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
29+
private final String websocketPath;
30+
31+
public WebSocketOnlyHandler(String websocketPath) {
32+
super(false);
33+
this.websocketPath = websocketPath;
34+
}
35+
36+
@Override
37+
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
38+
if (!req.uri().equals(websocketPath)
39+
||
40+
!req.headers().get(HttpHeaderNames.UPGRADE, "").equalsIgnoreCase("websocket")) {
41+
FullHttpResponse response =
42+
new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.BAD_REQUEST);
43+
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
44+
} else {
45+
// Proceed with the pipeline setup for WebSocket.
46+
ctx.pipeline().remove(this); // Remove the validator after it's used.
47+
ctx.fireChannelRead(req); // Pass the request further if it's valid.
48+
}
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
* Unless required by applicable law or agreed to in writing,
9+
* software distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and limitations under the License.
12+
*/
13+
14+
package com.baidu.bifromq.mqtt.handler.ws;
15+
16+
import static org.testng.Assert.assertEquals;
17+
import static org.testng.Assert.assertFalse;
18+
import static org.testng.Assert.assertNotNull;
19+
import static org.testng.Assert.assertTrue;
20+
21+
import io.netty.buffer.ByteBuf;
22+
import io.netty.buffer.ByteBufUtil;
23+
import io.netty.buffer.Unpooled;
24+
import io.netty.channel.embedded.EmbeddedChannel;
25+
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
26+
import org.testng.annotations.BeforeMethod;
27+
import org.testng.annotations.Test;
28+
29+
public class ByteBufToWebSocketFrameEncoderTest {
30+
private EmbeddedChannel channel;
31+
32+
@BeforeMethod
33+
public void setUp() {
34+
// Initialize channel with the encoder before each test
35+
channel = new EmbeddedChannel(new ByteBufToWebSocketFrameEncoder());
36+
}
37+
38+
@Test
39+
public void testEncode() {
40+
// Creating a test ByteBuf with sample data
41+
ByteBuf input = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5});
42+
43+
// Write the ByteBuf to the channel
44+
input.retain();
45+
assertTrue(channel.writeOutbound(input.duplicate()));
46+
47+
// Read the encoded output from the channel
48+
BinaryWebSocketFrame frame = channel.readOutbound();
49+
50+
assertNotNull(frame);
51+
assertEquals(input.readerIndex(), frame.content().readerIndex());
52+
assertEquals(input.writerIndex(), frame.content().writerIndex());
53+
assertTrue(ByteBufUtil.equals(input, frame.content()));
54+
55+
// Cleanup
56+
frame.release();
57+
58+
assertFalse(channel.finish());
59+
}
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
* Unless required by applicable law or agreed to in writing,
9+
* software distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and limitations under the License.
12+
*/
13+
14+
package com.baidu.bifromq.mqtt.handler.ws;
15+
16+
import static org.mockito.Mockito.mock;
17+
import static org.testng.Assert.assertNotNull;
18+
import static org.testng.Assert.assertNull;
19+
20+
import com.baidu.bifromq.mqtt.handler.ChannelAttrs;
21+
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler;
22+
import com.baidu.bifromq.mqtt.handler.MQTTMessageDebounceHandler;
23+
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler;
24+
import com.baidu.bifromq.mqtt.session.MQTTSessionContext;
25+
import com.baidu.bifromq.plugin.eventcollector.IEventCollector;
26+
import io.netty.channel.Channel;
27+
import io.netty.channel.ChannelInitializer;
28+
import io.netty.channel.embedded.EmbeddedChannel;
29+
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
30+
import io.netty.handler.codec.mqtt.MqttDecoder;
31+
import io.netty.handler.codec.mqtt.MqttEncoder;
32+
import java.net.InetSocketAddress;
33+
import org.testng.annotations.BeforeMethod;
34+
import org.testng.annotations.Test;
35+
36+
public class MqttOverWSHandlerTest {
37+
private EmbeddedChannel channel;
38+
private MQTTSessionContext sessionContext;
39+
private IEventCollector eventCollector;
40+
41+
@BeforeMethod
42+
public void setUp() {
43+
eventCollector = mock(IEventCollector.class);
44+
// Initialize channel with the MqttOverWSHandler
45+
sessionContext = MQTTSessionContext.builder()
46+
.eventCollector(eventCollector)
47+
.build();
48+
channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() {
49+
@Override
50+
protected void initChannel(Channel ch) {
51+
ch.attr(ChannelAttrs.MQTT_SESSION_CTX).set(sessionContext);
52+
ch.attr(ChannelAttrs.PEER_ADDR).set(new InetSocketAddress("127.0.0.1", 8080));
53+
ch.pipeline().addLast(new MqttOverWSHandler(65536, 30, eventCollector));
54+
}
55+
});
56+
}
57+
58+
@Test
59+
public void testMqttHandlerAdditionAfterHandshakeComplete() {
60+
// Simulate a WebSocket handshake completion event
61+
channel.pipeline()
62+
.fireUserEventTriggered(new WebSocketServerProtocolHandler.HandshakeComplete(null, null, null));
63+
64+
// Check if all handlers are added
65+
assertNotNull(channel.pipeline().get(WebSocketFrameToByteBufDecoder.class));
66+
assertNotNull(channel.pipeline().get(ByteBufToWebSocketFrameEncoder.class));
67+
assertNotNull(channel.pipeline().get(MqttEncoder.class));
68+
assertNotNull(channel.pipeline().get(MqttDecoder.class));
69+
assertNotNull(channel.pipeline().get(MQTTMessageDebounceHandler.class));
70+
assertNotNull(channel.pipeline().get(ConditionalRejectHandler.class));
71+
assertNotNull(channel.pipeline().get(MQTTPreludeHandler.class));
72+
73+
// Check that the MqttOverWSHandler itself has been removed from the pipeline
74+
assertNull(channel.pipeline().get(MqttOverWSHandler.class));
75+
}
76+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
* Unless required by applicable law or agreed to in writing,
9+
* software distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and limitations under the License.
12+
*/
13+
14+
package com.baidu.bifromq.mqtt.handler.ws;
15+
16+
import static org.testng.Assert.assertEquals;
17+
import static org.testng.Assert.assertFalse;
18+
import static org.testng.Assert.assertNotNull;
19+
import static org.testng.Assert.assertTrue;
20+
21+
import io.netty.buffer.ByteBuf;
22+
import io.netty.buffer.ByteBufUtil;
23+
import io.netty.buffer.Unpooled;
24+
import io.netty.channel.embedded.EmbeddedChannel;
25+
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
26+
import org.testng.annotations.BeforeMethod;
27+
import org.testng.annotations.Test;
28+
29+
public class WebSocketFrameToByteBufDecoderTest {
30+
private EmbeddedChannel channel;
31+
32+
@BeforeMethod
33+
public void setUp() {
34+
// Initialize channel with the decoder before each test
35+
channel = new EmbeddedChannel(new WebSocketFrameToByteBufDecoder());
36+
}
37+
38+
@Test
39+
public void testDecode() {
40+
// Creating a BinaryWebSocketFrame with sample data
41+
ByteBuf originalContent = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5});
42+
BinaryWebSocketFrame frame = new BinaryWebSocketFrame(originalContent);
43+
44+
// Write the frame to the channel
45+
assertTrue(channel.writeInbound(frame));
46+
47+
// Read the decoded output from the channel
48+
ByteBuf decoded = channel.readInbound();
49+
50+
assertNotNull(decoded);
51+
assertEquals(originalContent.readerIndex(), decoded.readerIndex());
52+
assertEquals(originalContent.writerIndex(), decoded.writerIndex());
53+
assertTrue(ByteBufUtil.equals(originalContent, decoded));
54+
55+
// Cleanup
56+
decoded.release();
57+
58+
assertFalse(channel.finish());
59+
}
60+
}

0 commit comments

Comments
 (0)