diff --git a/src/microdot/sse.py b/src/microdot/sse.py index 6376ee0..a27b1ca 100644 --- a/src/microdot/sse.py +++ b/src/microdot/sse.py @@ -17,7 +17,8 @@ class SSE: self.event = asyncio.Event() self.queue = [] - async def send(self, data, event=None, event_id=None): + async def send(self, data, event=None, event_id=None, retry=None, + comment=False): """Send an event to the client. :param data: the data to send. It can be given as a string, bytes, dict @@ -27,6 +28,12 @@ class SSE: given, it must be a string. :param event_id: an optional event id, to send along with the data. If given, it must be a string. + :param retry: an optional reconnection time (in seconds) that the + client should use when the connection is lost. + :param comment: when set to ``True``, the data is sent as a comment + line, and all other parameters are ignored. This is + useful as a heartbeat mechanism that keeps the + connection alive. """ if isinstance(data, (dict, list)): data = json.dumps(data) @@ -34,11 +41,17 @@ class SSE: data = data.encode() elif not isinstance(data, bytes): data = str(data).encode() - data = b'data: ' + data + b'\n\n' - if event_id: - data = b'id: ' + event_id.encode() + b'\n' + data - if event: - data = b'event: ' + event.encode() + b'\n' + data + if comment: + data = b': ' + data + b'\n\n' + else: + data = b'data: ' + data + b'\n\n' + if event_id: + data = b'id: ' + event_id.encode() + b'\n' + data + if event: + data = b'event: ' + event.encode() + b'\n' + data + if retry: + data = b'retry: ' + str(int(retry * 1000)).encode() + b'\n' + \ + data self.queue.append(data) self.event.set() diff --git a/src/microdot/test_client.py b/src/microdot/test_client.py index 909bd55..a622591 100644 --- a/src/microdot/test_client.py +++ b/src/microdot/test_client.py @@ -78,6 +78,7 @@ class TestResponse: data = None event = None event_id = None + retry = None for line in sse_event.split(b'\n'): if line.startswith(b'data:'): data = line[5:].strip() @@ -85,6 +86,8 @@ class TestResponse: event = line[6:].strip().decode() elif line.startswith(b'id:'): event_id = line[3:].strip().decode() + elif line.startswith(b'retry:'): + retry = int(line[7:].strip()) / 1000 if data: data_json = None try: @@ -92,8 +95,9 @@ class TestResponse: except ValueError: pass self.events.append({ - "data": data, "data_json": data_json, - "event": event, "event_id": event_id}) + 'data': data, 'data_json': data_json, + 'event': event, 'event_id': event_id, + 'retry': retry}) @classmethod async def create(cls, res): diff --git a/tests/test_sse.py b/tests/test_sse.py index 0586b72..04d3de1 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -25,7 +25,9 @@ class TestWebSocket(unittest.TestCase): await sse.send('bar', event='test') await sse.send('bar', event='test', event_id='id42') await sse.send('bar', event_id='id42') + await sse.send('bar', retry=2.5) await sse.send({'foo': 'bar'}) + await sse.send('ping', comment=True) await sse.send([42, 'foo', 'bar']) await sse.send(ValueError('foo')) await sse.send(b'foo') @@ -38,35 +40,40 @@ class TestWebSocket(unittest.TestCase): 'event: test\ndata: bar\n\n' 'event: test\nid: id42\ndata: bar\n\n' 'id: id42\ndata: bar\n\n' + 'retry: 2500\ndata: bar\n\n' 'data: {"foo": "bar"}\n\n' + ': ping\n\n' 'data: [42, "foo", "bar"]\n\n' 'data: foo\n\n' 'data: foo\n\n')) - self.assertEqual(len(response.events), 8) + self.assertEqual(len(response.events), 9) self.assertEqual(response.events[0], { 'data': b'foo', 'data_json': None, 'event': None, - 'event_id': None}) + 'event_id': None, "retry": None}) self.assertEqual(response.events[1], { 'data': b'bar', 'data_json': None, 'event': 'test', - 'event_id': None}) + 'event_id': None, "retry": None}) self.assertEqual(response.events[2], { 'data': b'bar', 'data_json': None, 'event': 'test', - 'event_id': 'id42'}) + 'event_id': 'id42', "retry": None}) self.assertEqual(response.events[3], { 'data': b'bar', 'data_json': None, 'event': None, - 'event_id': 'id42'}) + 'event_id': 'id42', "retry": None}) self.assertEqual(response.events[4], { - 'data': b'{"foo": "bar"}', 'data_json': {'foo': 'bar'}, - 'event': None, 'event_id': None}) + 'data': b'bar', 'data_json': None, 'event': None, 'event_id': None, + 'retry': 2.5}) self.assertEqual(response.events[5], { - 'data': b'[42, "foo", "bar"]', 'data_json': [42, 'foo', 'bar'], - 'event': None, 'event_id': None}) + 'data': b'{"foo": "bar"}', 'data_json': {'foo': 'bar'}, + 'event': None, 'event_id': None, "retry": None}) self.assertEqual(response.events[6], { - 'data': b'foo', 'data_json': None, 'event': None, - 'event_id': None}) + 'data': b'[42, "foo", "bar"]', 'data_json': [42, 'foo', 'bar'], + 'event': None, 'event_id': None, "retry": None}) self.assertEqual(response.events[7], { 'data': b'foo', 'data_json': None, 'event': None, - 'event_id': None}) + 'event_id': None, "retry": None}) + self.assertEqual(response.events[8], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None, "retry": None}) def test_sse_exception(self): app = Microdot()