diff --git a/examples/auth/basic_auth.py b/examples/auth/basic_auth.py new file mode 100644 index 0000000..2ed2333 --- /dev/null +++ b/examples/auth/basic_auth.py @@ -0,0 +1,27 @@ +from microdot import Microdot +from microdot_auth import BasicAuth + +app = Microdot() +basic_auth = BasicAuth() + +USERS = { + 'susan': 'hello', + 'david': 'bye', +} + + +@basic_auth.callback +def verify_password(request, username, password): + if username in USERS and USERS[username] == password: + request.g.user = username + return True + + +@app.route('/') +@basic_auth +def index(request): + return f'Hello, {request.g.user}!' + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/examples/auth/login_auth.py b/examples/auth/login_auth.py new file mode 100644 index 0000000..bc4ced0 --- /dev/null +++ b/examples/auth/login_auth.py @@ -0,0 +1,60 @@ +from microdot import Microdot, redirect +from microdot_session import set_session_secret_key +from microdot_login import LoginAuth + +app = Microdot() +set_session_secret_key('top-secret') +login_auth = LoginAuth() + +USERS = { + 'susan': 'hello', + 'david': 'bye', +} + + +@login_auth.callback +def check_user(request, user_id): + request.g.user = user_id + return True + + +@app.route('/') +@login_auth +def index(request): + return f''' +

Login Auth Example

+

Hello, {request.g.user}!

+
+ +
+ ''', {'Content-Type': 'text/html'} + + +@app.route('/login', methods=['GET', 'POST']) +def login(request): + if request.method == 'GET': + return ''' +

Login Auth Example

+
+ + + +
+ ''', {'Content-Type': 'text/html'} + username = request.form['username'] + password = request.form['password'] + if USERS.get(username) == password: + login_auth.login_user(request, username) + return login_auth.redirect_to_next(request) + else: + return redirect('/login') + + +@app.post('/logout') +def logout(request): + login_auth.logout_user(request) + return redirect('/') + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/examples/auth/token_auth.py b/examples/auth/token_auth.py new file mode 100644 index 0000000..0856dd4 --- /dev/null +++ b/examples/auth/token_auth.py @@ -0,0 +1,27 @@ +from microdot import Microdot +from microdot_auth import TokenAuth + +app = Microdot() +token_auth = TokenAuth() + +TOKENS = { + 'hello': 'susan', + 'bye': 'david', +} + + +@token_auth.callback +def verify_token(request, token): + if token in TOKENS: + request.g.user = TOKENS[token] + return True + + +@app.route('/') +@token_auth +def index(request): + return f'Hello, {request.g.user}!' + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/setup.cfg b/setup.cfg index 1d53ff1..422df0e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,8 @@ py_modules = microdot_utemplate microdot_jinja microdot_session + microdot_auth + microdot_login microdot_websocket microdot_websocket_alt microdot_asyncio_websocket diff --git a/src/microdot_auth.py b/src/microdot_auth.py new file mode 100644 index 0000000..8e307cd --- /dev/null +++ b/src/microdot_auth.py @@ -0,0 +1,65 @@ +from microdot import abort + + +class BaseAuth: + def __init__(self, header='Authorization', scheme=None): + self.auth_callback = None + self.error_callback = self.auth_failed + self.header = header + self.scheme = scheme.lower() + + def callback(self, f): + """Decorator to configure the authentication callback. + + Microdot calls the authentication callback to allow the application to + check user credentials. + """ + self.auth_callback = f + + def errorhandler(self, f): + """Decorator to configure the error callback. + + Microdot calls the error callback to allow the application to generate + a custom error response. The default error response is to call + ``abort(401)``. + """ + self.error_callback = f + + def auth_failed(self): + abort(401) + + def __call__(self, func): + def wrapper(request, *args, **kwargs): + auth = request.headers.get(self.header) + if not auth: + return self.error_callback() + if self.header == 'Authorization': + if ' ' not in auth: + return self.error_callback() + scheme, auth = auth.split(' ', 1) + if scheme.lower() != self.scheme: + return self.error_callback() + if not self.auth_callback(request, *self._get_auth_args(auth)): + return self.error_callback() + return func(request, *args, **kwargs) + + return wrapper + + +class BasicAuth(BaseAuth): + def __init__(self): + super().__init__(scheme='Basic') + + def _get_auth_args(self, auth): + import binascii + username, password = binascii.a2b_base64(auth).decode('utf-8').split( + ':', 1) + return (username, password) + + +class TokenAuth(BaseAuth): + def __init__(self, header='Authorization', scheme='Bearer'): + super().__init__(header=header, scheme=scheme) + + def _get_auth_args(self, token): + return (token,) diff --git a/src/microdot_login.py b/src/microdot_login.py new file mode 100644 index 0000000..a5546a8 --- /dev/null +++ b/src/microdot_login.py @@ -0,0 +1,46 @@ +from microdot import redirect, urlencode +from microdot_session import get_session, update_session + + +class LoginAuth: + def __init__(self, login_url='/login'): + super().__init__() + self.login_url = login_url + self.user_callback = self._accept_user + + def callback(self, f): + self.user_callback = f + + def login_user(self, request, user_id): + session = get_session(request) + session['user_id'] = user_id + update_session(request, session) + return session + + def logout_user(self, request): + session = get_session(request) + session.pop('user_id', None) + update_session(request, session) + return session + + def redirect_to_next(self, request, default_url='/'): + next_url = request.args.get('next', default_url) + if not next_url.startswith('/'): + next_url = default_url + return redirect(next_url) + + def __call__(self, func): + def wrapper(request, *args, **kwargs): + session = get_session(request) + if 'user_id' not in session: + return redirect(self.login_url + '?next=' + urlencode( + request.url)) + if not self.user_callback(request, session['user_id']): + return redirect(self.login_url + '?next=' + urlencode( + request.url)) + return func(request, *args, **kwargs) + + return wrapper + + def _accept_user(self, request, user_id): + return True diff --git a/tests/test_microdot_auth.py b/tests/test_microdot_auth.py new file mode 100644 index 0000000..c90062a --- /dev/null +++ b/tests/test_microdot_auth.py @@ -0,0 +1,113 @@ +import binascii +import unittest +from microdot import Microdot +from microdot_auth import BasicAuth, TokenAuth +from microdot_test_client import TestClient + + +class TestAuth(unittest.TestCase): + def test_basic_auth(self): + app = Microdot() + basic_auth = BasicAuth() + + @basic_auth.callback + def authenticate(request, username, password): + if username == 'foo' and password == 'bar': + request.g.user = {'username': username} + return True + + @app.route('/') + @basic_auth + def index(request): + return request.g.user['username'] + + client = TestClient(app) + res = client.get('/') + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:bar').decode()}) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'foo') + + res = client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:baz').decode()}) + self.assertEqual(res.status_code, 401) + + def test_token_auth(self): + app = Microdot() + token_auth = TokenAuth() + + @token_auth.callback + def authenticate(request, token): + if token == 'foo': + request.g.user = 'user' + return True + + @app.route('/') + @token_auth + def index(request): + return request.g.user + + client = TestClient(app) + res = client.get('/') + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'Authorization': 'Basic foo'}) + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'Authorization': 'foo'}) + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'Authorization': 'Bearer foo'}) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + def test_token_auth_custom_header(self): + app = Microdot() + token_auth = TokenAuth(header='X-Auth-Token') + + @token_auth.callback + def authenticate(request, token): + if token == 'foo': + request.g.user = 'user' + return True + + @app.route('/') + @token_auth + def index(request): + return request.g.user + + client = TestClient(app) + res = client.get('/') + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'Authorization': 'Basic foo'}) + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'Authorization': 'foo'}) + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'Authorization': 'Bearer foo'}) + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'X-Token-Auth': 'Bearer foo'}) + self.assertEqual(res.status_code, 401) + + res = client.get('/', headers={'X-Auth-Token': 'foo'}) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + res = client.get('/', headers={'x-auth-token': 'foo'}) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + @token_auth.errorhandler + def error_handler(): + return {'status_code': 403}, 403 + + res = client.get('/') + self.assertEqual(res.status_code, 403) + self.assertEqual(res.json, {'status_code': 403}) diff --git a/tests/test_microdot_login.py b/tests/test_microdot_login.py new file mode 100644 index 0000000..dcb9913 --- /dev/null +++ b/tests/test_microdot_login.py @@ -0,0 +1,134 @@ +import unittest +from microdot import Microdot +from microdot_login import LoginAuth +from microdot_session import set_session_secret_key, with_session +from microdot_test_client import TestClient + +set_session_secret_key('top-secret!') + + +class TestLogin(unittest.TestCase): + def test_login_auth(self): + app = Microdot() + login_auth = LoginAuth() + + @app.get('/') + @login_auth + def index(request): + return 'ok' + + @app.post('/login') + def login(request): + login_auth.login_user(request, 'user') + return login_auth.redirect_to_next(request) + + @app.post('/logout') + def logout(request): + login_auth.logout_user(request) + return 'ok' + + client = TestClient(app) + res = client.get('/?foo=bar') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar') + + res = client.post('/login?next=/%3Ffoo=bar') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/?foo=bar') + + res = client.get('/') + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'ok') + + res = client.post('/logout') + self.assertEqual(res.status_code, 200) + + res = client.get('/') + self.assertEqual(res.status_code, 302) + + def test_login_auth_with_session(self): + app = Microdot() + login_auth = LoginAuth(login_url='/foo') + + @app.get('/') + @login_auth + @with_session + def index(request, session): + return session['user_id'] + + @app.post('/foo') + def login(request): + login_auth.login_user(request, 'user') + return login_auth.redirect_to_next(request) + + client = TestClient(app) + res = client.get('/') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/foo?next=/') + + res = client.post('/foo') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') + + res = client.get('/') + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + def test_login_auth_user_callback(self): + app = Microdot() + login_auth = LoginAuth() + + @login_auth.callback + def check_user(request, user_id): + request.g.user_id = user_id + return user_id == 'user' + + @app.get('/') + @login_auth + def index(request): + return request.g.user_id + + @app.post('/good-login') + def good_login(request): + login_auth.login_user(request, 'user') + return login_auth.redirect_to_next(request) + + @app.post('/bad-login') + def bad_login(request): + login_auth.login_user(request, 'foo') + return login_auth.redirect_to_next(request) + + client = TestClient(app) + res = client.post('/good-login') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') + res = client.get('/') + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + res = client.post('/bad-login') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') + res = client.get('/') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/') + + def test_login_auth_bad_redirect(self): + app = Microdot() + login_auth = LoginAuth() + + @app.get('/') + @login_auth + def index(request): + return 'ok' + + @app.post('/login') + def login(request): + login_auth.login_user(request, 'user') + return login_auth.redirect_to_next(request) + + client = TestClient(app) + res = client.post('/login?next=http://example.com') + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') +