diff --git a/examples/auth/login_auth.py b/examples/auth/login_auth.py index 97d730e..4a7a994 100644 --- a/examples/auth/login_auth.py +++ b/examples/auth/login_auth.py @@ -1,7 +1,7 @@ from hashlib import sha1 from microdot import Microdot, redirect from microdot.session import Session -from microdot.auth import LoginAuth +from microdot.auth import Login def create_hash(password): @@ -15,7 +15,7 @@ USERS = { app = Microdot() Session(app, secret_key='top-secret!') -auth = LoginAuth() +auth = Login() @auth.id_to_user @@ -28,29 +28,34 @@ async def get_user_id(user): return user -@app.route('/') -@auth -async def index(request): - return f''' -
Hello, {request.g.current_user}!
- - ''', {'Content-Type': 'text/html'} - - @app.route('/login', methods=['GET', 'POST']) async def login(request): if request.method == 'GET': return ''' -+ Click here to access the fresh login page. +
+ + + + ''', {'Content-Type': 'text/html'} + + @app.get('/fresh') @auth.fresh async def fresh(request): - return ''' -This page requires a fresh login session.
+Go back to the main page.
+ + ''', {'Content-Type': 'text/html'} diff --git a/src/microdot/auth.py b/src/microdot/auth.py index ce00824..3097bc6 100644 --- a/src/microdot/auth.py +++ b/src/microdot/auth.py @@ -93,7 +93,7 @@ class TokenAuth(HTTPAuth): self.error_callback = f -class LoginAuth(BaseAuth): +class Login(BaseAuth): def __init__(self, login_url='/login'): super().__init__() self.login_url = login_url @@ -132,6 +132,7 @@ class LoginAuth(BaseAuth): async def _set_remember_cookie(request, response): response.set_cookie('_remember', remember_payload, max_age=days * 24 * 60 * 60) + print(response.headers) return response def _get_auth(self, request): @@ -142,7 +143,7 @@ class LoginAuth(BaseAuth): remember_payload = request.app._session.decode( request.cookies['_remember']) user_id = remember_payload.get('user_id') - if user_id: + if user_id: # pragma: no branch self._update_remember_cookie( request, remember_payload.get('_days', 30), user_id) session['_user_id'] = user_id diff --git a/tests/test_auth.py b/tests/test_auth.py index 1e33dcd..bf3f018 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,7 +2,7 @@ import asyncio import binascii import unittest from microdot import Microdot -from microdot.auth import BasicAuth, TokenAuth, LoginAuth +from microdot.auth import BasicAuth, TokenAuth, Login from microdot.session import Session from microdot.test_client import TestClient @@ -128,24 +128,25 @@ class TestAuth(unittest.TestCase): def test_login_auth(self): app = Microdot() Session(app, secret_key='secret') - login_auth = LoginAuth() + login_auth = Login() @login_auth.id_to_user def id_to_user(user_id): - return user_id + return {'id': int(user_id), 'name': f'user{user_id}'} @login_auth.user_to_id def user_to_id(user): - return user + return str(user['id']) @app.get('/') @login_auth def index(request): - return request.g.current_user + return request.g.current_user['name'] @app.post('/login') async def login(request): - return await login_auth.login_user(request, 'user') + return await login_auth.login_user( + request, {'id': 123, 'name': 'user123'}) @app.post('/logout') async def logout(request): @@ -160,10 +161,12 @@ class TestAuth(unittest.TestCase): res = self._run(client.post('/login?next=/%3Ffoo=bar')) self.assertEqual(res.status_code, 302) self.assertEqual(res.headers['Location'], '/?foo=bar') + self.assertEqual(len(res.headers['Set-Cookie']), 1) + self.assertIn('session', client.cookies) res = self._run(client.get('/')) self.assertEqual(res.status_code, 200) - self.assertEqual(res.text, 'user') + self.assertEqual(res.text, 'user123') res = self._run(client.post('/logout')) self.assertEqual(res.status_code, 200) @@ -174,7 +177,7 @@ class TestAuth(unittest.TestCase): def test_login_auth_bad_redirect(self): app = Microdot() Session(app, secret_key='secret') - login_auth = LoginAuth() + login_auth = Login() @login_auth.id_to_user def id_to_user(user_id): @@ -197,3 +200,65 @@ class TestAuth(unittest.TestCase): res = self._run(client.post('/login?next=http://example.com')) self.assertEqual(res.status_code, 302) self.assertEqual(res.headers['Location'], '/') + + def test_login_remember(self): + app = Microdot() + Session(app, secret_key='secret') + login_auth = Login() + + @login_auth.id_to_user + def id_to_user(user_id): + return user_id + + @login_auth.user_to_id + def user_to_id(user): + return user + + @app.get('/') + @login_auth + def index(request): + return request.g.current_user + + @app.post('/login') + async def login(request): + return await login_auth.login_user(request, 'user', remember=True) + + @app.post('/logout') + async def logout(request): + await login_auth.logout_user(request) + return 'ok' + + @app.get('/fresh') + @login_auth.fresh + async def fresh(request): + return f'fresh {request.g.current_user}' + + client = TestClient(app) + res = self._run(client.post('/login?next=/%3Ffoo=bar')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/?foo=bar') + self.assertEqual(len(res.headers['Set-Cookie']), 2) + self.assertIn('session', client.cookies) + self.assertIn('_remember', client.cookies) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + res = self._run(client.get('/fresh')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'fresh user') + + del client.cookies['session'] + print(client.cookies) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + res = self._run(client.get('/fresh')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/fresh') + + res = self._run(client.post('/logout')) + self.assertEqual(res.status_code, 200) + self.assertFalse('_remember' in client.cookies) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 302)