diff --git a/src/microdot/test_client.py b/src/microdot/test_client.py index 1530ee6..a6d7141 100644 --- a/src/microdot/test_client.py +++ b/src/microdot/test_client.py @@ -292,6 +292,8 @@ class TestClient: async def awrite(self, data): if self.started: h = WebSocket._parse_frame_header(data[0:2]) + if h[1] not in [WebSocket.TEXT, WebSocket.BINARY]: + return if h[3] < 0: data = data[2 - h[3]:] else: diff --git a/src/microdot/websocket.py b/src/microdot/websocket.py index c7b6034..925f7dc 100644 --- a/src/microdot/websocket.py +++ b/src/microdot/websocket.py @@ -1,7 +1,12 @@ import binascii import hashlib -from microdot import Response -from microdot.microdot import MUTED_SOCKET_ERRORS +from microdot import Request, Response +from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception + + +class WebSocketError(Exception): + """Exception raised when an error occurs in a WebSocket connection.""" + pass class WebSocket: @@ -17,6 +22,18 @@ class WebSocket: PING = 9 PONG = 10 + #: Specify the maximum message size that can be received when calling the + #: ``receive()`` method. Messages with payloads that are larger than this + #: size will be rejected and the connection closed. Set to 0 to disable + #: the size check (be aware of potential security issues if you do this), + #: or to -1 to use the value set in + #: ``Request.max_body_length``. The default is -1. + #: + #: Example:: + #: + #: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages + max_message_length = -1 + def __init__(self, request): self.request = request self.closed = False @@ -86,7 +103,7 @@ class WebSocket: fin = header[0] & 0x80 opcode = header[0] & 0x0f if fin == 0 or opcode == cls.CONT: # pragma: no cover - raise OSError(32, 'Continuation frames not supported') + raise WebSocketError('Continuation frames not supported') has_mask = header[1] & 0x80 length = header[1] & 0x7f if length == 126: @@ -101,7 +118,7 @@ class WebSocket: elif opcode == self.BINARY: pass elif opcode == self.CLOSE: - raise OSError(32, 'Websocket connection closed') + raise WebSocketError('Websocket connection closed') elif opcode == self.PING: return self.PONG, payload elif opcode == self.PONG: # pragma: no branch @@ -128,7 +145,7 @@ class WebSocket: async def _read_frame(self): header = await self.request.sock[0].read(2) if len(header) != 2: # pragma: no cover - raise OSError(32, 'Websocket connection closed') + raise WebSocketError('Websocket connection closed') fin, opcode, has_mask, length = self._parse_frame_header(header) if length == -2: length = await self.request.sock[0].read(2) @@ -136,6 +153,10 @@ class WebSocket: elif length == -8: length = await self.request.sock[0].read(8) length = int.from_bytes(length, 'big') + max_allowed_length = Request.max_body_length \ + if self.max_message_length == -1 else self.max_message_length + if length > max_allowed_length: + raise WebSocketError('Message too large') if has_mask: # pragma: no cover mask = await self.request.sock[0].read(4) payload = await self.request.sock[0].read(length) @@ -175,11 +196,19 @@ def websocket_wrapper(f, upgrade_function): ws = await upgrade_function(request) try: await f(request, ws, *args, **kwargs) - await ws.close() # pragma: no cover except OSError as exc: if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover raise - return '' + except WebSocketError: + pass + except Exception as exc: + print_exception(exc) + finally: # pragma: no cover + try: + await ws.close() + except Exception: + pass + return Response.already_handled return wrapper diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 4d2a507..9c20682 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,8 +1,8 @@ import asyncio import sys import unittest -from microdot import Microdot -from microdot.websocket import with_websocket, WebSocket +from microdot import Microdot, Request +from microdot.websocket import with_websocket, WebSocket, WebSocketError from microdot.test_client import TestClient @@ -17,6 +17,7 @@ class TestWebSocket(unittest.TestCase): return self.loop.run_until_complete(coro) def test_websocket_echo(self): + WebSocket.max_message_length = 65537 app = Microdot() @app.route('/echo') @@ -26,34 +27,10 @@ class TestWebSocket(unittest.TestCase): data = await ws.receive() await ws.send(data) - results = [] - - def ws(): - data = yield 'hello' - results.append(data) - data = yield b'bye' - results.append(data) - data = yield b'*' * 300 - results.append(data) - data = yield b'+' * 65537 - results.append(data) - - client = TestClient(app) - res = self._run(client.websocket('/echo', ws)) - self.assertIsNone(res) - self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537]) - - @unittest.skipIf(sys.implementation.name == 'micropython', - 'no support for async generators in MicroPython') - def test_websocket_echo_async_client(self): - app = Microdot() - - @app.route('/echo') + @app.route('/divzero') @with_websocket - async def index(req, ws): - while True: - data = await ws.receive() - await ws.send(data) + async def divzero(req, ws): + 1 / 0 results = [] @@ -72,6 +49,35 @@ class TestWebSocket(unittest.TestCase): self.assertIsNone(res) self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537]) + res = self._run(client.websocket('/divzero', ws)) + self.assertIsNone(res) + WebSocket.max_message_length = -1 + + @unittest.skipIf(sys.implementation.name == 'micropython', + 'no support for async generators in MicroPython') + def test_websocket_large_message(self): + saved_max_body_length = Request.max_body_length + Request.max_body_length = 10 + app = Microdot() + + @app.route('/echo') + @with_websocket + async def index(req, ws): + data = await ws.receive() + await ws.send(data) + + results = [] + + async def ws(): + data = yield '0123456789abcdef' + results.append(data) + + client = TestClient(app) + res = self._run(client.websocket('/echo', ws)) + self.assertIsNone(res) + self.assertEqual(results, []) + Request.max_body_length = saved_max_body_length + def test_bad_websocket_request(self): app = Microdot() @@ -106,7 +112,7 @@ class TestWebSocket(unittest.TestCase): (None, 'foo')) self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'), (None, b'foo')) - self.assertRaises(OSError, ws._process_websocket_frame, + self.assertRaises(WebSocketError, ws._process_websocket_frame, WebSocket.CLOSE, b'') self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'), (WebSocket.PONG, b'foo'))