Add a limit to WebSocket message size (Fixes #193)
This commit is contained in:
@@ -292,6 +292,8 @@ class TestClient:
|
||||
async def awrite(self, data):
|
||||
if self.started:
|
||||
h = WebSocket._parse_frame_header(data[0:2])
|
||||
if h[1] not in [WebSocket.TEXT, WebSocket.BINARY]:
|
||||
return
|
||||
if h[3] < 0:
|
||||
data = data[2 - h[3]:]
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from microdot import Response
|
||||
from microdot.microdot import MUTED_SOCKET_ERRORS
|
||||
from microdot import Request, Response
|
||||
from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""Exception raised when an error occurs in a WebSocket connection."""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocket:
|
||||
@@ -17,6 +22,18 @@ class WebSocket:
|
||||
PING = 9
|
||||
PONG = 10
|
||||
|
||||
#: Specify the maximum message size that can be received when calling the
|
||||
#: ``receive()`` method. Messages with payloads that are larger than this
|
||||
#: size will be rejected and the connection closed. Set to 0 to disable
|
||||
#: the size check (be aware of potential security issues if you do this),
|
||||
#: or to -1 to use the value set in
|
||||
#: ``Request.max_body_length``. The default is -1.
|
||||
#:
|
||||
#: Example::
|
||||
#:
|
||||
#: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages
|
||||
max_message_length = -1
|
||||
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
self.closed = False
|
||||
@@ -86,7 +103,7 @@ class WebSocket:
|
||||
fin = header[0] & 0x80
|
||||
opcode = header[0] & 0x0f
|
||||
if fin == 0 or opcode == cls.CONT: # pragma: no cover
|
||||
raise OSError(32, 'Continuation frames not supported')
|
||||
raise WebSocketError('Continuation frames not supported')
|
||||
has_mask = header[1] & 0x80
|
||||
length = header[1] & 0x7f
|
||||
if length == 126:
|
||||
@@ -101,7 +118,7 @@ class WebSocket:
|
||||
elif opcode == self.BINARY:
|
||||
pass
|
||||
elif opcode == self.CLOSE:
|
||||
raise OSError(32, 'Websocket connection closed')
|
||||
raise WebSocketError('Websocket connection closed')
|
||||
elif opcode == self.PING:
|
||||
return self.PONG, payload
|
||||
elif opcode == self.PONG: # pragma: no branch
|
||||
@@ -128,7 +145,7 @@ class WebSocket:
|
||||
async def _read_frame(self):
|
||||
header = await self.request.sock[0].read(2)
|
||||
if len(header) != 2: # pragma: no cover
|
||||
raise OSError(32, 'Websocket connection closed')
|
||||
raise WebSocketError('Websocket connection closed')
|
||||
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
||||
if length == -2:
|
||||
length = await self.request.sock[0].read(2)
|
||||
@@ -136,6 +153,10 @@ class WebSocket:
|
||||
elif length == -8:
|
||||
length = await self.request.sock[0].read(8)
|
||||
length = int.from_bytes(length, 'big')
|
||||
max_allowed_length = Request.max_body_length \
|
||||
if self.max_message_length == -1 else self.max_message_length
|
||||
if length > max_allowed_length:
|
||||
raise WebSocketError('Message too large')
|
||||
if has_mask: # pragma: no cover
|
||||
mask = await self.request.sock[0].read(4)
|
||||
payload = await self.request.sock[0].read(length)
|
||||
@@ -175,11 +196,19 @@ def websocket_wrapper(f, upgrade_function):
|
||||
ws = await upgrade_function(request)
|
||||
try:
|
||||
await f(request, ws, *args, **kwargs)
|
||||
await ws.close() # pragma: no cover
|
||||
except OSError as exc:
|
||||
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
|
||||
raise
|
||||
return ''
|
||||
except WebSocketError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
print_exception(exc)
|
||||
finally: # pragma: no cover
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
return Response.already_handled
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from microdot import Microdot
|
||||
from microdot.websocket import with_websocket, WebSocket
|
||||
from microdot import Microdot, Request
|
||||
from microdot.websocket import with_websocket, WebSocket, WebSocketError
|
||||
from microdot.test_client import TestClient
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class TestWebSocket(unittest.TestCase):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
def test_websocket_echo(self):
|
||||
WebSocket.max_message_length = 65537
|
||||
app = Microdot()
|
||||
|
||||
@app.route('/echo')
|
||||
@@ -26,34 +27,10 @@ class TestWebSocket(unittest.TestCase):
|
||||
data = await ws.receive()
|
||||
await ws.send(data)
|
||||
|
||||
results = []
|
||||
|
||||
def ws():
|
||||
data = yield 'hello'
|
||||
results.append(data)
|
||||
data = yield b'bye'
|
||||
results.append(data)
|
||||
data = yield b'*' * 300
|
||||
results.append(data)
|
||||
data = yield b'+' * 65537
|
||||
results.append(data)
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.websocket('/echo', ws))
|
||||
self.assertIsNone(res)
|
||||
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
|
||||
|
||||
@unittest.skipIf(sys.implementation.name == 'micropython',
|
||||
'no support for async generators in MicroPython')
|
||||
def test_websocket_echo_async_client(self):
|
||||
app = Microdot()
|
||||
|
||||
@app.route('/echo')
|
||||
@app.route('/divzero')
|
||||
@with_websocket
|
||||
async def index(req, ws):
|
||||
while True:
|
||||
data = await ws.receive()
|
||||
await ws.send(data)
|
||||
async def divzero(req, ws):
|
||||
1 / 0
|
||||
|
||||
results = []
|
||||
|
||||
@@ -72,6 +49,35 @@ class TestWebSocket(unittest.TestCase):
|
||||
self.assertIsNone(res)
|
||||
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
|
||||
|
||||
res = self._run(client.websocket('/divzero', ws))
|
||||
self.assertIsNone(res)
|
||||
WebSocket.max_message_length = -1
|
||||
|
||||
@unittest.skipIf(sys.implementation.name == 'micropython',
|
||||
'no support for async generators in MicroPython')
|
||||
def test_websocket_large_message(self):
|
||||
saved_max_body_length = Request.max_body_length
|
||||
Request.max_body_length = 10
|
||||
app = Microdot()
|
||||
|
||||
@app.route('/echo')
|
||||
@with_websocket
|
||||
async def index(req, ws):
|
||||
data = await ws.receive()
|
||||
await ws.send(data)
|
||||
|
||||
results = []
|
||||
|
||||
async def ws():
|
||||
data = yield '0123456789abcdef'
|
||||
results.append(data)
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.websocket('/echo', ws))
|
||||
self.assertIsNone(res)
|
||||
self.assertEqual(results, [])
|
||||
Request.max_body_length = saved_max_body_length
|
||||
|
||||
def test_bad_websocket_request(self):
|
||||
app = Microdot()
|
||||
|
||||
@@ -106,7 +112,7 @@ class TestWebSocket(unittest.TestCase):
|
||||
(None, 'foo'))
|
||||
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
|
||||
(None, b'foo'))
|
||||
self.assertRaises(OSError, ws._process_websocket_frame,
|
||||
self.assertRaises(WebSocketError, ws._process_websocket_frame,
|
||||
WebSocket.CLOSE, b'')
|
||||
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
|
||||
(WebSocket.PONG, b'foo'))
|
||||
|
||||
Reference in New Issue
Block a user