121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
import asyncio
|
|
import sys
|
|
import unittest
|
|
from microdot import Microdot, Request
|
|
from microdot.websocket import with_websocket, WebSocket, WebSocketError
|
|
from microdot.test_client import TestClient
|
|
|
|
|
|
class TestWebSocket(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if hasattr(asyncio, 'set_event_loop'):
|
|
asyncio.set_event_loop(asyncio.new_event_loop())
|
|
cls.loop = asyncio.get_event_loop()
|
|
|
|
def _run(self, coro):
|
|
return self.loop.run_until_complete(coro)
|
|
|
|
def test_websocket_echo(self):
|
|
WebSocket.max_message_length = 65537
|
|
app = Microdot()
|
|
|
|
@app.route('/echo')
|
|
@with_websocket
|
|
async def index(req, ws):
|
|
while True:
|
|
data = await ws.receive()
|
|
await ws.send(data)
|
|
|
|
@app.route('/divzero')
|
|
@with_websocket
|
|
async def divzero(req, ws):
|
|
1 / 0
|
|
|
|
results = []
|
|
|
|
async def ws():
|
|
data = yield 'hello'
|
|
results.append(data)
|
|
data = yield b'bye'
|
|
results.append(data)
|
|
data = yield b'*' * 300
|
|
results.append(data)
|
|
data = yield b'+' * 65537
|
|
results.append(data)
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.websocket('/echo', ws))
|
|
self.assertIsNone(res)
|
|
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
|
|
|
|
res = self._run(client.websocket('/divzero', ws))
|
|
self.assertIsNone(res)
|
|
WebSocket.max_message_length = -1
|
|
|
|
@unittest.skipIf(sys.implementation.name == 'micropython',
|
|
'no support for async generators in MicroPython')
|
|
def test_websocket_large_message(self):
|
|
saved_max_body_length = Request.max_body_length
|
|
Request.max_body_length = 10
|
|
app = Microdot()
|
|
|
|
@app.route('/echo')
|
|
@with_websocket
|
|
async def index(req, ws):
|
|
data = await ws.receive()
|
|
await ws.send(data)
|
|
|
|
results = []
|
|
|
|
async def ws():
|
|
data = yield '0123456789abcdef'
|
|
results.append(data)
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.websocket('/echo', ws))
|
|
self.assertIsNone(res)
|
|
self.assertEqual(results, [])
|
|
Request.max_body_length = saved_max_body_length
|
|
|
|
def test_bad_websocket_request(self):
|
|
app = Microdot()
|
|
|
|
@app.route('/echo')
|
|
@with_websocket
|
|
def index(req, ws):
|
|
return 'hello'
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.get('/echo'))
|
|
self.assertEqual(res.status_code, 400)
|
|
res = self._run(client.get('/echo', headers={'Connection': 'Upgrade'}))
|
|
self.assertEqual(res.status_code, 400)
|
|
res = self._run(client.get('/echo', headers={'Connection': 'foo'}))
|
|
self.assertEqual(res.status_code, 400)
|
|
res = self._run(client.get('/echo', headers={'Upgrade': 'websocket'}))
|
|
self.assertEqual(res.status_code, 400)
|
|
res = self._run(client.get('/echo', headers={'Upgrade': 'bar'}))
|
|
self.assertEqual(res.status_code, 400)
|
|
res = self._run(client.get('/echo', headers={'Connection': 'Upgrade',
|
|
'Upgrade': 'websocket'}))
|
|
self.assertEqual(res.status_code, 400)
|
|
res = self._run(client.get(
|
|
'/echo', headers={'Sec-WebSocket-Key': 'xxx'}))
|
|
self.assertEqual(res.status_code, 400)
|
|
|
|
def test_process_websocket_frame(self):
|
|
ws = WebSocket(None)
|
|
ws.closed = True
|
|
|
|
self.assertEqual(ws._process_websocket_frame(WebSocket.TEXT, b'foo'),
|
|
(None, 'foo'))
|
|
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
|
|
(None, b'foo'))
|
|
self.assertRaises(WebSocketError, ws._process_websocket_frame,
|
|
WebSocket.CLOSE, b'')
|
|
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
|
|
(WebSocket.PONG, b'foo'))
|
|
self.assertEqual(ws._process_websocket_frame(WebSocket.PONG, b'foo'),
|
|
(None, None))
|