Make WebSocket internals consistent between TLS and non-TLS (Fixes #61)
This commit is contained in:
23
examples/tls/echo_async_tls.py
Normal file
23
examples/tls/echo_async_tls.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import ssl
|
||||
from microdot_asyncio import Microdot, send_file
|
||||
from microdot_asyncio_websocket import with_websocket
|
||||
|
||||
app = Microdot()
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def index(request):
|
||||
return send_file('index.html')
|
||||
|
||||
|
||||
@app.route('/echo')
|
||||
@with_websocket
|
||||
async def echo(request, ws):
|
||||
while True:
|
||||
data = await ws.receive()
|
||||
await ws.send(data)
|
||||
|
||||
|
||||
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
sslctx.load_cert_chain('cert.pem', 'key.pem')
|
||||
app.run(port=4443, debug=True, ssl=sslctx)
|
||||
@@ -260,7 +260,7 @@ class TestClient:
|
||||
else WebSocket.BINARY
|
||||
return WebSocket._encode_websocket_frame(opcode, data)
|
||||
|
||||
def read(self, n):
|
||||
def recv(self, n):
|
||||
self.started = True
|
||||
if not self.buffer:
|
||||
self.buffer = self._next()
|
||||
@@ -268,7 +268,7 @@ class TestClient:
|
||||
self.buffer = self.buffer[n:]
|
||||
return data
|
||||
|
||||
def write(self, data):
|
||||
def send(self, data):
|
||||
if self.started:
|
||||
h = WebSocket._parse_frame_header(data[0:2])
|
||||
if h[3] < 0:
|
||||
|
||||
@@ -17,10 +17,10 @@ class WebSocket:
|
||||
|
||||
def handshake(self):
|
||||
response = self._handshake_response()
|
||||
self.request.sock.write(b'HTTP/1.1 101 Switching Protocols\r\n')
|
||||
self.request.sock.write(b'Upgrade: websocket\r\n')
|
||||
self.request.sock.write(b'Connection: Upgrade\r\n')
|
||||
self.request.sock.write(
|
||||
self.request.sock.send(b'HTTP/1.1 101 Switching Protocols\r\n')
|
||||
self.request.sock.send(b'Upgrade: websocket\r\n')
|
||||
self.request.sock.send(b'Connection: Upgrade\r\n')
|
||||
self.request.sock.send(
|
||||
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
|
||||
|
||||
def receive(self):
|
||||
@@ -36,7 +36,7 @@ class WebSocket:
|
||||
frame = self._encode_websocket_frame(
|
||||
opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
|
||||
data)
|
||||
self.request.sock.write(frame)
|
||||
self.request.sock.send(frame)
|
||||
|
||||
def close(self):
|
||||
if not self.closed: # pragma: no cover
|
||||
@@ -110,16 +110,16 @@ class WebSocket:
|
||||
return frame
|
||||
|
||||
def _read_frame(self):
|
||||
header = self.request.sock.read(2)
|
||||
header = self.request.sock.recv(2)
|
||||
if len(header) != 2: # pragma: no cover
|
||||
raise OSError(32, 'Websocket connection closed')
|
||||
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
||||
if length < 0:
|
||||
length = self.request.sock.read(-length)
|
||||
length = self.request.sock.recv(-length)
|
||||
length = int.from_bytes(length, 'big')
|
||||
if has_mask: # pragma: no cover
|
||||
mask = self.request.sock.read(4)
|
||||
payload = self.request.sock.read(length)
|
||||
mask = self.request.sock.recv(4)
|
||||
payload = self.request.sock.recv(length)
|
||||
if has_mask: # pragma: no cover
|
||||
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
|
||||
return opcode, payload
|
||||
|
||||
Reference in New Issue
Block a user