From f5d3d931edfbacedebf5fdf938ef77c5ee910380 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 18 May 2025 18:26:38 +0100 Subject: [PATCH] Support for SSE responses in the test client --- src/microdot/sse.py | 2 ++ src/microdot/test_client.py | 44 +++++++++++++++++++++++++++++++++---- tests/test_sse.py | 25 +++++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/src/microdot/sse.py b/src/microdot/sse.py index 33a88b4..6376ee0 100644 --- a/src/microdot/sse.py +++ b/src/microdot/sse.py @@ -63,6 +63,8 @@ def sse_response(request, event_function, *args, **kwargs): async def sse_task_wrapper(): try: await event_function(request, sse, *args, **kwargs) + except asyncio.CancelledError: # pragma: no cover + pass except Exception as exc: # the SSE task raised an exception so we need to pass it to the # main route so that it is re-raised there diff --git a/src/microdot/test_client.py b/src/microdot/test_client.py index 0f2d0fc..909bd55 100644 --- a/src/microdot/test_client.py +++ b/src/microdot/test_client.py @@ -1,3 +1,4 @@ +import asyncio from microdot.microdot import Request, Response, AsyncBytesIO try: @@ -32,6 +33,11 @@ class TestResponse: #: The body of the JSON response, decoded to a dictionary or list. Set #: ``Note`` if the response does not have a JSON payload. self.json = None + #: The body of the SSE response, decoded to a list of events, each + #: given as a dictionary with a ``data`` key and optionally also + #: ``event`` and ``id`` keys. Set to ``None`` if the response does not + #: have an SSE payload. + self.events = None def _initialize_response(self, res): self.status_code = res.status_code @@ -41,10 +47,13 @@ class TestResponse: async def _initialize_body(self, res): self.body = b'' iter = res.body_iter() - async for body in iter: # pragma: no branch - if isinstance(body, str): - body = body.encode() - self.body += body + try: + async for body in iter: # pragma: no branch + if isinstance(body, str): + body = body.encode() + self.body += body + except asyncio.CancelledError: # pragma: no cover + pass if hasattr(iter, 'aclose'): # pragma: no branch await iter.aclose() @@ -60,6 +69,32 @@ class TestResponse: if content_type.split(';')[0] == 'application/json': self.json = json.loads(self.text) + def _process_sse_body(self): + if 'Content-Type' in self.headers: # pragma: no branch + content_type = self.headers['Content-Type'] + if content_type.split(';')[0] == 'text/event-stream': + self.events = [] + for sse_event in self.body.split(b'\n\n'): + data = None + event = None + event_id = None + for line in sse_event.split(b'\n'): + if line.startswith(b'data:'): + data = line[5:].strip() + elif line.startswith(b'event:'): + event = line[6:].strip().decode() + elif line.startswith(b'id:'): + event_id = line[3:].strip().decode() + if data: + data_json = None + try: + data_json = json.loads(data) + except ValueError: + pass + self.events.append({ + "data": data, "data_json": data_json, + "event": event, "event_id": event_id}) + @classmethod async def create(cls, res): test_res = cls() @@ -68,6 +103,7 @@ class TestResponse: await test_res._initialize_body(res) test_res._process_text_body() test_res._process_json_body() + test_res._process_sse_body() return test_res diff --git a/tests/test_sse.py b/tests/test_sse.py index cf2f8db..0586b72 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -42,6 +42,31 @@ class TestWebSocket(unittest.TestCase): 'data: [42, "foo", "bar"]\n\n' 'data: foo\n\n' 'data: foo\n\n')) + self.assertEqual(len(response.events), 8) + self.assertEqual(response.events[0], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None}) + self.assertEqual(response.events[1], { + 'data': b'bar', 'data_json': None, 'event': 'test', + 'event_id': None}) + self.assertEqual(response.events[2], { + 'data': b'bar', 'data_json': None, 'event': 'test', + 'event_id': 'id42'}) + self.assertEqual(response.events[3], { + 'data': b'bar', 'data_json': None, 'event': None, + 'event_id': 'id42'}) + self.assertEqual(response.events[4], { + 'data': b'{"foo": "bar"}', 'data_json': {'foo': 'bar'}, + 'event': None, 'event_id': None}) + self.assertEqual(response.events[5], { + 'data': b'[42, "foo", "bar"]', 'data_json': [42, 'foo', 'bar'], + 'event': None, 'event_id': None}) + self.assertEqual(response.events[6], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None}) + self.assertEqual(response.events[7], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None}) def test_sse_exception(self): app = Microdot()