g, before_request and after_request

This commit is contained in:
Miguel Grinberg
2019-04-27 18:15:09 +01:00
parent 76ab1fa6d7
commit 8aa50f171d
2 changed files with 120 additions and 54 deletions

View File

@@ -39,9 +39,14 @@ def urldecode(string):
class Request():
class G:
pass
def __init__(self, client_sock, client_addr):
self.client_sock = client_sock
self.client_addr = client_addr
self.url_args = None
self.g = Request.G()
if not hasattr(client_sock, 'readline'): # pragma: no cover
self.client_stream = client_sock.makefile("rwb")
@@ -263,6 +268,8 @@ class URLPattern():
class Microdot():
def __init__(self):
self.url_map = []
self.before_request_handlers = []
self.after_request_handlers = []
self.error_handlers = {}
def route(self, url_pattern, methods=None):
@@ -272,6 +279,14 @@ class Microdot():
return f
return decorated
def before_request(self, f):
self.before_request_handlers.append(f)
return f
def after_request(self, f):
self.after_request_handlers.append(f)
return f
def errorhandler(self, status_code_or_exception_class):
def decorated(f):
self.error_handlers[status_code_or_exception_class] = f
@@ -292,42 +307,53 @@ class Microdot():
while True:
req = Request(*s.accept())
f = None
args = None
for route_methods, route_pattern, route_handler in self.url_map:
if req.method in route_methods:
args = route_pattern.match(req.path)
if args is not None:
req.url_args = route_pattern.match(req.path)
if req.url_args is not None:
f = route_handler
break
try:
res = None
if f:
resp = f(req, **args)
for handler in self.before_request_handlers:
res = handler(req)
if res:
break
if res is None:
res = f(req, **req.url_args)
if isinstance(res, tuple):
res = Response(*res)
elif not isinstance(res, Response):
res = Response(res)
for handler in self.after_request_handlers:
res = handler(req, res) or res
elif 404 in self.error_handlers:
resp = self.error_handlers[404](req)
res = self.error_handlers[404](req)
else:
resp = 'Not found', 404
res = 'Not found', 404
except Exception as exc:
print_exception(exc)
resp = None
res = None
if exc.__class__ in self.error_handlers:
try:
resp = self.error_handlers[exc.__class__](req, exc)
res = self.error_handlers[exc.__class__](req, exc)
except Exception as exc2: # pragma: no cover
print_exception(exc2)
if resp is None:
if res is None:
if 500 in self.error_handlers:
resp = self.error_handlers[500](req)
res = self.error_handlers[500](req)
else:
resp = 'Internal server error', 500
if isinstance(resp, tuple):
resp = Response(*resp)
elif not isinstance(resp, Response):
resp = Response(resp)
res = 'Internal server error', 500
if isinstance(res, tuple):
res = Response(*res)
elif not isinstance(res, Response):
res = Response(res)
if debug: # pragma: no cover
print('{method} {path} {status_code}'.format(
method=req.method, path=req.path,
status_code=resp.status_code))
resp.write(req.client_stream)
status_code=res.status_code))
res.write(req.client_stream)
req.close()

View File

@@ -1,6 +1,6 @@
import sys
import unittest
from microdot import Microdot
from microdot import Microdot, Response
from tests import mock_socket
@@ -24,11 +24,10 @@ class TestMicrodot(unittest.TestCase):
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 200 OK\r\n'
b'Content-Length: 3\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'foo')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 200 OK\r\n'))
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\nfoo'))
def test_post_request(self):
app = Microdot()
@@ -39,16 +38,62 @@ class TestMicrodot(unittest.TestCase):
@app.route('/', methods=['POST'])
def index_post(req):
return 'bar'
return Response('bar')
mock_socket.clear_requests()
fd = mock_socket.add_request('POST', '/')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 200 OK\r\n'
b'Content-Length: 3\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'bar')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 200 OK\r\n'))
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\nbar'))
def test_before_after_request(self):
app = Microdot()
@app.before_request
def before_request(req):
if req.path == '/bar':
return 'bar', 202
req.g.message = 'baz'
@app.after_request
def after_request_one(req, res):
res.headers['X-One'] = '1'
@app.after_request
def after_request_two(req, res):
print('two')
res.set_cookie('foo', 'bar')
return res
@app.route('/bar')
def bar(req):
return 'foo'
@app.route('/baz')
def baz(req):
return req.g.message
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/bar')
self.assertRaises(IndexError, app.run)
self.assertTrue(fd.response.startswith(b'HTTP/1.0 202 N/A\r\n'))
self.assertIn(b'X-One: 1\r\n', fd.response)
self.assertIn(b'Set-Cookie: foo=bar\r\n', fd.response)
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\nbar'))
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/baz')
self.assertRaises(IndexError, app.run)
self.assertTrue(fd.response.startswith(b'HTTP/1.0 200 OK\r\n'))
self.assertIn(b'X-One: 1\r\n', fd.response)
self.assertIn(b'Set-Cookie: foo=bar\r\n', fd.response)
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\nbaz'))
def test_404(self):
app = Microdot()
@@ -60,11 +105,10 @@ class TestMicrodot(unittest.TestCase):
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/foo')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 404 N/A\r\n'
b'Content-Length: 9\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'Not found')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 404 N/A\r\n'))
self.assertIn(b'Content-Length: 9\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\nNot found'))
def test_404_handler(self):
app = Microdot()
@@ -80,11 +124,10 @@ class TestMicrodot(unittest.TestCase):
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/foo')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 200 OK\r\n'
b'Content-Length: 3\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'404')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 200 OK\r\n'))
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\n404'))
def test_500(self):
app = Microdot()
@@ -96,11 +139,10 @@ class TestMicrodot(unittest.TestCase):
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 500 N/A\r\n'
b'Content-Length: 21\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'Internal server error')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 500 N/A\r\n'))
self.assertIn(b'Content-Length: 21\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\nInternal server error'))
def test_500_handler(self):
app = Microdot()
@@ -116,11 +158,10 @@ class TestMicrodot(unittest.TestCase):
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 501 N/A\r\n'
b'Content-Length: 3\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'501')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 501 N/A\r\n'))
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\n501'))
def test_exception_handler(self):
app = Microdot()
@@ -136,8 +177,7 @@ class TestMicrodot(unittest.TestCase):
mock_socket.clear_requests()
fd = mock_socket.add_request('GET', '/')
self.assertRaises(IndexError, app.run)
self.assertEqual(fd.response, b'HTTP/1.0 501 N/A\r\n'
b'Content-Length: 3\r\n'
b'Content-Type: text/plain\r\n'
b'\r\n'
b'501')
self.assertTrue(fd.response.startswith(b'HTTP/1.0 501 N/A\r\n'))
self.assertIn(b'Content-Length: 3\r\n', fd.response)
self.assertIn(b'Content-Type: text/plain\r\n', fd.response)
self.assertTrue(fd.response.endswith(b'\r\n501'))