Support large downloads in send_file (fixes #3)

This commit is contained in:
Miguel Grinberg
2020-02-19 00:01:51 +00:00
parent 1aacb3cf46
commit 3e29af5775
3 changed files with 33 additions and 16 deletions

View File

@@ -59,7 +59,17 @@ class Response(BaseResponse):
# body
if self.body:
await stream.awrite(self.body)
if hasattr(self.body, 'read'):
while True:
buf = self.body.read(self.send_file_buffer_size)
if len(buf):
await stream.awrite(buf)
if len(buf) < self.send_file_buffer_size:
break
if hasattr(self.body, 'close'):
self.body.close()
else:
await stream.awrite(self.body)
class Microdot(BaseMicrodot):

View File

@@ -154,6 +154,7 @@ class Response():
'png': 'image/png',
'txt': 'text/plain',
}
send_file_buffer_size = 1024
def __init__(self, body='', status_code=200, headers=None):
self.status_code = status_code
@@ -163,10 +164,9 @@ class Response():
self.headers['Content-Type'] = 'application/json'
elif isinstance(body, str):
self.body = body.encode()
elif isinstance(body, bytes):
self.body = body
else:
self.body = str(body).encode()
# this applies to bytes or file-like objects
self.body = body
def set_cookie(self, cookie, value, path=None, domain=None, expires=None,
max_age=None, secure=False, http_only=False):
@@ -190,7 +190,8 @@ class Response():
self.headers['Set-Cookie'] = [http_cookie]
def complete(self):
if 'Content-Length' not in self.headers:
if isinstance(self.body, bytes) and \
'Content-Length' not in self.headers:
self.headers['Content-Length'] = str(len(self.body))
if 'Content-Type' not in self.headers:
self.headers['Content-Type'] = 'text/plain'
@@ -213,7 +214,17 @@ class Response():
# body
if self.body:
stream.write(self.body)
if hasattr(self.body, 'read'):
while True:
buf = self.body.read(self.send_file_buffer_size)
if len(buf):
stream.write(buf)
if len(buf) < self.send_file_buffer_size:
break
if hasattr(self.body, 'close'):
self.body.close()
else:
stream.write(self.body)
@classmethod
def redirect(cls, location, status_code=302):
@@ -227,11 +238,9 @@ class Response():
content_type = Response.types_map[ext]
else:
content_type = 'application/octet-stream'
with open(filename) as f:
body = f.read()
return cls(body=body, status_code=status_code,
headers={'Content-Type': content_type,
'Content-Length': str(len(body))})
f = open(filename, 'rb')
return cls(body=f, status_code=status_code,
headers={'Content-Type': content_type})
class URLPattern():

View File

@@ -92,7 +92,7 @@ class TestResponse(unittest.TestCase):
res = Response(123)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers, {})
self.assertEqual(res.body, b'123')
self.assertEqual(res.body, 123)
def test_create_with_status_code(self):
res = Response('not found', 404)
@@ -161,11 +161,9 @@ class TestResponse(unittest.TestCase):
res = Response.send_file('tests/files/' + file)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Content-Type'], content_type)
self.assertEqual(res.headers['Content-Length'], '4')
self.assertEqual(res.body, b'foo\n')
self.assertEqual(res.body.read(), b'foo\n')
res = Response.send_file('tests/files/test.txt',
content_type='text/html')
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Content-Type'], 'text/html')
self.assertEqual(res.headers['Content-Length'], '4')
self.assertEqual(res.body, b'foo\n')
self.assertEqual(res.body.read(), b'foo\n')