Make WebSocket internals consistent between TLS and non-TLS (Fixes #61)
This commit is contained in:
23
examples/tls/echo_async_tls.py
Normal file
23
examples/tls/echo_async_tls.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import ssl
|
||||||
|
from microdot_asyncio import Microdot, send_file
|
||||||
|
from microdot_asyncio_websocket import with_websocket
|
||||||
|
|
||||||
|
app = Microdot()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index(request):
|
||||||
|
return send_file('index.html')
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/echo')
|
||||||
|
@with_websocket
|
||||||
|
async def echo(request, ws):
|
||||||
|
while True:
|
||||||
|
data = await ws.receive()
|
||||||
|
await ws.send(data)
|
||||||
|
|
||||||
|
|
||||||
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
sslctx.load_cert_chain('cert.pem', 'key.pem')
|
||||||
|
app.run(port=4443, debug=True, ssl=sslctx)
|
||||||
@@ -260,7 +260,7 @@ class TestClient:
|
|||||||
else WebSocket.BINARY
|
else WebSocket.BINARY
|
||||||
return WebSocket._encode_websocket_frame(opcode, data)
|
return WebSocket._encode_websocket_frame(opcode, data)
|
||||||
|
|
||||||
def read(self, n):
|
def recv(self, n):
|
||||||
self.started = True
|
self.started = True
|
||||||
if not self.buffer:
|
if not self.buffer:
|
||||||
self.buffer = self._next()
|
self.buffer = self._next()
|
||||||
@@ -268,7 +268,7 @@ class TestClient:
|
|||||||
self.buffer = self.buffer[n:]
|
self.buffer = self.buffer[n:]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def write(self, data):
|
def send(self, data):
|
||||||
if self.started:
|
if self.started:
|
||||||
h = WebSocket._parse_frame_header(data[0:2])
|
h = WebSocket._parse_frame_header(data[0:2])
|
||||||
if h[3] < 0:
|
if h[3] < 0:
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ class WebSocket:
|
|||||||
|
|
||||||
def handshake(self):
|
def handshake(self):
|
||||||
response = self._handshake_response()
|
response = self._handshake_response()
|
||||||
self.request.sock.write(b'HTTP/1.1 101 Switching Protocols\r\n')
|
self.request.sock.send(b'HTTP/1.1 101 Switching Protocols\r\n')
|
||||||
self.request.sock.write(b'Upgrade: websocket\r\n')
|
self.request.sock.send(b'Upgrade: websocket\r\n')
|
||||||
self.request.sock.write(b'Connection: Upgrade\r\n')
|
self.request.sock.send(b'Connection: Upgrade\r\n')
|
||||||
self.request.sock.write(
|
self.request.sock.send(
|
||||||
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
|
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
|
||||||
|
|
||||||
def receive(self):
|
def receive(self):
|
||||||
@@ -36,7 +36,7 @@ class WebSocket:
|
|||||||
frame = self._encode_websocket_frame(
|
frame = self._encode_websocket_frame(
|
||||||
opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
|
opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
|
||||||
data)
|
data)
|
||||||
self.request.sock.write(frame)
|
self.request.sock.send(frame)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self.closed: # pragma: no cover
|
if not self.closed: # pragma: no cover
|
||||||
@@ -110,16 +110,16 @@ class WebSocket:
|
|||||||
return frame
|
return frame
|
||||||
|
|
||||||
def _read_frame(self):
|
def _read_frame(self):
|
||||||
header = self.request.sock.read(2)
|
header = self.request.sock.recv(2)
|
||||||
if len(header) != 2: # pragma: no cover
|
if len(header) != 2: # pragma: no cover
|
||||||
raise OSError(32, 'Websocket connection closed')
|
raise OSError(32, 'Websocket connection closed')
|
||||||
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
||||||
if length < 0:
|
if length < 0:
|
||||||
length = self.request.sock.read(-length)
|
length = self.request.sock.recv(-length)
|
||||||
length = int.from_bytes(length, 'big')
|
length = int.from_bytes(length, 'big')
|
||||||
if has_mask: # pragma: no cover
|
if has_mask: # pragma: no cover
|
||||||
mask = self.request.sock.read(4)
|
mask = self.request.sock.recv(4)
|
||||||
payload = self.request.sock.read(length)
|
payload = self.request.sock.recv(length)
|
||||||
if has_mask: # pragma: no cover
|
if has_mask: # pragma: no cover
|
||||||
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
|
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
|
||||||
return opcode, payload
|
return opcode, payload
|
||||||
|
|||||||
Reference in New Issue
Block a user