Support large downloads in send_file (fixes #3)
This commit is contained in:
@@ -59,7 +59,17 @@ class Response(BaseResponse):
|
||||
|
||||
# body
|
||||
if self.body:
|
||||
await stream.awrite(self.body)
|
||||
if hasattr(self.body, 'read'):
|
||||
while True:
|
||||
buf = self.body.read(self.send_file_buffer_size)
|
||||
if len(buf):
|
||||
await stream.awrite(buf)
|
||||
if len(buf) < self.send_file_buffer_size:
|
||||
break
|
||||
if hasattr(self.body, 'close'):
|
||||
self.body.close()
|
||||
else:
|
||||
await stream.awrite(self.body)
|
||||
|
||||
|
||||
class Microdot(BaseMicrodot):
|
||||
|
||||
@@ -154,6 +154,7 @@ class Response():
|
||||
'png': 'image/png',
|
||||
'txt': 'text/plain',
|
||||
}
|
||||
send_file_buffer_size = 1024
|
||||
|
||||
def __init__(self, body='', status_code=200, headers=None):
|
||||
self.status_code = status_code
|
||||
@@ -163,10 +164,9 @@ class Response():
|
||||
self.headers['Content-Type'] = 'application/json'
|
||||
elif isinstance(body, str):
|
||||
self.body = body.encode()
|
||||
elif isinstance(body, bytes):
|
||||
self.body = body
|
||||
else:
|
||||
self.body = str(body).encode()
|
||||
# this applies to bytes or file-like objects
|
||||
self.body = body
|
||||
|
||||
def set_cookie(self, cookie, value, path=None, domain=None, expires=None,
|
||||
max_age=None, secure=False, http_only=False):
|
||||
@@ -190,7 +190,8 @@ class Response():
|
||||
self.headers['Set-Cookie'] = [http_cookie]
|
||||
|
||||
def complete(self):
|
||||
if 'Content-Length' not in self.headers:
|
||||
if isinstance(self.body, bytes) and \
|
||||
'Content-Length' not in self.headers:
|
||||
self.headers['Content-Length'] = str(len(self.body))
|
||||
if 'Content-Type' not in self.headers:
|
||||
self.headers['Content-Type'] = 'text/plain'
|
||||
@@ -213,7 +214,17 @@ class Response():
|
||||
|
||||
# body
|
||||
if self.body:
|
||||
stream.write(self.body)
|
||||
if hasattr(self.body, 'read'):
|
||||
while True:
|
||||
buf = self.body.read(self.send_file_buffer_size)
|
||||
if len(buf):
|
||||
stream.write(buf)
|
||||
if len(buf) < self.send_file_buffer_size:
|
||||
break
|
||||
if hasattr(self.body, 'close'):
|
||||
self.body.close()
|
||||
else:
|
||||
stream.write(self.body)
|
||||
|
||||
@classmethod
|
||||
def redirect(cls, location, status_code=302):
|
||||
@@ -227,11 +238,9 @@ class Response():
|
||||
content_type = Response.types_map[ext]
|
||||
else:
|
||||
content_type = 'application/octet-stream'
|
||||
with open(filename) as f:
|
||||
body = f.read()
|
||||
return cls(body=body, status_code=status_code,
|
||||
headers={'Content-Type': content_type,
|
||||
'Content-Length': str(len(body))})
|
||||
f = open(filename, 'rb')
|
||||
return cls(body=f, status_code=status_code,
|
||||
headers={'Content-Type': content_type})
|
||||
|
||||
|
||||
class URLPattern():
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestResponse(unittest.TestCase):
|
||||
res = Response(123)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.headers, {})
|
||||
self.assertEqual(res.body, b'123')
|
||||
self.assertEqual(res.body, 123)
|
||||
|
||||
def test_create_with_status_code(self):
|
||||
res = Response('not found', 404)
|
||||
@@ -161,11 +161,9 @@ class TestResponse(unittest.TestCase):
|
||||
res = Response.send_file('tests/files/' + file)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.headers['Content-Type'], content_type)
|
||||
self.assertEqual(res.headers['Content-Length'], '4')
|
||||
self.assertEqual(res.body, b'foo\n')
|
||||
self.assertEqual(res.body.read(), b'foo\n')
|
||||
res = Response.send_file('tests/files/test.txt',
|
||||
content_type='text/html')
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.headers['Content-Type'], 'text/html')
|
||||
self.assertEqual(res.headers['Content-Length'], '4')
|
||||
self.assertEqual(res.body, b'foo\n')
|
||||
self.assertEqual(res.body.read(), b'foo\n')
|
||||
|
||||
Reference in New Issue
Block a user