Add a limit to WebSocket message size (Fixes #193)
This commit is contained in:
@@ -292,6 +292,8 @@ class TestClient:
|
|||||||
async def awrite(self, data):
|
async def awrite(self, data):
|
||||||
if self.started:
|
if self.started:
|
||||||
h = WebSocket._parse_frame_header(data[0:2])
|
h = WebSocket._parse_frame_header(data[0:2])
|
||||||
|
if h[1] not in [WebSocket.TEXT, WebSocket.BINARY]:
|
||||||
|
return
|
||||||
if h[3] < 0:
|
if h[3] < 0:
|
||||||
data = data[2 - h[3]:]
|
data = data[2 - h[3]:]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
import binascii
|
import binascii
|
||||||
import hashlib
|
import hashlib
|
||||||
from microdot import Response
|
from microdot import Request, Response
|
||||||
from microdot.microdot import MUTED_SOCKET_ERRORS
|
from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketError(Exception):
|
||||||
|
"""Exception raised when an error occurs in a WebSocket connection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WebSocket:
|
class WebSocket:
|
||||||
@@ -17,6 +22,18 @@ class WebSocket:
|
|||||||
PING = 9
|
PING = 9
|
||||||
PONG = 10
|
PONG = 10
|
||||||
|
|
||||||
|
#: Specify the maximum message size that can be received when calling the
|
||||||
|
#: ``receive()`` method. Messages with payloads that are larger than this
|
||||||
|
#: size will be rejected and the connection closed. Set to 0 to disable
|
||||||
|
#: the size check (be aware of potential security issues if you do this),
|
||||||
|
#: or to -1 to use the value set in
|
||||||
|
#: ``Request.max_body_length``. The default is -1.
|
||||||
|
#:
|
||||||
|
#: Example::
|
||||||
|
#:
|
||||||
|
#: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages
|
||||||
|
max_message_length = -1
|
||||||
|
|
||||||
def __init__(self, request):
|
def __init__(self, request):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.closed = False
|
self.closed = False
|
||||||
@@ -86,7 +103,7 @@ class WebSocket:
|
|||||||
fin = header[0] & 0x80
|
fin = header[0] & 0x80
|
||||||
opcode = header[0] & 0x0f
|
opcode = header[0] & 0x0f
|
||||||
if fin == 0 or opcode == cls.CONT: # pragma: no cover
|
if fin == 0 or opcode == cls.CONT: # pragma: no cover
|
||||||
raise OSError(32, 'Continuation frames not supported')
|
raise WebSocketError('Continuation frames not supported')
|
||||||
has_mask = header[1] & 0x80
|
has_mask = header[1] & 0x80
|
||||||
length = header[1] & 0x7f
|
length = header[1] & 0x7f
|
||||||
if length == 126:
|
if length == 126:
|
||||||
@@ -101,7 +118,7 @@ class WebSocket:
|
|||||||
elif opcode == self.BINARY:
|
elif opcode == self.BINARY:
|
||||||
pass
|
pass
|
||||||
elif opcode == self.CLOSE:
|
elif opcode == self.CLOSE:
|
||||||
raise OSError(32, 'Websocket connection closed')
|
raise WebSocketError('Websocket connection closed')
|
||||||
elif opcode == self.PING:
|
elif opcode == self.PING:
|
||||||
return self.PONG, payload
|
return self.PONG, payload
|
||||||
elif opcode == self.PONG: # pragma: no branch
|
elif opcode == self.PONG: # pragma: no branch
|
||||||
@@ -128,7 +145,7 @@ class WebSocket:
|
|||||||
async def _read_frame(self):
|
async def _read_frame(self):
|
||||||
header = await self.request.sock[0].read(2)
|
header = await self.request.sock[0].read(2)
|
||||||
if len(header) != 2: # pragma: no cover
|
if len(header) != 2: # pragma: no cover
|
||||||
raise OSError(32, 'Websocket connection closed')
|
raise WebSocketError('Websocket connection closed')
|
||||||
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
||||||
if length == -2:
|
if length == -2:
|
||||||
length = await self.request.sock[0].read(2)
|
length = await self.request.sock[0].read(2)
|
||||||
@@ -136,6 +153,10 @@ class WebSocket:
|
|||||||
elif length == -8:
|
elif length == -8:
|
||||||
length = await self.request.sock[0].read(8)
|
length = await self.request.sock[0].read(8)
|
||||||
length = int.from_bytes(length, 'big')
|
length = int.from_bytes(length, 'big')
|
||||||
|
max_allowed_length = Request.max_body_length \
|
||||||
|
if self.max_message_length == -1 else self.max_message_length
|
||||||
|
if length > max_allowed_length:
|
||||||
|
raise WebSocketError('Message too large')
|
||||||
if has_mask: # pragma: no cover
|
if has_mask: # pragma: no cover
|
||||||
mask = await self.request.sock[0].read(4)
|
mask = await self.request.sock[0].read(4)
|
||||||
payload = await self.request.sock[0].read(length)
|
payload = await self.request.sock[0].read(length)
|
||||||
@@ -175,11 +196,19 @@ def websocket_wrapper(f, upgrade_function):
|
|||||||
ws = await upgrade_function(request)
|
ws = await upgrade_function(request)
|
||||||
try:
|
try:
|
||||||
await f(request, ws, *args, **kwargs)
|
await f(request, ws, *args, **kwargs)
|
||||||
await ws.close() # pragma: no cover
|
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
|
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
|
||||||
raise
|
raise
|
||||||
return ''
|
except WebSocketError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
print_exception(exc)
|
||||||
|
finally: # pragma: no cover
|
||||||
|
try:
|
||||||
|
await ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return Response.already_handled
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from microdot import Microdot
|
from microdot import Microdot, Request
|
||||||
from microdot.websocket import with_websocket, WebSocket
|
from microdot.websocket import with_websocket, WebSocket, WebSocketError
|
||||||
from microdot.test_client import TestClient
|
from microdot.test_client import TestClient
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +17,7 @@ class TestWebSocket(unittest.TestCase):
|
|||||||
return self.loop.run_until_complete(coro)
|
return self.loop.run_until_complete(coro)
|
||||||
|
|
||||||
def test_websocket_echo(self):
|
def test_websocket_echo(self):
|
||||||
|
WebSocket.max_message_length = 65537
|
||||||
app = Microdot()
|
app = Microdot()
|
||||||
|
|
||||||
@app.route('/echo')
|
@app.route('/echo')
|
||||||
@@ -26,34 +27,10 @@ class TestWebSocket(unittest.TestCase):
|
|||||||
data = await ws.receive()
|
data = await ws.receive()
|
||||||
await ws.send(data)
|
await ws.send(data)
|
||||||
|
|
||||||
results = []
|
@app.route('/divzero')
|
||||||
|
|
||||||
def ws():
|
|
||||||
data = yield 'hello'
|
|
||||||
results.append(data)
|
|
||||||
data = yield b'bye'
|
|
||||||
results.append(data)
|
|
||||||
data = yield b'*' * 300
|
|
||||||
results.append(data)
|
|
||||||
data = yield b'+' * 65537
|
|
||||||
results.append(data)
|
|
||||||
|
|
||||||
client = TestClient(app)
|
|
||||||
res = self._run(client.websocket('/echo', ws))
|
|
||||||
self.assertIsNone(res)
|
|
||||||
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.implementation.name == 'micropython',
|
|
||||||
'no support for async generators in MicroPython')
|
|
||||||
def test_websocket_echo_async_client(self):
|
|
||||||
app = Microdot()
|
|
||||||
|
|
||||||
@app.route('/echo')
|
|
||||||
@with_websocket
|
@with_websocket
|
||||||
async def index(req, ws):
|
async def divzero(req, ws):
|
||||||
while True:
|
1 / 0
|
||||||
data = await ws.receive()
|
|
||||||
await ws.send(data)
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -72,6 +49,35 @@ class TestWebSocket(unittest.TestCase):
|
|||||||
self.assertIsNone(res)
|
self.assertIsNone(res)
|
||||||
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
|
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
|
||||||
|
|
||||||
|
res = self._run(client.websocket('/divzero', ws))
|
||||||
|
self.assertIsNone(res)
|
||||||
|
WebSocket.max_message_length = -1
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.implementation.name == 'micropython',
|
||||||
|
'no support for async generators in MicroPython')
|
||||||
|
def test_websocket_large_message(self):
|
||||||
|
saved_max_body_length = Request.max_body_length
|
||||||
|
Request.max_body_length = 10
|
||||||
|
app = Microdot()
|
||||||
|
|
||||||
|
@app.route('/echo')
|
||||||
|
@with_websocket
|
||||||
|
async def index(req, ws):
|
||||||
|
data = await ws.receive()
|
||||||
|
await ws.send(data)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def ws():
|
||||||
|
data = yield '0123456789abcdef'
|
||||||
|
results.append(data)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
res = self._run(client.websocket('/echo', ws))
|
||||||
|
self.assertIsNone(res)
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
Request.max_body_length = saved_max_body_length
|
||||||
|
|
||||||
def test_bad_websocket_request(self):
|
def test_bad_websocket_request(self):
|
||||||
app = Microdot()
|
app = Microdot()
|
||||||
|
|
||||||
@@ -106,7 +112,7 @@ class TestWebSocket(unittest.TestCase):
|
|||||||
(None, 'foo'))
|
(None, 'foo'))
|
||||||
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
|
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
|
||||||
(None, b'foo'))
|
(None, b'foo'))
|
||||||
self.assertRaises(OSError, ws._process_websocket_frame,
|
self.assertRaises(WebSocketError, ws._process_websocket_frame,
|
||||||
WebSocket.CLOSE, b'')
|
WebSocket.CLOSE, b'')
|
||||||
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
|
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
|
||||||
(WebSocket.PONG, b'foo'))
|
(WebSocket.PONG, b'foo'))
|
||||||
|
|||||||
Reference in New Issue
Block a user