From b0cddde6ecf01adfab7f039ba583273621b4c86a Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 15 Mar 2024 10:56:50 +0000 Subject: [PATCH] Basic, token and login authentication --- src/microdot/auth.py | 147 +++++++++++++++++++++++++++++++ tests/__init__.py | 1 + tests/test_auth.py | 199 ++++++++++++++++++++++++++++++++++++++++++ tests/test_session.py | 2 +- 4 files changed, 348 insertions(+), 1 deletion(-) create mode 100644 src/microdot/auth.py create mode 100644 tests/test_auth.py diff --git a/src/microdot/auth.py b/src/microdot/auth.py new file mode 100644 index 0000000..2d8f14a --- /dev/null +++ b/src/microdot/auth.py @@ -0,0 +1,147 @@ +from microdot import abort, redirect +from microdot.microdot import urlencode, invoke_handler + + +class BaseAuth: + def __init__(self): + self.auth_callback = None + self.error_callback = lambda request: abort(401) + + def __call__(self, f): + """Decorator to protect a route with authentication. + + Microdot will only call the route if the authentication callback + returns a valid user object, otherwise it will call the error + callback.""" + async def wrapper(request, *args, **kwargs): + auth = self._get_auth(request) + if not auth: + return await invoke_handler(self.error_callback, request) + request.g.current_user = await invoke_handler( + self.auth_callback, request, *auth) + if not request.g.current_user: + return await invoke_handler(self.error_callback, request) + return await invoke_handler(f, request, *args, **kwargs) + + return wrapper + + +class HTTPAuth(BaseAuth): + def authenticate(self, f): + """Decorator to configure the authentication callback. + + Microdot calls the authentication callback to allow the application to + check user credentials. + """ + self.auth_callback = f + + +class BasicAuth(HTTPAuth): + def __init__(self, realm='Please login', charset='UTF-8', scheme='Basic', + error_status=401): + super().__init__() + self.realm = realm + self.charset = charset + self.scheme = scheme + self.error_status = error_status + self.error_callback = self.authentication_error + + def _get_auth(self, request): + auth = request.headers.get('Authorization') + if auth and auth.startswith('Basic '): + import binascii + try: + username, password = binascii.a2b_base64( + auth[6:]).decode().split(':', 1) + except Exception: # pragma: no cover + return None + return username, password + + def authentication_error(self, request): + return '', self.error_status, { + 'WWW-Authenticate': '{} realm="{}", charset="{}"'.format( + self.scheme, self.realm, self.charset)} + + +class TokenAuth(HTTPAuth): + def __init__(self, header='Authorization', scheme='Bearer'): + super().__init__() + self.header = header + self.scheme = scheme.lower() + + def _get_auth(self, request): + auth = request.headers.get(self.header) + if auth: + if self.header == 'Authorization': + try: + scheme, token = auth.split(' ', 1) + except Exception: + return None + if scheme.lower() == self.scheme: + return (token.strip(),) + else: + return (auth,) + + def errorhandler(self, f): + """Decorator to configure the error callback. + + Microdot calls the error callback to allow the application to generate + a custom error response. The default error response is to call + ``abort(401)``. + """ + self.error_callback = f + + +class LoginAuth(BaseAuth): + def __init__(self, login_url='/login'): + super().__init__() + self.login_url = login_url + self.user_callback = None + self.user_id_callback = None + self.auth_callback = self._authenticate + self.error_callback = self._redirect_to_login + + def id_to_user(self, f): + """Decorator to configure the user callback. + + Microdot calls the user callback to load the user object from the + user ID stored in the user session. + """ + self.user_callback = f + + def user_to_id(self, f): + """Decorator to configure the user ID callback. + + Microdot calls the user ID callback to load the user ID from the + user session. + """ + self.user_id_callback = f + + def _get_session(self, request): + return request.app._session.get(request) + + def _get_auth(self, request): + session = self._get_session(request) + if session and 'user_id' in session: + return (session['user_id'],) + + async def _authenticate(self, request, user_id): + return await invoke_handler(self.user_callback, user_id) + + async def _redirect_to_login(self, request): + return '', 302, {'Location': self.login_url + '?next=' + urlencode( + request.url)} + + async def login_user(self, request, user, redirect_url='/'): + session = self._get_session(request) + session['user_id'] = await invoke_handler(self.user_id_callback, user) + session.save() + next_url = request.args.get('next', redirect_url) + if not next_url.startswith('/'): + next_url = redirect_url + return redirect(next_url) + + async def logout_user(self, request): + session = self._get_session(request) + session.pop('user_id', None) + session.save() diff --git a/tests/__init__.py b/tests/__init__.py index 4f40481..3afcab2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -9,3 +9,4 @@ from tests.test_sse import * # noqa: F401, F403 from tests.test_cors import * # noqa: F401, F403 from tests.test_utemplate import * # noqa: F401, F403 from tests.test_session import * # noqa: F401, F403 +from tests.test_auth import * # noqa: F401, F403 diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..1e33dcd --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,199 @@ +import asyncio +import binascii +import unittest +from microdot import Microdot +from microdot.auth import BasicAuth, TokenAuth, LoginAuth +from microdot.session import Session +from microdot.test_client import TestClient + + +class TestAuth(unittest.TestCase): + @classmethod + def setUpClass(cls): + if hasattr(asyncio, 'set_event_loop'): + asyncio.set_event_loop(asyncio.new_event_loop()) + cls.loop = asyncio.get_event_loop() + + def _run(self, coro): + return self.loop.run_until_complete(coro) + + def test_basic_auth(self): + app = Microdot() + basic_auth = BasicAuth() + + @basic_auth.authenticate + def authenticate(request, username, password): + if username == 'foo' and password == 'bar': + return {'username': username} + + @app.route('/') + @basic_auth + def index(request): + return request.g.current_user['username'] + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:bar').decode()})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'foo') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:baz').decode()})) + self.assertEqual(res.status_code, 401) + + def test_token_auth(self): + app = Microdot() + token_auth = TokenAuth() + + @token_auth.authenticate + def authenticate(request, token): + if token == 'foo': + return 'user' + + @app.route('/') + @token_auth + def index(request): + return request.g.current_user + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={'Authorization': 'foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Bearer foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + def test_token_auth_custom_header(self): + app = Microdot() + token_auth = TokenAuth(header='X-Auth-Token') + + @token_auth.authenticate + def authenticate(request, token): + if token == 'foo': + return 'user' + + @app.route('/') + @token_auth + def index(request): + return request.g.current_user + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={'Authorization': 'foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Bearer foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'X-Token-Auth': 'Bearer foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={'X-Auth-Token': 'foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + res = self._run(client.get('/', headers={'x-auth-token': 'foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + @token_auth.errorhandler + def error_handler(request): + return {'status_code': 403}, 403 + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 403) + self.assertEqual(res.json, {'status_code': 403}) + + def test_login_auth(self): + app = Microdot() + Session(app, secret_key='secret') + login_auth = LoginAuth() + + @login_auth.id_to_user + def id_to_user(user_id): + return user_id + + @login_auth.user_to_id + def user_to_id(user): + return user + + @app.get('/') + @login_auth + def index(request): + return request.g.current_user + + @app.post('/login') + async def login(request): + return await login_auth.login_user(request, 'user') + + @app.post('/logout') + async def logout(request): + await login_auth.logout_user(request) + return 'ok' + + client = TestClient(app) + res = self._run(client.get('/?foo=bar')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar') + + res = self._run(client.post('/login?next=/%3Ffoo=bar')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/?foo=bar') + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + res = self._run(client.post('/logout')) + self.assertEqual(res.status_code, 200) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 302) + + def test_login_auth_bad_redirect(self): + app = Microdot() + Session(app, secret_key='secret') + login_auth = LoginAuth() + + @login_auth.id_to_user + def id_to_user(user_id): + return user_id + + @login_auth.user_to_id + def user_to_id(user): + return user + + @app.get('/') + @login_auth + async def index(request): + return 'ok' + + @app.post('/login') + async def login(request): + return await login_auth.login_user(request, 'user') + + client = TestClient(app) + res = self._run(client.post('/login?next=http://example.com')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') diff --git a/tests/test_session.py b/tests/test_session.py index aedb7b2..326c63a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -37,7 +37,7 @@ class TestSession(unittest.TestCase): @app.post('/set') @with_session - async def save_session(req, session): + def save_session(req, session): session['name'] = 'joe' session.save() return 'OK'