Add a limit to WebSocket message size (Fixes #193)

This commit is contained in:
Miguel Grinberg
2024-01-03 00:03:34 +00:00
parent b80b6b64d0
commit 5d188e8c0d
3 changed files with 74 additions and 37 deletions

View File

@@ -292,6 +292,8 @@ class TestClient:
async def awrite(self, data): async def awrite(self, data):
if self.started: if self.started:
h = WebSocket._parse_frame_header(data[0:2]) h = WebSocket._parse_frame_header(data[0:2])
if h[1] not in [WebSocket.TEXT, WebSocket.BINARY]:
return
if h[3] < 0: if h[3] < 0:
data = data[2 - h[3]:] data = data[2 - h[3]:]
else: else:

View File

@@ -1,7 +1,12 @@
import binascii import binascii
import hashlib import hashlib
from microdot import Response from microdot import Request, Response
from microdot.microdot import MUTED_SOCKET_ERRORS from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception
class WebSocketError(Exception):
"""Exception raised when an error occurs in a WebSocket connection."""
pass
class WebSocket: class WebSocket:
@@ -17,6 +22,18 @@ class WebSocket:
PING = 9 PING = 9
PONG = 10 PONG = 10
#: Specify the maximum message size that can be received when calling the
#: ``receive()`` method. Messages with payloads that are larger than this
#: size will be rejected and the connection closed. Set to 0 to disable
#: the size check (be aware of potential security issues if you do this),
#: or to -1 to use the value set in
#: ``Request.max_body_length``. The default is -1.
#:
#: Example::
#:
#: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages
max_message_length = -1
def __init__(self, request): def __init__(self, request):
self.request = request self.request = request
self.closed = False self.closed = False
@@ -86,7 +103,7 @@ class WebSocket:
fin = header[0] & 0x80 fin = header[0] & 0x80
opcode = header[0] & 0x0f opcode = header[0] & 0x0f
if fin == 0 or opcode == cls.CONT: # pragma: no cover if fin == 0 or opcode == cls.CONT: # pragma: no cover
raise OSError(32, 'Continuation frames not supported') raise WebSocketError('Continuation frames not supported')
has_mask = header[1] & 0x80 has_mask = header[1] & 0x80
length = header[1] & 0x7f length = header[1] & 0x7f
if length == 126: if length == 126:
@@ -101,7 +118,7 @@ class WebSocket:
elif opcode == self.BINARY: elif opcode == self.BINARY:
pass pass
elif opcode == self.CLOSE: elif opcode == self.CLOSE:
raise OSError(32, 'Websocket connection closed') raise WebSocketError('Websocket connection closed')
elif opcode == self.PING: elif opcode == self.PING:
return self.PONG, payload return self.PONG, payload
elif opcode == self.PONG: # pragma: no branch elif opcode == self.PONG: # pragma: no branch
@@ -128,7 +145,7 @@ class WebSocket:
async def _read_frame(self): async def _read_frame(self):
header = await self.request.sock[0].read(2) header = await self.request.sock[0].read(2)
if len(header) != 2: # pragma: no cover if len(header) != 2: # pragma: no cover
raise OSError(32, 'Websocket connection closed') raise WebSocketError('Websocket connection closed')
fin, opcode, has_mask, length = self._parse_frame_header(header) fin, opcode, has_mask, length = self._parse_frame_header(header)
if length == -2: if length == -2:
length = await self.request.sock[0].read(2) length = await self.request.sock[0].read(2)
@@ -136,6 +153,10 @@ class WebSocket:
elif length == -8: elif length == -8:
length = await self.request.sock[0].read(8) length = await self.request.sock[0].read(8)
length = int.from_bytes(length, 'big') length = int.from_bytes(length, 'big')
max_allowed_length = Request.max_body_length \
if self.max_message_length == -1 else self.max_message_length
if length > max_allowed_length:
raise WebSocketError('Message too large')
if has_mask: # pragma: no cover if has_mask: # pragma: no cover
mask = await self.request.sock[0].read(4) mask = await self.request.sock[0].read(4)
payload = await self.request.sock[0].read(length) payload = await self.request.sock[0].read(length)
@@ -175,11 +196,19 @@ def websocket_wrapper(f, upgrade_function):
ws = await upgrade_function(request) ws = await upgrade_function(request)
try: try:
await f(request, ws, *args, **kwargs) await f(request, ws, *args, **kwargs)
await ws.close() # pragma: no cover
except OSError as exc: except OSError as exc:
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
raise raise
return '' except WebSocketError:
pass
except Exception as exc:
print_exception(exc)
finally: # pragma: no cover
try:
await ws.close()
except Exception:
pass
return Response.already_handled
return wrapper return wrapper

View File

@@ -1,8 +1,8 @@
import asyncio import asyncio
import sys import sys
import unittest import unittest
from microdot import Microdot from microdot import Microdot, Request
from microdot.websocket import with_websocket, WebSocket from microdot.websocket import with_websocket, WebSocket, WebSocketError
from microdot.test_client import TestClient from microdot.test_client import TestClient
@@ -17,6 +17,7 @@ class TestWebSocket(unittest.TestCase):
return self.loop.run_until_complete(coro) return self.loop.run_until_complete(coro)
def test_websocket_echo(self): def test_websocket_echo(self):
WebSocket.max_message_length = 65537
app = Microdot() app = Microdot()
@app.route('/echo') @app.route('/echo')
@@ -26,34 +27,10 @@ class TestWebSocket(unittest.TestCase):
data = await ws.receive() data = await ws.receive()
await ws.send(data) await ws.send(data)
results = [] @app.route('/divzero')
def ws():
data = yield 'hello'
results.append(data)
data = yield b'bye'
results.append(data)
data = yield b'*' * 300
results.append(data)
data = yield b'+' * 65537
results.append(data)
client = TestClient(app)
res = self._run(client.websocket('/echo', ws))
self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
@unittest.skipIf(sys.implementation.name == 'micropython',
'no support for async generators in MicroPython')
def test_websocket_echo_async_client(self):
app = Microdot()
@app.route('/echo')
@with_websocket @with_websocket
async def index(req, ws): async def divzero(req, ws):
while True: 1 / 0
data = await ws.receive()
await ws.send(data)
results = [] results = []
@@ -72,6 +49,35 @@ class TestWebSocket(unittest.TestCase):
self.assertIsNone(res) self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537]) self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
res = self._run(client.websocket('/divzero', ws))
self.assertIsNone(res)
WebSocket.max_message_length = -1
@unittest.skipIf(sys.implementation.name == 'micropython',
'no support for async generators in MicroPython')
def test_websocket_large_message(self):
saved_max_body_length = Request.max_body_length
Request.max_body_length = 10
app = Microdot()
@app.route('/echo')
@with_websocket
async def index(req, ws):
data = await ws.receive()
await ws.send(data)
results = []
async def ws():
data = yield '0123456789abcdef'
results.append(data)
client = TestClient(app)
res = self._run(client.websocket('/echo', ws))
self.assertIsNone(res)
self.assertEqual(results, [])
Request.max_body_length = saved_max_body_length
def test_bad_websocket_request(self): def test_bad_websocket_request(self):
app = Microdot() app = Microdot()
@@ -106,7 +112,7 @@ class TestWebSocket(unittest.TestCase):
(None, 'foo')) (None, 'foo'))
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'), self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
(None, b'foo')) (None, b'foo'))
self.assertRaises(OSError, ws._process_websocket_frame, self.assertRaises(WebSocketError, ws._process_websocket_frame,
WebSocket.CLOSE, b'') WebSocket.CLOSE, b'')
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'), self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
(WebSocket.PONG, b'foo')) (WebSocket.PONG, b'foo'))