Websocket standard and asyncio extensions (#55)

This commit is contained in:
Miguel Grinberg
2022-09-03 20:04:34 +01:00
committed by GitHub
parent ec0f9ba855
commit 2399c29c8a
27 changed files with 1077 additions and 60 deletions

View File

@@ -1,2 +1,4 @@
[run]
omit=src/utemplate/*
omit=
src/microdot_websocket_alt.py
src/microdot_asgi_websocket.py

View File

@@ -208,6 +208,80 @@ Example::
delete_session(req)
return redirect('/')
WebSocket Support
~~~~~~~~~~~~~~~~~
.. list-table::
:align: left
* - Compatibility
- | CPython & MicroPython
* - Required Microdot source files
- | `microdot.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot.py>`_
| `microdot_websocket.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot_websocket.py>`_
* - Required external dependencies
- | None
* - Examples
- | `echo.py <https://github.com/miguelgrinberg/microdot/blob/main/examples/websocket/echo.py>`_
| `echo_wsgi.py <https://github.com/miguelgrinberg/microdot/blob/main/examples/websocket/echo_wsgi.py>`_
The WebSocket extension provides a way for the application to handle WebSocket
requests. The :func:`websocket <microdot_websocket.with_websocket>` decorator
is used to mark a route handler as a WebSocket handler. The handler receives
a WebSocket object as a second argument. The WebSocket object provides
``send()`` and ``receive()`` methods to send and receive messages respectively.
Example::
@app.route('/echo')
@with_websocket
def echo(request, ws):
while True:
message = ws.receive()
ws.send(message)
.. note::
An unsupported *microsoft_websocket_alt.py* module, with the same
interface, is also provided. This module uses the native WebSocket support
in MicroPython that powers the WebREPL, and may provide slightly better
performance for MicroPython low-end boards. This module is not compatible
with CPython.
Asynchronous WebSocket
~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:align: left
* - Compatibility
- | CPython & MicroPython
* - Required Microdot source files
- | `microdot.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot.py>`_
| `microdot_asyncio.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot_asyncio.py>`_
| `microdot_websocket.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot_websocket.py>`_
| `microdot_asyncio_websocket.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot_asyncio_websocket.py>`_
* - Required external dependencies
- | CPython: None
| MicroPython: `uasyncio <https://github.com/micropython/micropython/tree/master/extmod/uasyncio>`_
* - Examples
- | `echo_async.py <https://github.com/miguelgrinberg/microdot/blob/main/examples/websocket/echo_async.py>`_
This extension has the same interface as the synchronous WebSocket extension,
but the ``receive()`` and ``send()`` methods are asynchronous.
.. note::
An unsupported *microsoft_asgi_websocket.py* module, with the same
interface, is also provided. This module must be used instead of
*microsoft_asyncio_websocket.py* when the ASGI support is used. The
`echo_asgi.py <https://github.com/miguelgrinberg/microdot/blob/main/examples/websocket/echo_asgi.py>`_
example shows how to use this module.
Test Client
~~~~~~~~~~~

View File

@@ -0,0 +1 @@
This directory contains WebSocket examples.

View File

@@ -0,0 +1,20 @@
from microdot import Microdot, send_file
from microdot_websocket import with_websocket
app = Microdot()
@app.route('/')
def index(request):
return send_file('index.html')
@app.route('/echo')
@with_websocket
def echo(request, ws):
while True:
data = ws.receive()
ws.send(data)
app.run()

View File

@@ -0,0 +1,17 @@
from microdot_asgi import Microdot, send_file
from microdot_asgi_websocket import with_websocket
app = Microdot()
@app.route('/')
def index(request):
return send_file('index.html')
@app.route('/echo')
@with_websocket
async def echo(request, ws):
while True:
data = await ws.receive()
await ws.send(data)

View File

@@ -0,0 +1,20 @@
from microdot_asyncio import Microdot, send_file
from microdot_asyncio_websocket import with_websocket
app = Microdot()
@app.route('/')
def index(request):
return send_file('index.html')
@app.route('/echo')
@with_websocket
async def echo(request, ws):
while True:
data = await ws.receive()
await ws.send(data)
app.run()

View File

@@ -0,0 +1,17 @@
from microdot_wsgi import Microdot, send_file
from microdot_websocket import with_websocket
app = Microdot()
@app.route('/')
def index(request):
return send_file('index.html')
@app.route('/echo')
@with_websocket
def echo(request, ws):
while True:
data = ws.receive()
ws.send(data)

View File

@@ -0,0 +1,35 @@
<!doctype html>
<html>
<head>
<title>Microdot WebSocket Demo</title>
</head>
<body>
<h1>Microdot WebSocket Demo</h1>
<div id="log"></div>
<br>
<form id="form">
<label for="text">Input: </label>
<input type="text" id="text" autofocus>
</form>
<script>
const log = (text, color) => {
document.getElementById('log').innerHTML += `<span style="color: ${color}">${text}</span><br>`;
};
const socket = new WebSocket('ws://' + location.host + '/echo');
socket.addEventListener('message', ev => {
log('<<< ' + ev.data, 'blue');
});
socket.addEventListener('close', ev => {
log('<<< closed');
});
document.getElementById('form').onsubmit = ev => {
ev.preventDefault();
const textField = document.getElementById('text');
log('>>> ' + textField.value, 'red');
socket.send(textField.value);
textField.value = '';
};
</script>
</body>
</html>

View File

@@ -28,7 +28,11 @@ py_modules =
microdot_utemplate
microdot_jinja
microdot_session
microdot_websocket
microdot_websocket_alt
microdot_asyncio_websocket
microdot_test_client
microdot_asyncio_test_client
microdot_wsgi
microdot_asgi
microdot_asgi_websocket

View File

@@ -51,6 +51,13 @@ except ImportError:
except ImportError: # pragma: no cover
socket = None
MUTED_SOCKET_ERRORS = [
32, # Broken pipe
54, # Connection reset by peer
104, # Connection reset by peer
128, # Operation on closed socket
]
def urldecode(string):
string = string.replace('+', ' ')
@@ -194,7 +201,7 @@ class Request():
pass
def __init__(self, app, client_addr, method, url, http_version, headers,
body=None, stream=None):
body=None, stream=None, sock=None):
#: The application instance to which this request belongs.
self.app = app
#: The address of the client, as a tuple (host, port).
@@ -240,18 +247,21 @@ class Request():
self.body_used = False
self._stream = stream
self.stream_used = False
self.sock = sock
self._json = None
self._form = None
self.after_request_handlers = []
@staticmethod
def create(app, client_stream, client_addr):
def create(app, client_stream, client_addr, client_sock=None):
"""Create a request object.
:param app: The Microdot application instance.
:param client_stream: An input stream from where the request data can
be read.
:param client_addr: The address of the client, as a tuple.
:param client_sock: The low-level socket associated with the request.
This method returns a newly created ``Request`` object.
"""
@@ -273,7 +283,7 @@ class Request():
headers[header] = value
return Request(app, client_addr, method, url, http_version, headers,
stream=client_stream)
stream=client_stream, sock=client_sock)
def _parse_urlencoded(self, urlencoded):
data = MultiDict()
@@ -396,6 +406,10 @@ class Response():
#: ``Content-Type`` header.
default_content_type = 'text/plain'
#: Special response used to signal that a response does not need to be
#: written to the client. Used to exit WebSocket connections cleanly.
already_handled = None
def __init__(self, body='', status_code=200, headers=None, reason=None):
if body is None and status_code == 200:
body = ''
@@ -482,7 +496,7 @@ class Response():
if can_flush: # pragma: no cover
stream.flush()
except OSError as exc: # pragma: no cover
if exc.errno == 32: # errno.EPIPE
if exc.errno in MUTED_SOCKET_ERRORS:
pass
else:
raise
@@ -935,15 +949,16 @@ class Microdot():
req = None
try:
req = Request.create(self, stream, addr)
req = Request.create(self, stream, addr, sock)
except Exception as exc: # pragma: no cover
print_exception(exc)
res = self.dispatch_request(req)
res.write(stream)
if res != Response.already_handled: # pragma: no branch
res.write(stream)
try:
stream.close()
except OSError as exc: # pragma: no cover
if exc.errno == 32: # errno.EPIPE
if exc.errno in MUTED_SOCKET_ERRORS:
pass
else:
raise
@@ -1026,5 +1041,6 @@ class Microdot():
abort = Microdot.abort
Response.already_handled = Response()
redirect = Response.redirect
send_file = Response.send_file

View File

@@ -50,7 +50,7 @@ class Microdot(BaseMicrodot):
async def asgi_app(self, scope, receive, send):
"""An ASGI application."""
if scope['type'] != 'http': # pragma: no cover
if scope['type'] not in ['http', 'websocket']: # pragma: no cover
return
path = scope['path']
if 'query_string' in scope and scope['query_string']:
@@ -62,7 +62,6 @@ class Microdot(BaseMicrodot):
if key.lower() == 'content-length':
content_length = int(value)
body = b''
if content_length and content_length <= Request.max_body_length:
body = b''
more = True
@@ -78,12 +77,13 @@ class Microdot(BaseMicrodot):
req = Request(
self,
(scope['client'][0], scope['client'][1]),
scope['method'],
scope.get('method', 'GET'),
path,
'HTTP/' + scope['http_version'],
headers,
body=body,
stream=stream)
stream=stream,
sock=(receive, send))
req.asgi_scope = scope
res = await self.dispatch_request(req)
@@ -97,6 +97,9 @@ class Microdot(BaseMicrodot):
for v in value:
header_list.append((name, v))
if scope['type'] != 'http': # pragma: no cover
return
await send({'type': 'http.response.start',
'status': res.status_code,
'headers': header_list})

View File

@@ -0,0 +1,86 @@
from microdot_asyncio import Response, abort
from microdot_websocket import WebSocket as BaseWebSocket
class WebSocket(BaseWebSocket):
async def handshake(self):
connect = await self.request.sock[0]()
if connect['type'] != 'websocket.connect':
abort(400)
await self.request.sock[1]({'type': 'websocket.accept'})
async def receive(self):
message = await self.request.sock[0]()
if message['type'] == 'websocket.disconnect':
raise OSError(32, 'Websocket connection closed')
elif message['type'] != 'websocket.receive':
raise OSError(32, 'Websocket message type not supported')
return message.get('bytes', message.get('text'))
async def send(self, data):
if isinstance(data, str):
await self.request.sock[1](
{'type': 'websocket.send', 'text': data})
else:
await self.request.sock[1](
{'type': 'websocket.send', 'bytes': data})
async def close(self):
if not self.closed:
self.closed = True
try:
await self.request.sock[1]({'type': 'websocket.close'})
except: # noqa E722
pass
async def websocket_upgrade(request):
"""Upgrade a request handler to a websocket connection.
This function can be called directly inside a route function to process a
WebSocket upgrade handshake, for example after the user's credentials are
verified. The function returns the websocket object::
@app.route('/echo')
async def echo(request):
if not (await authenticate_user(request)):
abort(401)
ws = await websocket_upgrade(request)
while True:
message = await ws.receive()
await ws.send(message)
"""
ws = WebSocket(request)
await ws.handshake()
@request.after_request
async def after_request(request, response):
return Response.already_handled
return ws
def with_websocket(f):
"""Decorator to make a route a WebSocket endpoint.
This decorator is used to define a route that accepts websocket
connections. The route then receives a websocket object as a second
argument that it can use to send and receive messages::
@app.route('/echo')
@with_websocket
async def echo(request, ws):
while True:
message = await ws.receive()
await ws.send(message)
"""
async def wrapper(request, *args, **kwargs):
ws = await websocket_upgrade(request)
try:
await f(request, ws, *args, **kwargs)
except OSError as exc:
if exc.errno != 32 and exc.errno != 54:
raise
await ws.close()
return ''
return wrapper

View File

@@ -21,6 +21,7 @@ from microdot import print_exception
from microdot import Request as BaseRequest
from microdot import Response as BaseResponse
from microdot import HTTPException
from microdot import MUTED_SOCKET_ERRORS
def _iscoroutine(coro):
@@ -43,22 +44,30 @@ class _AsyncBytesIO:
async def readuntil(self, separator=b'\n'): # pragma: no cover
return self.stream.readuntil(separator=separator)
async def awrite(self, data): # pragma: no cover
return self.stream.write(data)
async def aclose(self): # pragma: no cover
pass
class Request(BaseRequest):
@staticmethod
async def create(app, client_stream, client_addr):
async def create(app, client_reader, client_writer, client_addr):
"""Create a request object.
:param app: The Microdot application instance.
:param client_stream: An input stream from where the request data can
:param client_reader: An input stream from where the request data can
be read.
:param client_writer: An output stream where the response data can be
written.
:param client_addr: The address of the client, as a tuple.
This method is a coroutine. It returns a newly created ``Request``
object.
"""
# request line
line = (await Request._safe_readline(client_stream)).strip().decode()
line = (await Request._safe_readline(client_reader)).strip().decode()
if not line:
return None
method, url, http_version = line.split()
@@ -69,7 +78,7 @@ class Request(BaseRequest):
content_length = 0
while True:
line = (await Request._safe_readline(
client_stream)).strip().decode()
client_reader)).strip().decode()
if line == '':
break
header, value = line.split(':', 1)
@@ -81,14 +90,15 @@ class Request(BaseRequest):
# body
body = b''
if content_length and content_length <= Request.max_body_length:
body = await client_stream.readexactly(content_length)
body = await client_reader.readexactly(content_length)
stream = None
else:
body = b''
stream = client_stream
stream = client_reader
return Request(app, client_addr, method, url, http_version, headers,
body=body, stream=stream)
body=body, stream=stream,
sock=(client_reader, client_writer))
@property
def stream(self):
@@ -119,31 +129,33 @@ class Response(BaseResponse):
default is "OK" for responses with a 200 status code and
"N/A" for any other status codes.
"""
async def write(self, stream):
self.complete()
# status code
reason = self.reason if self.reason is not None else \
('OK' if self.status_code == 200 else 'N/A')
await stream.awrite('HTTP/1.0 {status_code} {reason}\r\n'.format(
status_code=self.status_code, reason=reason).encode())
# headers
for header, value in self.headers.items():
values = value if isinstance(value, list) else [value]
for value in values:
await stream.awrite('{header}: {value}\r\n'.format(
header=header, value=value).encode())
await stream.awrite(b'\r\n')
# body
try:
# status code
reason = self.reason if self.reason is not None else \
('OK' if self.status_code == 200 else 'N/A')
await stream.awrite('HTTP/1.0 {status_code} {reason}\r\n'.format(
status_code=self.status_code, reason=reason).encode())
# headers
for header, value in self.headers.items():
values = value if isinstance(value, list) else [value]
for value in values:
await stream.awrite('{header}: {value}\r\n'.format(
header=header, value=value).encode())
await stream.awrite(b'\r\n')
# body
async for body in self.body_iter():
if isinstance(body, str): # pragma: no cover
body = body.encode()
await stream.awrite(body)
except OSError as exc: # pragma: no cover
if exc.errno == 32 or exc.args[0] == 'Connection lost':
if exc.errno in MUTED_SOCKET_ERRORS or \
exc.args[0] == 'Connection lost':
pass
else:
raise
@@ -301,17 +313,18 @@ class Microdot(BaseMicrodot):
async def handle_request(self, reader, writer):
req = None
try:
req = await Request.create(self, reader,
req = await Request.create(self, reader, writer,
writer.get_extra_info('peername'))
except Exception as exc: # pragma: no cover
print_exception(exc)
res = await self.dispatch_request(req)
await res.write(writer)
if res != Response.already_handled: # pragma: no branch
await res.write(writer)
try:
await writer.aclose()
except OSError as exc: # pragma: no cover
if exc.errno == 32: # errno.EPIPE
if exc.errno in MUTED_SOCKET_ERRORS:
pass
else:
raise
@@ -401,5 +414,6 @@ class Microdot(BaseMicrodot):
abort = Microdot.abort
Response.already_handled = Response()
redirect = Response.redirect
send_file = Response.send_file

View File

@@ -1,6 +1,10 @@
from microdot_asyncio import Request, _AsyncBytesIO
from microdot_asyncio import Request, Response, _AsyncBytesIO
from microdot_test_client import TestClient as BaseTestClient, \
TestResponse as BaseTestResponse
try:
from microdot_asyncio_websocket import WebSocket
except: # pragma: no cover # noqa: E722
WebSocket = None
class TestResponse(BaseTestResponse):
@@ -47,15 +51,24 @@ class TestClient(BaseTestClient):
assert res.status_code == 200
assert res.text == 'Hello, World!'
"""
async def request(self, method, path, headers=None, body=None):
async def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {}
body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers)
request_bytes = self._render_request(method, path, headers, body)
if sock:
reader = sock[0]
reader.buffer = request_bytes
writer = sock[1]
else:
reader = _AsyncBytesIO(request_bytes)
writer = _AsyncBytesIO(b'')
req = await Request.create(self.app, _AsyncBytesIO(request_bytes),
req = await Request.create(self.app, reader, writer,
('127.0.0.1', 1234))
res = await self.app.dispatch_request(req)
if res == Response.already_handled:
return None
res.complete()
self._update_cookies(res)
@@ -124,3 +137,72 @@ class TestClient(BaseTestClient):
:class:`TestResponse <microdot_test_client.TestResponse>` object.
"""
return await self.request('DELETE', path, headers=headers)
async def websocket(self, path, client, headers=None):
"""Send a websocket connection request to the application.
:param path: The request URL.
:param client: A generator function that yields client messages.
:param headers: A dictionary of headers to send with the request.
"""
gen = client()
class FakeWebSocket:
def __init__(self):
self.started = False
self.closed = False
self.buffer = b''
async def _next(self, data=None):
try:
data = (await gen.asend(data)) if hasattr(gen, 'asend') \
else gen.send(data)
except (StopIteration, StopAsyncIteration):
if not self.closed:
self.closed = True
raise OSError(32, 'Websocket connection closed')
return # pragma: no cover
opcode = WebSocket.TEXT if isinstance(data, str) \
else WebSocket.BINARY
return WebSocket._encode_websocket_frame(opcode, data)
async def read(self, n):
if not self.buffer:
self.started = True
self.buffer = await self._next()
data = self.buffer[:n]
self.buffer = self.buffer[n:]
return data
async def readexactly(self, n): # pragma: no cover
return await self.read(n)
async def readline(self):
line = b''
while True:
line += await self.read(1)
if line[-1] in [b'\n', 10]:
break
return line
async def awrite(self, data):
if self.started:
h = WebSocket._parse_frame_header(data[0:2])
if h[3] < 0:
data = data[2 - h[3]:]
else:
data = data[2:]
if h[1] == WebSocket.TEXT:
data = data.decode()
self.buffer = await self._next(data)
ws_headers = {
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
}
ws_headers.update(headers or {})
sock = FakeWebSocket()
return await self.request('GET', path, headers=ws_headers,
sock=(sock, sock))

View File

@@ -0,0 +1,103 @@
from microdot_asyncio import Response
from microdot_websocket import WebSocket as BaseWebSocket
class WebSocket(BaseWebSocket):
async def handshake(self):
response = self._handshake_response()
await self.request.sock[1].awrite(
b'HTTP/1.1 101 Switching Protocols\r\n')
await self.request.sock[1].awrite(b'Upgrade: websocket\r\n')
await self.request.sock[1].awrite(b'Connection: Upgrade\r\n')
await self.request.sock[1].awrite(
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
async def receive(self):
while True:
opcode, payload = await self._read_frame()
send_opcode, data = self._process_websocket_frame(opcode, payload)
if send_opcode: # pragma: no cover
await self.send(send_opcode, data)
elif data: # pragma: no branch
return data
async def send(self, data, opcode=None):
frame = self._encode_websocket_frame(
opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
data)
await self.request.sock[1].awrite(frame)
async def close(self):
if not self.closed: # pragma: no cover
self.closed = True
await self.send(b'', self.CLOSE)
async def _read_frame(self):
header = await self.request.sock[0].read(2)
if len(header) != 2: # pragma: no cover
raise OSError(32, 'Websocket connection closed')
fin, opcode, has_mask, length = self._parse_frame_header(header)
if length == -2:
length = await self.request.sock[0].read(2)
length = int.from_bytes(length, 'big')
elif length == -8:
length = await self.request.sock[0].read(8)
length = int.from_bytes(length, 'big')
if has_mask: # pragma: no cover
mask = await self.request.sock[0].read(4)
payload = await self.request.sock[0].read(length)
if has_mask: # pragma: no cover
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
return opcode, payload
async def websocket_upgrade(request):
"""Upgrade a request handler to a websocket connection.
This function can be called directly inside a route function to process a
WebSocket upgrade handshake, for example after the user's credentials are
verified. The function returns the websocket object::
@app.route('/echo')
async def echo(request):
if not authenticate_user(request):
abort(401)
ws = await websocket_upgrade(request)
while True:
message = await ws.receive()
await ws.send(message)
"""
ws = WebSocket(request)
await ws.handshake()
@request.after_request
async def after_request(request, response):
return Response.already_handled
return ws
def with_websocket(f):
"""Decorator to make a route a WebSocket endpoint.
This decorator is used to define a route that accepts websocket
connections. The route then receives a websocket object as a second
argument that it can use to send and receive messages::
@app.route('/echo')
@with_websocket
async def echo(request, ws):
while True:
message = await ws.receive()
await ws.send(message)
"""
async def wrapper(request, *args, **kwargs):
ws = await websocket_upgrade(request)
try:
await f(request, ws, *args, **kwargs)
await ws.close() # pragma: no cover
except OSError as exc:
if exc.errno not in [32, 54, 104]: # pragma: no cover
raise
return ''
return wrapper

View File

@@ -1,6 +1,10 @@
from io import BytesIO
import json
from microdot import Request
from microdot import Request, Response
try:
from microdot_websocket import WebSocket
except: # pragma: no cover # noqa: E722
WebSocket = None
class TestResponse:
@@ -82,6 +86,8 @@ class TestClient:
assert res.status_code == 200
assert res.text == 'Hello, World!'
"""
__test__ = False # remove this class from pytest's test collection
def __init__(self, app, cookies=None):
self.app = app
self.cookies = cookies or {}
@@ -147,15 +153,17 @@ class TestClient:
else:
self.cookies[cookie_name] = cookie_options[0]
def request(self, method, path, headers=None, body=None):
def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {}
body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers)
request_bytes = self._render_request(method, path, headers, body)
req = Request.create(self.app, BytesIO(request_bytes),
('127.0.0.1', 1234))
('127.0.0.1', 1234), client_sock=sock)
res = self.app.dispatch_request(req)
if res == Response.already_handled:
return None
res.complete()
self._update_cookies(res)
@@ -224,3 +232,59 @@ class TestClient:
:class:`TestResponse <microdot_test_client.TestResponse>` object.
"""
return self.request('DELETE', path, headers=headers)
def websocket(self, path, client, headers=None):
"""Send a websocket connection request to the application.
:param path: The request URL.
:param client: A generator function that yields client messages.
:param headers: A dictionary of headers to send with the request.
"""
gen = client()
class FakeWebSocket:
def __init__(self):
self.started = False
self.closed = False
self.buffer = b''
def _next(self, data=None):
try:
data = gen.send(data)
except StopIteration:
if self.closed: # pragma: no cover
return
self.closed = True
raise OSError(32, 'Websocket connection closed')
opcode = WebSocket.TEXT if isinstance(data, str) \
else WebSocket.BINARY
return WebSocket._encode_websocket_frame(opcode, data)
def recv(self, n):
self.started = True
if not self.buffer:
self.buffer = self._next()
data = self.buffer[:n]
self.buffer = self.buffer[n:]
return data
def send(self, data):
if self.started:
h = WebSocket._parse_frame_header(data[0:2])
if h[3] < 0:
data = data[2 - h[3]:]
else:
data = data[2:]
if h[1] == WebSocket.TEXT:
data = data.decode()
self.buffer = self._next(data)
ws_headers = {
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
}
ws_headers.update(headers or {})
return self.request('GET', path, headers=ws_headers,
sock=FakeWebSocket())

177
src/microdot_websocket.py Normal file
View File

@@ -0,0 +1,177 @@
import binascii
import hashlib
from microdot import Response
class WebSocket:
CONT = 0
TEXT = 1
BINARY = 2
CLOSE = 8
PING = 9
PONG = 10
def __init__(self, request):
self.request = request
self.closed = False
def handshake(self):
response = self._handshake_response()
self.request.sock.send(b'HTTP/1.1 101 Switching Protocols\r\n')
self.request.sock.send(b'Upgrade: websocket\r\n')
self.request.sock.send(b'Connection: Upgrade\r\n')
self.request.sock.send(
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
def receive(self):
while True:
opcode, payload = self._read_frame()
send_opcode, data = self._process_websocket_frame(opcode, payload)
if send_opcode: # pragma: no cover
self.send(send_opcode, data)
elif data: # pragma: no branch
return data
def send(self, data, opcode=None):
frame = self._encode_websocket_frame(
opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
data)
self.request.sock.send(frame)
def close(self):
if not self.closed: # pragma: no cover
self.closed = True
self.send(b'', self.CLOSE)
def _handshake_response(self):
connection = False
upgrade = False
websocket_key = None
for header, value in self.request.headers.items():
h = header.lower()
if h == 'connection':
connection = True
if 'upgrade' not in value.lower():
return self.request.app.abort(400)
elif h == 'upgrade':
upgrade = True
if not value.lower() == 'websocket':
return self.request.app.abort(400)
elif h == 'sec-websocket-key':
websocket_key = value
if not connection or not upgrade or not websocket_key:
return self.request.app.abort(400)
d = hashlib.sha1(websocket_key.encode())
d.update(b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11')
return binascii.b2a_base64(d.digest())[:-1]
@classmethod
def _parse_frame_header(cls, header):
fin = header[0] & 0x80
opcode = header[0] & 0x0f
if fin == 0 or opcode == cls.CONT: # pragma: no cover
raise OSError(32, 'Continuation frames not supported')
has_mask = header[1] & 0x80
length = header[1] & 0x7f
if length == 126:
length = -2
elif length == 127:
length = -8
return fin, opcode, has_mask, length
def _process_websocket_frame(self, opcode, payload):
if opcode == self.TEXT:
payload = payload.decode()
elif opcode == self.BINARY:
pass
elif opcode == self.CLOSE:
raise OSError(32, 'Websocket connection closed')
elif opcode == self.PING:
return self.PONG, payload
elif opcode == self.PONG: # pragma: no branch
return None, None
return None, payload
@classmethod
def _encode_websocket_frame(cls, opcode, payload):
frame = bytearray()
frame.append(0x80 | opcode)
if opcode == cls.TEXT:
payload = payload.encode()
if len(payload) < 126:
frame.append(len(payload))
elif len(payload) < (1 << 16):
frame.append(126)
frame.extend(len(payload).to_bytes(2, 'big'))
else:
frame.append(127)
frame.extend(len(payload).to_bytes(8, 'big'))
frame.extend(payload)
return frame
def _read_frame(self):
header = self.request.sock.recv(2)
if len(header) != 2: # pragma: no cover
raise OSError(32, 'Websocket connection closed')
fin, opcode, has_mask, length = self._parse_frame_header(header)
if length < 0:
length = self.request.sock.recv(-length)
length = int.from_bytes(length, 'big')
if has_mask: # pragma: no cover
mask = self.request.sock.recv(4)
payload = self.request.sock.recv(length)
if has_mask: # pragma: no cover
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
return opcode, payload
def websocket_upgrade(request):
"""Upgrade a request handler to a websocket connection.
This function can be called directly inside a route function to process a
WebSocket upgrade handshake, for example after the user's credentials are
verified. The function returns the websocket object::
@app.route('/echo')
def echo(request):
if not authenticate_user(request):
abort(401)
ws = websocket_upgrade(request)
while True:
message = ws.receive()
ws.send(message)
"""
ws = WebSocket(request)
ws.handshake()
@request.after_request
def after_request(request, response):
return Response.already_handled
return ws
def with_websocket(f):
"""Decorator to make a route a WebSocket endpoint.
This decorator is used to define a route that accepts websocket
connections. The route then receives a websocket object as a second
argument that it can use to send and receive messages::
@app.route('/echo')
@with_websocket
def echo(request, ws):
while True:
message = ws.receive()
ws.send(message)
"""
def wrapper(request, *args, **kwargs):
ws = websocket_upgrade(request)
try:
f(request, ws, *args, **kwargs)
ws.close() # pragma: no cover
except OSError as exc:
if exc.errno not in [32, 54, 104]: # pragma: no cover
raise
return ''
return wrapper

View File

@@ -0,0 +1,114 @@
import binascii
import hashlib
import select
import websocket as _websocket
from microdot import Response
class WebSocket:
CONT = 0
TEXT = 1
BINARY = 2
CLOSE = 8
PING = 9
PONG = 10
def __init__(self, request):
self.request = request
self.poll = select.poll()
self.poll.register(self.request.sock, select.POLLIN)
self.ws = _websocket.websocket(self.request.sock, True)
self.request.sock.setblocking(False)
def handshake(self):
response = self._handshake_response()
self.request.sock.send(b'HTTP/1.1 101 Switching Protocols\r\n')
self.request.sock.send(b'Upgrade: websocket\r\n')
self.request.sock.send(b'Connection: Upgrade\r\n')
self.request.sock.send(
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
def receive(self):
while True:
self.poll.poll()
data = self.ws.read()
if data:
try:
data = data.decode()
except ValueError:
pass
return data
def send(self, data):
self.ws.write(data)
def close(self):
self.poll.unregister(self.request.sock)
self.ws.close()
def _handshake_response(self):
for header, value in self.request.headers.items():
h = header.lower()
if h == 'connection' and not value.lower().startswith('upgrade'):
return self.request.app.abort(400)
elif h == 'upgrade' and not value.lower() == 'websocket':
return self.request.app.abort(400)
elif h == 'sec-websocket-key':
websocket_key = value
if not websocket_key:
return self.request.app.abort(400)
d = hashlib.sha1(websocket_key.encode())
d.update(b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11')
return binascii.b2a_base64(d.digest())[:-1]
def websocket_upgrade(request):
"""Upgrade a request handler to a websocket connection.
This function can be called directly inside a route function to process a
WebSocket upgrade handshake, for example after the user's credentials are
verified. The function returns the websocket object::
@app.route('/echo')
def echo(request):
if not authenticate_user(request):
abort(401)
ws = websocket_upgrade(request)
while True:
message = ws.receive()
ws.send(message)
"""
ws = WebSocket(request)
ws.handshake()
@request.after_request
def after_request(request, response):
return Response.already_handled
return ws
def with_websocket(f):
"""Decorator to make a route a WebSocket endpoint.
This decorator is used to define a route that accepts websocket
connections. The route then receives a websocket object as a second
argument that it can use to send and receive messages::
@app.route('/echo')
@with_websocket
def echo(request, ws):
while True:
message = ws.receive()
ws.send(message)
"""
def wrapper(request, *args, **kwargs):
ws = websocket_upgrade(request)
try:
f(request, ws, *args, **kwargs)
except OSError as exc:
if exc.errno != 32 and exc.errno != 54:
raise
ws.close()
return ''
return wrapper

View File

@@ -27,7 +27,8 @@ class Microdot(BaseMicrodot):
path,
environ['SERVER_PROTOCOL'],
headers,
stream=environ['wsgi.input'])
stream=environ['wsgi.input'],
sock=environ.get('gunicorn.socket'))
req.environ = environ
res = self.dispatch_request(req)

View File

@@ -3,11 +3,12 @@ from .test_request import TestRequest
from .test_response import TestResponse
from .test_url_pattern import TestURLPattern
from .test_microdot import TestMicrodot
from .test_microdot_websocket import TestMicrodotWebSocket
from .test_request_asyncio import TestRequestAsync
from .test_response_asyncio import TestResponseAsync
from .test_microdot_asyncio import TestMicrodotAsync
from .test_microdot_asyncio_websocket import TestMicrodotAsyncWebSocket
from .test_utemplate import TestUTemplate
from .test_session import TestSession

View File

@@ -19,7 +19,7 @@ def _run(coro):
@unittest.skipIf(sys.implementation.name == 'micropython',
'not supported under MicroPython')
class TestUTemplate(unittest.TestCase):
class TestJinja(unittest.TestCase):
def test_render_template(self):
s = render_template('hello.jinja.txt', name='foo')
self.assertEqual(s, 'Hello, foo!')
@@ -44,7 +44,7 @@ class TestUTemplate(unittest.TestCase):
return render_template('hello.jinja.txt', name='foo')
req = _run(RequestAsync.create(
app, get_async_request_fd('GET', '/'), 'addr'))
app, get_async_request_fd('GET', '/'), 'writer', 'addr'))
res = _run(app.dispatch_request(req))
self.assertEqual(res.status_code, 200)

View File

@@ -526,6 +526,17 @@ class TestMicrodot(unittest.TestCase):
self.assertEqual(res.headers['Content-Type'], 'text/plain')
self.assertEqual(res.text, 'foobar')
def test_already_handled_response(self):
app = Microdot()
@app.route('/')
def index(req):
return Response.already_handled
client = TestClient(app)
res = client.get('/')
self.assertEqual(res, None)
def test_mount(self):
subapp = Microdot()

View File

@@ -570,3 +570,14 @@ class TestMicrodotAsync(unittest.TestCase):
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Content-Type'], 'text/plain')
self.assertEqual(res.text, 'foobar')
def test_already_handled_response(self):
app = Microdot()
@app.route('/')
def index(req):
return Response.already_handled
client = TestClient(app)
res = self._run(client.get('/'))
self.assertEqual(res, None)

View File

@@ -0,0 +1,71 @@
import sys
try:
import uasyncio as asyncio
except ImportError:
import asyncio
import unittest
from microdot_asyncio import Microdot
from microdot_asyncio_websocket import with_websocket
from microdot_asyncio_test_client import TestClient
class TestMicrodotAsyncWebSocket(unittest.TestCase):
def _run(self, coro):
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
def test_websocket_echo(self):
app = Microdot()
@app.route('/echo')
@with_websocket
async def index(req, ws):
while True:
data = await ws.receive()
await ws.send(data)
results = []
def ws():
data = yield 'hello'
results.append(data)
data = yield b'bye'
results.append(data)
data = yield b'*' * 300
results.append(data)
data = yield b'+' * 65537
results.append(data)
client = TestClient(app)
res = self._run(client.websocket('/echo', ws))
self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
@unittest.skipIf(sys.implementation.name == 'micropython',
'no support for async generators in MicroPython')
def test_websocket_echo_async_client(self):
app = Microdot()
@app.route('/echo')
@with_websocket
async def index(req, ws):
while True:
data = await ws.receive()
await ws.send(data)
results = []
async def ws():
data = yield 'hello'
results.append(data)
data = yield b'bye'
results.append(data)
data = yield b'*' * 300
results.append(data)
data = yield b'+' * 65537
results.append(data)
client = TestClient(app)
res = self._run(client.websocket('/echo', ws))
self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])

View File

@@ -0,0 +1,73 @@
import unittest
from microdot import Microdot
from microdot_websocket import with_websocket, WebSocket
from microdot_test_client import TestClient
class TestMicrodotWebSocket(unittest.TestCase):
def test_websocket_echo(self):
app = Microdot()
@app.route('/echo')
@with_websocket
def index(req, ws):
while True:
data = ws.receive()
ws.send(data)
results = []
def ws():
data = yield 'hello'
results.append(data)
data = yield b'bye'
results.append(data)
data = yield b'*' * 300
results.append(data)
data = yield b'+' * 65537
results.append(data)
client = TestClient(app)
res = client.websocket('/echo', ws)
self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
def test_bad_websocket_request(self):
app = Microdot()
@app.route('/echo')
@with_websocket
def index(req, ws):
return 'hello'
client = TestClient(app)
res = client.get('/echo')
self.assertEqual(res.status_code, 400)
res = client.get('/echo', headers={'Connection': 'Upgrade'})
self.assertEqual(res.status_code, 400)
res = client.get('/echo', headers={'Connection': 'foo'})
self.assertEqual(res.status_code, 400)
res = client.get('/echo', headers={'Upgrade': 'websocket'})
self.assertEqual(res.status_code, 400)
res = client.get('/echo', headers={'Upgrade': 'bar'})
self.assertEqual(res.status_code, 400)
res = client.get('/echo', headers={'Connection': 'Upgrade',
'Upgrade': 'websocket'})
self.assertEqual(res.status_code, 400)
res = client.get('/echo', headers={'Sec-WebSocket-Key': 'xxx'})
self.assertEqual(res.status_code, 400)
def test_process_websocket_frame(self):
ws = WebSocket(None)
ws.closed = True
self.assertEqual(ws._process_websocket_frame(WebSocket.TEXT, b'foo'),
(None, 'foo'))
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
(None, b'foo'))
self.assertRaises(OSError, ws._process_websocket_frame,
WebSocket.CLOSE, b'')
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
(WebSocket.PONG, b'foo'))
self.assertEqual(ws._process_websocket_frame(WebSocket.PONG, b'foo'),
(None, None))

View File

@@ -16,7 +16,7 @@ def _run(coro):
class TestRequestAsync(unittest.TestCase):
def test_create_request(self):
fd = get_async_request_fd('GET', '/foo')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertEqual(req.app, 'app')
self.assertEqual(req.client_addr, 'addr')
self.assertEqual(req.method, 'GET')
@@ -37,7 +37,7 @@ class TestRequestAsync(unittest.TestCase):
'Content-Type': 'application/json',
'Cookie': 'foo=bar;abc=def',
'Content-Length': '3'}, body='aaa')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertEqual(req.headers, {
'Host': 'example.com:1234',
'Content-Type': 'application/json',
@@ -50,7 +50,7 @@ class TestRequestAsync(unittest.TestCase):
def test_args(self):
fd = get_async_request_fd('GET', '/?foo=bar&abc=def&x=%2f%%')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertEqual(req.query_string, 'foo=bar&abc=def&x=%2f%%')
self.assertEqual(req.args, MultiDict(
{'foo': 'bar', 'abc': 'def', 'x': '/%%'}))
@@ -58,26 +58,26 @@ class TestRequestAsync(unittest.TestCase):
def test_json(self):
fd = get_async_request_fd('GET', '/foo', headers={
'Content-Type': 'application/json'}, body='{"foo":"bar"}')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
json = req.json
self.assertEqual(json, {'foo': 'bar'})
self.assertTrue(req.json is json)
fd = get_async_request_fd('GET', '/foo', headers={
'Content-Type': 'application/json'}, body='[1, "2"]')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertEqual(req.json, [1, '2'])
fd = get_async_request_fd('GET', '/foo', headers={
'Content-Type': 'application/xml'}, body='[1, "2"]')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertIsNone(req.json)
def test_form(self):
fd = get_async_request_fd('GET', '/foo', headers={
'Content-Type': 'application/x-www-form-urlencoded'},
body='foo=bar&abc=def&x=%2f%%')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
form = req.form
self.assertEqual(form, MultiDict(
{'foo': 'bar', 'abc': 'def', 'x': '/%%'}))
@@ -86,7 +86,7 @@ class TestRequestAsync(unittest.TestCase):
fd = get_async_request_fd('GET', '/foo', headers={
'Content-Type': 'application/json'},
body='foo=bar&abc=def&x=%2f%%')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertIsNone(req.form)
def test_large_line(self):
@@ -97,7 +97,7 @@ class TestRequestAsync(unittest.TestCase):
'Content-Type': 'application/x-www-form-urlencoded'},
body='foo=bar&abc=def&x=y')
with self.assertRaises(ValueError):
_run(Request.create('app', fd, 'addr'))
_run(Request.create('app', fd, 'writer', 'addr'))
Request.max_readline = saved_max_readline
@@ -106,7 +106,7 @@ class TestRequestAsync(unittest.TestCase):
'Content-Type': 'application/x-www-form-urlencoded',
'Content-Length': '19'},
body='foo=bar&abc=def&x=y')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertEqual(req.body, b'foo=bar&abc=def&x=y')
data = _run(req.stream.read())
self.assertEqual(data, b'foo=bar&abc=def&x=y')
@@ -121,7 +121,7 @@ class TestRequestAsync(unittest.TestCase):
'Content-Type': 'application/x-www-form-urlencoded',
'Content-Length': '19'},
body='foo=bar&abc=def&x=y')
req = _run(Request.create('app', fd, 'addr'))
req = _run(Request.create('app', fd, 'writer', 'addr'))
self.assertEqual(req.body, b'')
data = _run(req.stream.read())
self.assertEqual(data, b'foo=bar&abc=def&x=y')

View File

@@ -41,7 +41,7 @@ class TestUTemplate(unittest.TestCase):
return render_template('hello.utemplate.txt', name='foo')
req = _run(RequestAsync.create(
app, get_async_request_fd('GET', '/'), 'addr'))
app, get_async_request_fd('GET', '/'), 'writer', 'addr'))
res = _run(app.dispatch_request(req))
self.assertEqual(res.status_code, 200)