diff --git a/docs/extensions.rst b/docs/extensions.rst index c8bfffd..69904f0 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -414,10 +414,25 @@ decorator:: While running an authenticated request, the user object returned by the authenticaction function is accessible as ``request.g.current_user``. +If an endpoint is intended to work with or without authentication, then it can +be protected with the ``auth.optional`` decorator:: + + @app.route('/') + @auth.optional + async def index(request): + if g.current_user: + return f'Hello, {request.g.current_user}!' + else: + return 'Hello, anonymous user!' + +As shown in the example, a route can check ``g.current_user`` to determine if +the user is authenticated or not. + Token Authentication ^^^^^^^^^^^^^^^^^^^^ -To set up token authentication, create an instance of :class:`TokenAuth `:: +To set up token authentication, create an instance of +:class:`TokenAuth `:: from microdot.auth import TokenAuth @@ -437,7 +452,17 @@ protect your routes:: @auth async def index(request): return f'Hello, {request.g.current_user}!' - + +Optional authentication can also be used with tokens:: + + @app.route('/') + @auth.optional + async def index(request): + if g.current_user: + return f'Hello, {request.g.current_user}!' + else: + return 'Hello, anonymous user!' + User Logins ~~~~~~~~~~~ diff --git a/src/microdot/auth.py b/src/microdot/auth.py index 1fcf687..a6536c2 100644 --- a/src/microdot/auth.py +++ b/src/microdot/auth.py @@ -36,6 +36,24 @@ class BaseAuth: return wrapper + def optional(self, f): + """Decorator to protect a route with optional authentication. + + This decorator makes authentication for the decorated route optional, + meaning that the route is allowed to run with or with + authentication given in the request. + """ + async def wrapper(request, *args, **kwargs): + auth = self._get_auth(request) + if not auth: + request.g.current_user = None + else: + request.g.current_user = await invoke_handler( + self.auth_callback, request, *auth) + return await invoke_handler(f, request, *args, **kwargs) + + return wrapper + class BasicAuth(BaseAuth): """Basic Authentication. diff --git a/tests/test_auth.py b/tests/test_auth.py index bd64365..b8b397f 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -45,6 +45,38 @@ class TestAuth(unittest.TestCase): b'foo:baz').decode()})) self.assertEqual(res.status_code, 401) + def test_basic_optional_auth(self): + app = Microdot() + basic_auth = BasicAuth() + + @basic_auth.authenticate + def authenticate(request, username, password): + if username == 'foo' and password == 'bar': + return {'username': username} + + @app.route('/') + @basic_auth.optional + def index(request): + return request.g.current_user['username'] \ + if request.g.current_user else '' + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:bar').decode()})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'foo') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:baz').decode()})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + def test_token_auth(self): app = Microdot() token_auth = TokenAuth() @@ -67,7 +99,7 @@ class TestAuth(unittest.TestCase): 'Authorization': 'Basic foo'})) self.assertEqual(res.status_code, 401) - res = self._run(client.get('/', headers={'Authorization': 'foo'})) + res = self._run(client.get('/', headers={'Authorization': 'invalid'})) self.assertEqual(res.status_code, 401) res = self._run(client.get('/', headers={ @@ -75,6 +107,39 @@ class TestAuth(unittest.TestCase): self.assertEqual(res.status_code, 200) self.assertEqual(res.text, 'user') + def test_token_optional_auth(self): + app = Microdot() + token_auth = TokenAuth() + + @token_auth.authenticate + def authenticate(request, token): + if token == 'foo': + return 'user' + + @app.route('/') + @token_auth.optional + def index(request): + return request.g.current_user or '' + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={'Authorization': 'foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Bearer foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + def test_token_auth_custom_header(self): app = Microdot() token_auth = TokenAuth(header='X-Auth-Token')