Support for SSE responses in the test client

This commit is contained in:
Miguel Grinberg
2025-05-18 18:26:38 +01:00
parent 654a85f46b
commit f5d3d931ed
3 changed files with 67 additions and 4 deletions

View File

@@ -63,6 +63,8 @@ def sse_response(request, event_function, *args, **kwargs):
async def sse_task_wrapper():
try:
await event_function(request, sse, *args, **kwargs)
except asyncio.CancelledError: # pragma: no cover
pass
except Exception as exc:
# the SSE task raised an exception so we need to pass it to the
# main route so that it is re-raised there

View File

@@ -1,3 +1,4 @@
import asyncio
from microdot.microdot import Request, Response, AsyncBytesIO
try:
@@ -32,6 +33,11 @@ class TestResponse:
#: The body of the JSON response, decoded to a dictionary or list. Set
#: ``Note`` if the response does not have a JSON payload.
self.json = None
#: The body of the SSE response, decoded to a list of events, each
#: given as a dictionary with a ``data`` key and optionally also
#: ``event`` and ``id`` keys. Set to ``None`` if the response does not
#: have an SSE payload.
self.events = None
def _initialize_response(self, res):
self.status_code = res.status_code
@@ -41,10 +47,13 @@ class TestResponse:
async def _initialize_body(self, res):
self.body = b''
iter = res.body_iter()
async for body in iter: # pragma: no branch
if isinstance(body, str):
body = body.encode()
self.body += body
try:
async for body in iter: # pragma: no branch
if isinstance(body, str):
body = body.encode()
self.body += body
except asyncio.CancelledError: # pragma: no cover
pass
if hasattr(iter, 'aclose'): # pragma: no branch
await iter.aclose()
@@ -60,6 +69,32 @@ class TestResponse:
if content_type.split(';')[0] == 'application/json':
self.json = json.loads(self.text)
def _process_sse_body(self):
if 'Content-Type' in self.headers: # pragma: no branch
content_type = self.headers['Content-Type']
if content_type.split(';')[0] == 'text/event-stream':
self.events = []
for sse_event in self.body.split(b'\n\n'):
data = None
event = None
event_id = None
for line in sse_event.split(b'\n'):
if line.startswith(b'data:'):
data = line[5:].strip()
elif line.startswith(b'event:'):
event = line[6:].strip().decode()
elif line.startswith(b'id:'):
event_id = line[3:].strip().decode()
if data:
data_json = None
try:
data_json = json.loads(data)
except ValueError:
pass
self.events.append({
"data": data, "data_json": data_json,
"event": event, "event_id": event_id})
@classmethod
async def create(cls, res):
test_res = cls()
@@ -68,6 +103,7 @@ class TestResponse:
await test_res._initialize_body(res)
test_res._process_text_body()
test_res._process_json_body()
test_res._process_sse_body()
return test_res

View File

@@ -42,6 +42,31 @@ class TestWebSocket(unittest.TestCase):
'data: [42, "foo", "bar"]\n\n'
'data: foo\n\n'
'data: foo\n\n'))
self.assertEqual(len(response.events), 8)
self.assertEqual(response.events[0], {
'data': b'foo', 'data_json': None, 'event': None,
'event_id': None})
self.assertEqual(response.events[1], {
'data': b'bar', 'data_json': None, 'event': 'test',
'event_id': None})
self.assertEqual(response.events[2], {
'data': b'bar', 'data_json': None, 'event': 'test',
'event_id': 'id42'})
self.assertEqual(response.events[3], {
'data': b'bar', 'data_json': None, 'event': None,
'event_id': 'id42'})
self.assertEqual(response.events[4], {
'data': b'{"foo": "bar"}', 'data_json': {'foo': 'bar'},
'event': None, 'event_id': None})
self.assertEqual(response.events[5], {
'data': b'[42, "foo", "bar"]', 'data_json': [42, 'foo', 'bar'],
'event': None, 'event_id': None})
self.assertEqual(response.events[6], {
'data': b'foo', 'data_json': None, 'event': None,
'event_id': None})
self.assertEqual(response.events[7], {
'data': b'foo', 'data_json': None, 'event': None,
'event_id': None})
def test_sse_exception(self):
app = Microdot()