Skip to content

Commit dd80ac5

Browse files
committed
feat: support snappy
1 parent 5019579 commit dd80ac5

File tree

5 files changed

+59
-10
lines changed

5 files changed

+59
-10
lines changed

config_c.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
{
22
"server": {
33
"url": "ws://192.168.9.224:18888/websocket_path",
4-
"password": "helloworld"
4+
"password": "helloworld",
5+
"compress": true
56
},
67
"client_name": "windows10_sql",
78
"log_file": "/var/log/nt/nt.log"

entity/client_config_entity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class _ServerEntity(TypedDict):
1313
https: bool
1414
password: str
1515
path: str
16+
compress: bool
1617

1718

1819
class ClientConfigEntity(TypedDict):

run_client.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010
from optparse import OptionParser
1111
from threading import Thread
1212
from typing import List, Set, Dict
13+
try:
14+
import snappy
15+
has_snappy = True
16+
except ModuleNotFoundError:
17+
has_snappy = False
1318

1419
from tornado import ioloop
1520

1621
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1722

1823
from common.speed_limit import SpeedLimiter
19-
from common.websocket import WebSocketException
24+
from common.websocket import WebSocketException, ABNF, WebSocketConnectionClosedException
2025

2126
from client.clear_nonce_task import ClearNonceTask
2227
from client.heart_beat_task import HeatBeatTask
@@ -113,12 +118,23 @@ def __init__(self, ws: websocket.WebSocketApp, tcp_forward_client, heart_beat_ta
113118
self.ws.on_message = self.on_message
114119
self.ws.on_close = self.on_close
115120
self.ws.on_open = self.on_open
121+
self.ws.send = self.send
116122
self.forward_client: TcpForwardClient = tcp_forward_client
117123
self.heart_beat_task = heart_beat_task
118124
self.config_data: ClientConfigEntity = config_data
125+
self.compress_support: bool = config_data['server']['compress']
126+
127+
def send(self, data, opcode=ABNF.OPCODE_TEXT):
128+
if opcode == ABNF.OPCODE_BINARY and self.compress_support:
129+
data = snappy.snappy.compress(data)
130+
if not self.ws.sock or self.ws.sock.send(data, opcode) == 0:
131+
raise WebSocketConnectionClosedException(
132+
"Connection is already closed.")
119133

120134
def on_message(self, ws, message: bytes):
121135
try:
136+
if self.compress_support:
137+
message = snappy.snappy.uncompress(message)
122138
message_data: MessageEntity = NatSerialization.loads(message, ContextUtils.get_password())
123139
start_time = time.time()
124140
time_ = message_data['type_']
@@ -222,7 +238,17 @@ def main():
222238
else:
223239
url += 'ws://'
224240
url += f"{server_config['host']}:{str(server_config['port'])}{server_config['path']}"
241+
config_data['server'].setdefault('compress', False)
242+
compress_support = config_data['server']['compress']
243+
assert isinstance(compress_support, bool)
244+
if compress_support and not has_snappy:
245+
raise Exception('snappy is not installed')
225246
LoggerFactory.get_logger().info(f'start open {url}')
247+
if compress_support:
248+
if '?' in url: # 补充 compress_support 参数
249+
url += '&c=' + json.dumps(compress_support)
250+
else:
251+
url += '?c=' + json.dumps(compress_support)
226252
ws = websocket.WebSocketApp(url)
227253
forward_client = TcpForwardClient(ws)
228254
heart_beat_task = HeatBeatTask(ws)

server/websocket_handler.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import asyncio
22
import json
33
import logging
4-
import socket
54
import time
65
import traceback
76
from asyncio import Lock
8-
from collections import defaultdict
9-
from threading import Thread
10-
from typing import List, Dict, Set, Tuple
7+
from json import JSONDecodeError
8+
from typing import List, Dict, Set
9+
10+
try:
11+
import snappy
12+
has_snappy = True
13+
except ModuleNotFoundError:
14+
has_snappy = False
1115

12-
from tornado.ioloop import IOLoop
1316
from tornado.websocket import WebSocketHandler
1417

1518
from common.nat_serialization import NatSerialization
@@ -33,19 +36,32 @@ class MyWebSocketaHandler(WebSocketHandler):
3336
names: Set[str]
3437
recv_time: float = None
3538

39+
compress_support: bool = False # 是否支持snappy压缩
40+
3641
# handler_to_recv_time: Dict['MyWebSocketaHandler', float] = {}
3742
client_name_to_handler: Dict[str, 'MyWebSocketaHandler'] = {}
3843
lock = Lock()
3944

4045
def open(self, *args: str, **kwargs: str):
4146
self.client_name = None
4247
self.version = None
43-
LoggerFactory.get_logger().info('new open websocket')
48+
try:
49+
self.compress_support = json.loads(self.get_argument('c', 'false'))
50+
except JSONDecodeError:
51+
self.compress_support = False
52+
if self.compress_support and not has_snappy:
53+
msg = 'python-snappy is not installed on the server'
54+
LoggerFactory.get_logger().info(msg)
55+
self.close(reason=msg)
56+
LoggerFactory.get_logger().info(f'new open websocket, compress_support: {self.compress_support}')
4457

4558
async def write_message(self, message, binary=False):
4659
start_time = time.time()
4760
try:
48-
await (super(MyWebSocketaHandler, self).write_message(bytes(message), binary))
61+
byte_message = bytes(message)
62+
if self.compress_support:
63+
byte_message = snappy.snappy.compress(byte_message)
64+
await (super(MyWebSocketaHandler, self).write_message(byte_message, binary))
4965
if LoggerFactory.get_logger().isEnabledFor(logging.DEBUG):
5066
LoggerFactory.get_logger().debug(f'write message cost time {time.time() - start_time}, len: {len(message)}')
5167
return
@@ -55,6 +71,8 @@ async def write_message(self, message, binary=False):
5571
raise
5672

5773
def on_message(self, m_bytes):
74+
if self.compress_support:
75+
m_bytes = snappy.snappy.uncompress(m_bytes)
5876
asyncio.ensure_future(self.on_message_async(m_bytes))
5977

6078
async def on_message_async(self, message):

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@
2727
nt_server = {package_name}.run_server:main
2828
""",
2929
packages=l,
30-
install_requires=['tornado', 'typing_extensions']
30+
install_requires=['tornado', 'typing_extensions'],
31+
extras_require={
32+
"snappy": ["python-snappy"],
33+
},
3134
)

0 commit comments

Comments
 (0)