Make WebSocket internals consistent between TLS and non-TLS (Fixes #61)

This commit is contained in:
Miguel Grinberg
2022-09-18 11:17:57 +01:00
parent f540e04ffe
commit 5693b812ce
3 changed files with 34 additions and 11 deletions

View File

@@ -0,0 +1,23 @@
import ssl
from microdot_asyncio import Microdot, send_file
from microdot_asyncio_websocket import with_websocket
app = Microdot()
@app.route('/')
def index(request):
return send_file('index.html')
@app.route('/echo')
@with_websocket
async def echo(request, ws):
while True:
data = await ws.receive()
await ws.send(data)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain('cert.pem', 'key.pem')
app.run(port=4443, debug=True, ssl=sslctx)

View File

@@ -260,7 +260,7 @@ class TestClient:
else WebSocket.BINARY else WebSocket.BINARY
return WebSocket._encode_websocket_frame(opcode, data) return WebSocket._encode_websocket_frame(opcode, data)
def read(self, n): def recv(self, n):
self.started = True self.started = True
if not self.buffer: if not self.buffer:
self.buffer = self._next() self.buffer = self._next()
@@ -268,7 +268,7 @@ class TestClient:
self.buffer = self.buffer[n:] self.buffer = self.buffer[n:]
return data return data
def write(self, data): def send(self, data):
if self.started: if self.started:
h = WebSocket._parse_frame_header(data[0:2]) h = WebSocket._parse_frame_header(data[0:2])
if h[3] < 0: if h[3] < 0:

View File

@@ -17,10 +17,10 @@ class WebSocket:
def handshake(self): def handshake(self):
response = self._handshake_response() response = self._handshake_response()
self.request.sock.write(b'HTTP/1.1 101 Switching Protocols\r\n') self.request.sock.send(b'HTTP/1.1 101 Switching Protocols\r\n')
self.request.sock.write(b'Upgrade: websocket\r\n') self.request.sock.send(b'Upgrade: websocket\r\n')
self.request.sock.write(b'Connection: Upgrade\r\n') self.request.sock.send(b'Connection: Upgrade\r\n')
self.request.sock.write( self.request.sock.send(
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n') b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
def receive(self): def receive(self):
@@ -36,7 +36,7 @@ class WebSocket:
frame = self._encode_websocket_frame( frame = self._encode_websocket_frame(
opcode or (self.TEXT if isinstance(data, str) else self.BINARY), opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
data) data)
self.request.sock.write(frame) self.request.sock.send(frame)
def close(self): def close(self):
if not self.closed: # pragma: no cover if not self.closed: # pragma: no cover
@@ -110,16 +110,16 @@ class WebSocket:
return frame return frame
def _read_frame(self): def _read_frame(self):
header = self.request.sock.read(2) header = self.request.sock.recv(2)
if len(header) != 2: # pragma: no cover if len(header) != 2: # pragma: no cover
raise OSError(32, 'Websocket connection closed') raise OSError(32, 'Websocket connection closed')
fin, opcode, has_mask, length = self._parse_frame_header(header) fin, opcode, has_mask, length = self._parse_frame_header(header)
if length < 0: if length < 0:
length = self.request.sock.read(-length) length = self.request.sock.recv(-length)
length = int.from_bytes(length, 'big') length = int.from_bytes(length, 'big')
if has_mask: # pragma: no cover if has_mask: # pragma: no cover
mask = self.request.sock.read(4) mask = self.request.sock.recv(4)
payload = self.request.sock.read(length) payload = self.request.sock.recv(length)
if has_mask: # pragma: no cover if has_mask: # pragma: no cover
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload)) payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
return opcode, payload return opcode, payload