Support for SSE responses in the test client
This commit is contained in:
@@ -63,6 +63,8 @@ def sse_response(request, event_function, *args, **kwargs):
|
||||
async def sse_task_wrapper():
|
||||
try:
|
||||
await event_function(request, sse, *args, **kwargs)
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
pass
|
||||
except Exception as exc:
|
||||
# the SSE task raised an exception so we need to pass it to the
|
||||
# main route so that it is re-raised there
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from microdot.microdot import Request, Response, AsyncBytesIO
|
||||
|
||||
try:
|
||||
@@ -32,6 +33,11 @@ class TestResponse:
|
||||
#: The body of the JSON response, decoded to a dictionary or list. Set
|
||||
#: ``Note`` if the response does not have a JSON payload.
|
||||
self.json = None
|
||||
#: The body of the SSE response, decoded to a list of events, each
|
||||
#: given as a dictionary with a ``data`` key and optionally also
|
||||
#: ``event`` and ``id`` keys. Set to ``None`` if the response does not
|
||||
#: have an SSE payload.
|
||||
self.events = None
|
||||
|
||||
def _initialize_response(self, res):
|
||||
self.status_code = res.status_code
|
||||
@@ -41,10 +47,13 @@ class TestResponse:
|
||||
async def _initialize_body(self, res):
|
||||
self.body = b''
|
||||
iter = res.body_iter()
|
||||
async for body in iter: # pragma: no branch
|
||||
if isinstance(body, str):
|
||||
body = body.encode()
|
||||
self.body += body
|
||||
try:
|
||||
async for body in iter: # pragma: no branch
|
||||
if isinstance(body, str):
|
||||
body = body.encode()
|
||||
self.body += body
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
pass
|
||||
if hasattr(iter, 'aclose'): # pragma: no branch
|
||||
await iter.aclose()
|
||||
|
||||
@@ -60,6 +69,32 @@ class TestResponse:
|
||||
if content_type.split(';')[0] == 'application/json':
|
||||
self.json = json.loads(self.text)
|
||||
|
||||
def _process_sse_body(self):
|
||||
if 'Content-Type' in self.headers: # pragma: no branch
|
||||
content_type = self.headers['Content-Type']
|
||||
if content_type.split(';')[0] == 'text/event-stream':
|
||||
self.events = []
|
||||
for sse_event in self.body.split(b'\n\n'):
|
||||
data = None
|
||||
event = None
|
||||
event_id = None
|
||||
for line in sse_event.split(b'\n'):
|
||||
if line.startswith(b'data:'):
|
||||
data = line[5:].strip()
|
||||
elif line.startswith(b'event:'):
|
||||
event = line[6:].strip().decode()
|
||||
elif line.startswith(b'id:'):
|
||||
event_id = line[3:].strip().decode()
|
||||
if data:
|
||||
data_json = None
|
||||
try:
|
||||
data_json = json.loads(data)
|
||||
except ValueError:
|
||||
pass
|
||||
self.events.append({
|
||||
"data": data, "data_json": data_json,
|
||||
"event": event, "event_id": event_id})
|
||||
|
||||
@classmethod
|
||||
async def create(cls, res):
|
||||
test_res = cls()
|
||||
@@ -68,6 +103,7 @@ class TestResponse:
|
||||
await test_res._initialize_body(res)
|
||||
test_res._process_text_body()
|
||||
test_res._process_json_body()
|
||||
test_res._process_sse_body()
|
||||
return test_res
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,31 @@ class TestWebSocket(unittest.TestCase):
|
||||
'data: [42, "foo", "bar"]\n\n'
|
||||
'data: foo\n\n'
|
||||
'data: foo\n\n'))
|
||||
self.assertEqual(len(response.events), 8)
|
||||
self.assertEqual(response.events[0], {
|
||||
'data': b'foo', 'data_json': None, 'event': None,
|
||||
'event_id': None})
|
||||
self.assertEqual(response.events[1], {
|
||||
'data': b'bar', 'data_json': None, 'event': 'test',
|
||||
'event_id': None})
|
||||
self.assertEqual(response.events[2], {
|
||||
'data': b'bar', 'data_json': None, 'event': 'test',
|
||||
'event_id': 'id42'})
|
||||
self.assertEqual(response.events[3], {
|
||||
'data': b'bar', 'data_json': None, 'event': None,
|
||||
'event_id': 'id42'})
|
||||
self.assertEqual(response.events[4], {
|
||||
'data': b'{"foo": "bar"}', 'data_json': {'foo': 'bar'},
|
||||
'event': None, 'event_id': None})
|
||||
self.assertEqual(response.events[5], {
|
||||
'data': b'[42, "foo", "bar"]', 'data_json': [42, 'foo', 'bar'],
|
||||
'event': None, 'event_id': None})
|
||||
self.assertEqual(response.events[6], {
|
||||
'data': b'foo', 'data_json': None, 'event': None,
|
||||
'event_id': None})
|
||||
self.assertEqual(response.events[7], {
|
||||
'data': b'foo', 'data_json': None, 'event': None,
|
||||
'event_id': None})
|
||||
|
||||
def test_sse_exception(self):
|
||||
app = Microdot()
|
||||
|
||||
Reference in New Issue
Block a user