diff --git a/examples/auth/basic_auth.py b/examples/auth/basic_auth.py
new file mode 100644
index 0000000..2ed2333
--- /dev/null
+++ b/examples/auth/basic_auth.py
@@ -0,0 +1,27 @@
+from microdot import Microdot
+from microdot_auth import BasicAuth
+
+app = Microdot()
+basic_auth = BasicAuth()
+
+USERS = {
+ 'susan': 'hello',
+ 'david': 'bye',
+}
+
+
+@basic_auth.callback
+def verify_password(request, username, password):
+ if username in USERS and USERS[username] == password:
+ request.g.user = username
+ return True
+
+
+@app.route('/')
+@basic_auth
+def index(request):
+ return f'Hello, {request.g.user}!'
+
+
+if __name__ == '__main__':
+ app.run(debug=True)
diff --git a/examples/auth/login_auth.py b/examples/auth/login_auth.py
new file mode 100644
index 0000000..bc4ced0
--- /dev/null
+++ b/examples/auth/login_auth.py
@@ -0,0 +1,60 @@
+from microdot import Microdot, redirect
+from microdot_session import set_session_secret_key
+from microdot_login import LoginAuth
+
+app = Microdot()
+set_session_secret_key('top-secret')
+login_auth = LoginAuth()
+
+USERS = {
+ 'susan': 'hello',
+ 'david': 'bye',
+}
+
+
+@login_auth.callback
+def check_user(request, user_id):
+ request.g.user = user_id
+ return True
+
+
+@app.route('/')
+@login_auth
+def index(request):
+ return f'''
+
Login Auth Example
+ Hello, {request.g.user}!
+
+ ''', {'Content-Type': 'text/html'}
+
+
+@app.route('/login', methods=['GET', 'POST'])
+def login(request):
+ if request.method == 'GET':
+ return '''
+ Login Auth Example
+
+ ''', {'Content-Type': 'text/html'}
+ username = request.form['username']
+ password = request.form['password']
+ if USERS.get(username) == password:
+ login_auth.login_user(request, username)
+ return login_auth.redirect_to_next(request)
+ else:
+ return redirect('/login')
+
+
+@app.post('/logout')
+def logout(request):
+ login_auth.logout_user(request)
+ return redirect('/')
+
+
+if __name__ == '__main__':
+ app.run(debug=True)
diff --git a/examples/auth/token_auth.py b/examples/auth/token_auth.py
new file mode 100644
index 0000000..0856dd4
--- /dev/null
+++ b/examples/auth/token_auth.py
@@ -0,0 +1,27 @@
+from microdot import Microdot
+from microdot_auth import TokenAuth
+
+app = Microdot()
+token_auth = TokenAuth()
+
+TOKENS = {
+ 'hello': 'susan',
+ 'bye': 'david',
+}
+
+
+@token_auth.callback
+def verify_token(request, token):
+ if token in TOKENS:
+ request.g.user = TOKENS[token]
+ return True
+
+
+@app.route('/')
+@token_auth
+def index(request):
+ return f'Hello, {request.g.user}!'
+
+
+if __name__ == '__main__':
+ app.run(debug=True)
diff --git a/setup.cfg b/setup.cfg
index 1d53ff1..422df0e 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -28,6 +28,8 @@ py_modules =
microdot_utemplate
microdot_jinja
microdot_session
+ microdot_auth
+ microdot_login
microdot_websocket
microdot_websocket_alt
microdot_asyncio_websocket
diff --git a/src/microdot_auth.py b/src/microdot_auth.py
new file mode 100644
index 0000000..8e307cd
--- /dev/null
+++ b/src/microdot_auth.py
@@ -0,0 +1,65 @@
+from microdot import abort
+
+
+class BaseAuth:
+ def __init__(self, header='Authorization', scheme=None):
+ self.auth_callback = None
+ self.error_callback = self.auth_failed
+ self.header = header
+ self.scheme = scheme.lower()
+
+ def callback(self, f):
+ """Decorator to configure the authentication callback.
+
+ Microdot calls the authentication callback to allow the application to
+ check user credentials.
+ """
+ self.auth_callback = f
+
+ def errorhandler(self, f):
+ """Decorator to configure the error callback.
+
+ Microdot calls the error callback to allow the application to generate
+ a custom error response. The default error response is to call
+ ``abort(401)``.
+ """
+ self.error_callback = f
+
+ def auth_failed(self):
+ abort(401)
+
+ def __call__(self, func):
+ def wrapper(request, *args, **kwargs):
+ auth = request.headers.get(self.header)
+ if not auth:
+ return self.error_callback()
+ if self.header == 'Authorization':
+ if ' ' not in auth:
+ return self.error_callback()
+ scheme, auth = auth.split(' ', 1)
+ if scheme.lower() != self.scheme:
+ return self.error_callback()
+ if not self.auth_callback(request, *self._get_auth_args(auth)):
+ return self.error_callback()
+ return func(request, *args, **kwargs)
+
+ return wrapper
+
+
+class BasicAuth(BaseAuth):
+ def __init__(self):
+ super().__init__(scheme='Basic')
+
+ def _get_auth_args(self, auth):
+ import binascii
+ username, password = binascii.a2b_base64(auth).decode('utf-8').split(
+ ':', 1)
+ return (username, password)
+
+
+class TokenAuth(BaseAuth):
+ def __init__(self, header='Authorization', scheme='Bearer'):
+ super().__init__(header=header, scheme=scheme)
+
+ def _get_auth_args(self, token):
+ return (token,)
diff --git a/src/microdot_login.py b/src/microdot_login.py
new file mode 100644
index 0000000..a5546a8
--- /dev/null
+++ b/src/microdot_login.py
@@ -0,0 +1,46 @@
+from microdot import redirect, urlencode
+from microdot_session import get_session, update_session
+
+
+class LoginAuth:
+ def __init__(self, login_url='/login'):
+ super().__init__()
+ self.login_url = login_url
+ self.user_callback = self._accept_user
+
+ def callback(self, f):
+ self.user_callback = f
+
+ def login_user(self, request, user_id):
+ session = get_session(request)
+ session['user_id'] = user_id
+ update_session(request, session)
+ return session
+
+ def logout_user(self, request):
+ session = get_session(request)
+ session.pop('user_id', None)
+ update_session(request, session)
+ return session
+
+ def redirect_to_next(self, request, default_url='/'):
+ next_url = request.args.get('next', default_url)
+ if not next_url.startswith('/'):
+ next_url = default_url
+ return redirect(next_url)
+
+ def __call__(self, func):
+ def wrapper(request, *args, **kwargs):
+ session = get_session(request)
+ if 'user_id' not in session:
+ return redirect(self.login_url + '?next=' + urlencode(
+ request.url))
+ if not self.user_callback(request, session['user_id']):
+ return redirect(self.login_url + '?next=' + urlencode(
+ request.url))
+ return func(request, *args, **kwargs)
+
+ return wrapper
+
+ def _accept_user(self, request, user_id):
+ return True
diff --git a/tests/test_microdot_auth.py b/tests/test_microdot_auth.py
new file mode 100644
index 0000000..c90062a
--- /dev/null
+++ b/tests/test_microdot_auth.py
@@ -0,0 +1,113 @@
+import binascii
+import unittest
+from microdot import Microdot
+from microdot_auth import BasicAuth, TokenAuth
+from microdot_test_client import TestClient
+
+
+class TestAuth(unittest.TestCase):
+ def test_basic_auth(self):
+ app = Microdot()
+ basic_auth = BasicAuth()
+
+ @basic_auth.callback
+ def authenticate(request, username, password):
+ if username == 'foo' and password == 'bar':
+ request.g.user = {'username': username}
+ return True
+
+ @app.route('/')
+ @basic_auth
+ def index(request):
+ return request.g.user['username']
+
+ client = TestClient(app)
+ res = client.get('/')
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={
+ 'Authorization': 'Basic ' + binascii.b2a_base64(
+ b'foo:bar').decode()})
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'foo')
+
+ res = client.get('/', headers={
+ 'Authorization': 'Basic ' + binascii.b2a_base64(
+ b'foo:baz').decode()})
+ self.assertEqual(res.status_code, 401)
+
+ def test_token_auth(self):
+ app = Microdot()
+ token_auth = TokenAuth()
+
+ @token_auth.callback
+ def authenticate(request, token):
+ if token == 'foo':
+ request.g.user = 'user'
+ return True
+
+ @app.route('/')
+ @token_auth
+ def index(request):
+ return request.g.user
+
+ client = TestClient(app)
+ res = client.get('/')
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'Authorization': 'Basic foo'})
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'Authorization': 'foo'})
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'Authorization': 'Bearer foo'})
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'user')
+
+ def test_token_auth_custom_header(self):
+ app = Microdot()
+ token_auth = TokenAuth(header='X-Auth-Token')
+
+ @token_auth.callback
+ def authenticate(request, token):
+ if token == 'foo':
+ request.g.user = 'user'
+ return True
+
+ @app.route('/')
+ @token_auth
+ def index(request):
+ return request.g.user
+
+ client = TestClient(app)
+ res = client.get('/')
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'Authorization': 'Basic foo'})
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'Authorization': 'foo'})
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'Authorization': 'Bearer foo'})
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'X-Token-Auth': 'Bearer foo'})
+ self.assertEqual(res.status_code, 401)
+
+ res = client.get('/', headers={'X-Auth-Token': 'foo'})
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'user')
+
+ res = client.get('/', headers={'x-auth-token': 'foo'})
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'user')
+
+ @token_auth.errorhandler
+ def error_handler():
+ return {'status_code': 403}, 403
+
+ res = client.get('/')
+ self.assertEqual(res.status_code, 403)
+ self.assertEqual(res.json, {'status_code': 403})
diff --git a/tests/test_microdot_login.py b/tests/test_microdot_login.py
new file mode 100644
index 0000000..dcb9913
--- /dev/null
+++ b/tests/test_microdot_login.py
@@ -0,0 +1,134 @@
+import unittest
+from microdot import Microdot
+from microdot_login import LoginAuth
+from microdot_session import set_session_secret_key, with_session
+from microdot_test_client import TestClient
+
+set_session_secret_key('top-secret!')
+
+
+class TestLogin(unittest.TestCase):
+ def test_login_auth(self):
+ app = Microdot()
+ login_auth = LoginAuth()
+
+ @app.get('/')
+ @login_auth
+ def index(request):
+ return 'ok'
+
+ @app.post('/login')
+ def login(request):
+ login_auth.login_user(request, 'user')
+ return login_auth.redirect_to_next(request)
+
+ @app.post('/logout')
+ def logout(request):
+ login_auth.logout_user(request)
+ return 'ok'
+
+ client = TestClient(app)
+ res = client.get('/?foo=bar')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar')
+
+ res = client.post('/login?next=/%3Ffoo=bar')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/?foo=bar')
+
+ res = client.get('/')
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'ok')
+
+ res = client.post('/logout')
+ self.assertEqual(res.status_code, 200)
+
+ res = client.get('/')
+ self.assertEqual(res.status_code, 302)
+
+ def test_login_auth_with_session(self):
+ app = Microdot()
+ login_auth = LoginAuth(login_url='/foo')
+
+ @app.get('/')
+ @login_auth
+ @with_session
+ def index(request, session):
+ return session['user_id']
+
+ @app.post('/foo')
+ def login(request):
+ login_auth.login_user(request, 'user')
+ return login_auth.redirect_to_next(request)
+
+ client = TestClient(app)
+ res = client.get('/')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/foo?next=/')
+
+ res = client.post('/foo')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/')
+
+ res = client.get('/')
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'user')
+
+ def test_login_auth_user_callback(self):
+ app = Microdot()
+ login_auth = LoginAuth()
+
+ @login_auth.callback
+ def check_user(request, user_id):
+ request.g.user_id = user_id
+ return user_id == 'user'
+
+ @app.get('/')
+ @login_auth
+ def index(request):
+ return request.g.user_id
+
+ @app.post('/good-login')
+ def good_login(request):
+ login_auth.login_user(request, 'user')
+ return login_auth.redirect_to_next(request)
+
+ @app.post('/bad-login')
+ def bad_login(request):
+ login_auth.login_user(request, 'foo')
+ return login_auth.redirect_to_next(request)
+
+ client = TestClient(app)
+ res = client.post('/good-login')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/')
+ res = client.get('/')
+ self.assertEqual(res.status_code, 200)
+ self.assertEqual(res.text, 'user')
+
+ res = client.post('/bad-login')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/')
+ res = client.get('/')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/login?next=/')
+
+ def test_login_auth_bad_redirect(self):
+ app = Microdot()
+ login_auth = LoginAuth()
+
+ @app.get('/')
+ @login_auth
+ def index(request):
+ return 'ok'
+
+ @app.post('/login')
+ def login(request):
+ login_auth.login_user(request, 'user')
+ return login_auth.redirect_to_next(request)
+
+ client = TestClient(app)
+ res = client.post('/login?next=http://example.com')
+ self.assertEqual(res.status_code, 302)
+ self.assertEqual(res.headers['Location'], '/')
+