diff --git a/src/microdot.py b/src/microdot.py index e1384d7..ce11cf7 100644 --- a/src/microdot.py +++ b/src/microdot.py @@ -198,6 +198,15 @@ class Request(): #: Request.max_content_length = 1 * 1024 * 1024 # 1MB requests allowed max_content_length = 16 * 1024 + #: Specify the maximum length allowed for a line in the request. Requests + #: with longer lines will not be correctly interpreted. Applications can + #: change this maximum as necessary. + #: + #: Example:: + #: + #: Request.max_readline = 16 * 1024 # 16KB lines allowed + max_readline = 2 * 1024 + class G: pass @@ -244,7 +253,7 @@ class Request(): This method returns a newly created ``Request`` object. """ # request line - line = client_stream.readline().strip().decode() + line = Request._safe_readline(client_stream).strip().decode() if not line: return None method, url, http_version = line.split() @@ -254,7 +263,7 @@ class Request(): headers = {} content_length = 0 while True: - line = client_stream.readline().strip().decode() + line = Request._safe_readline(client_stream).strip().decode() if line == '': break header, value = line.split(':', 1) @@ -298,6 +307,14 @@ class Request(): self._form = self._parse_urlencoded(self.body.decode()) return self._form + @staticmethod + def _safe_readline(stream): + line = stream.readline(Request.max_readline + 1) + print(line, Request.max_readline) + if len(line) > Request.max_readline: + raise ValueError('line too long') + return line + class Response(): """An HTTP response class. diff --git a/src/microdot_asyncio.py b/src/microdot_asyncio.py index 04d3bdd..71035d4 100644 --- a/src/microdot_asyncio.py +++ b/src/microdot_asyncio.py @@ -34,7 +34,7 @@ class Request(BaseRequest): object. """ # request line - line = (await client_stream.readline()).strip().decode() + line = (await Request._safe_readline(client_stream)).strip().decode() if not line: # pragma: no cover return None method, url, http_version = line.split() @@ -44,7 +44,8 @@ class Request(BaseRequest): headers = {} content_length = 0 while True: - line = (await client_stream.readline()).strip().decode() + line = (await Request._safe_readline( + client_stream)).strip().decode() if line == '': break header, value = line.split(':', 1) @@ -60,6 +61,13 @@ class Request(BaseRequest): return Request(app, client_addr, method, url, http_version, headers, body) + @staticmethod + async def _safe_readline(stream): + line = (await stream.readline()) + if len(line) > Request.max_readline: + raise ValueError('line too long') + return line + class Response(BaseResponse): """An HTTP response class. diff --git a/tests/microdot/test_request.py b/tests/microdot/test_request.py index 62fb06e..2aef941 100644 --- a/tests/microdot/test_request.py +++ b/tests/microdot/test_request.py @@ -79,6 +79,18 @@ class TestRequest(unittest.TestCase): req = Request.create('app', fd, 'addr') self.assertIsNone(req.form) + def test_large_line(self): + saved_max_readline = Request.max_readline + Request.max_readline = 16 + + fd = get_request_fd('GET', '/foo', headers={ + 'Content-Type': 'application/x-www-form-urlencoded'}, + body='foo=bar&abc=def&x=y') + with self.assertRaises(ValueError): + Request.create('app', fd, 'addr') + + Request.max_readline = saved_max_readline + def test_large_payload(self): saved_max_content_length = Request.max_content_length Request.max_content_length = 16 @@ -87,6 +99,6 @@ class TestRequest(unittest.TestCase): 'Content-Type': 'application/x-www-form-urlencoded'}, body='foo=bar&abc=def&x=y') req = Request.create('app', fd, 'addr') - assert req.body == b'' + self.assertEqual(req.body, b'') Request.max_content_length = saved_max_content_length diff --git a/tests/microdot_asyncio/test_request_asyncio.py b/tests/microdot_asyncio/test_request_asyncio.py index 882d1ce..b4a63b6 100644 --- a/tests/microdot_asyncio/test_request_asyncio.py +++ b/tests/microdot_asyncio/test_request_asyncio.py @@ -89,6 +89,18 @@ class TestRequestAsync(unittest.TestCase): req = _run(Request.create('app', fd, 'addr')) self.assertIsNone(req.form) + def test_large_line(self): + saved_max_readline = Request.max_readline + Request.max_readline = 16 + + fd = get_async_request_fd('GET', '/foo', headers={ + 'Content-Type': 'application/x-www-form-urlencoded'}, + body='foo=bar&abc=def&x=y') + with self.assertRaises(ValueError): + _run(Request.create('app', fd, 'addr')) + + Request.max_readline = saved_max_readline + def test_large_payload(self): saved_max_content_length = Request.max_content_length Request.max_content_length = 16 @@ -97,6 +109,6 @@ class TestRequestAsync(unittest.TestCase): 'Content-Type': 'application/x-www-form-urlencoded'}, body='foo=bar&abc=def&x=y') req = _run(Request.create('app', fd, 'addr')) - assert req.body == b'' + self.assertEqual(req.body, b'') Request.max_content_length = saved_max_content_length