remember tests
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from hashlib import sha1
|
||||
from microdot import Microdot, redirect
|
||||
from microdot.session import Session
|
||||
from microdot.auth import LoginAuth
|
||||
from microdot.auth import Login
|
||||
|
||||
|
||||
def create_hash(password):
|
||||
@@ -15,7 +15,7 @@ USERS = {
|
||||
|
||||
app = Microdot()
|
||||
Session(app, secret_key='top-secret!')
|
||||
auth = LoginAuth()
|
||||
auth = Login()
|
||||
|
||||
|
||||
@auth.id_to_user
|
||||
@@ -28,29 +28,34 @@ async def get_user_id(user):
|
||||
return user
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@auth
|
||||
async def index(request):
|
||||
return f'''
|
||||
<h1>Login Auth Example</h1>
|
||||
<p>Hello, {request.g.current_user}!</p>
|
||||
<form method="POST" action="/logout">
|
||||
<button type="submit">Logout</button>
|
||||
</form>
|
||||
''', {'Content-Type': 'text/html'}
|
||||
|
||||
|
||||
@app.route('/login', methods=['GET', 'POST'])
|
||||
async def login(request):
|
||||
if request.method == 'GET':
|
||||
return '''
|
||||
<h1>Login Auth Example</h1>
|
||||
<form method="POST">
|
||||
<input name="username" placeholder="username" autofocus>
|
||||
<input name="password" type="password" placeholder="password">
|
||||
<br><input name="remember_me" type="checkbox"> Remember me
|
||||
<br><button type="submit">Login</button>
|
||||
</form>
|
||||
<!doctype html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Please Login</h1>
|
||||
<form method="POST">
|
||||
<p>
|
||||
Username<br>
|
||||
<input name="username" autofocus>
|
||||
</p>
|
||||
<p>
|
||||
Password:<br>
|
||||
<input name="password" type="password">
|
||||
<br>
|
||||
</p>
|
||||
<p>
|
||||
<input name="remember_me" type="checkbox"> Remember me
|
||||
<br>
|
||||
</p>
|
||||
<p>
|
||||
<button type="submit">Login</button>
|
||||
</p>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
''', {'Content-Type': 'text/html'}
|
||||
username = request.form['username']
|
||||
password = request.form['password']
|
||||
@@ -61,11 +66,37 @@ async def login(request):
|
||||
return redirect('/login')
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@auth
|
||||
async def index(request):
|
||||
return f'''
|
||||
<!doctype html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Hello, {request.g.current_user}!</h1>
|
||||
<p>
|
||||
<a href="/fresh">Click here</a> to access the fresh login page.
|
||||
</p>
|
||||
<form method="POST" action="/logout">
|
||||
<button type="submit">Logout</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
''', {'Content-Type': 'text/html'}
|
||||
|
||||
|
||||
@app.get('/fresh')
|
||||
@auth.fresh
|
||||
async def fresh(request):
|
||||
return '''
|
||||
<h1>Fresh Login only</h1>
|
||||
return f'''
|
||||
<!doctype html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Hello, {request.g.current_user}!</h1>
|
||||
<p>This page requires a fresh login session.</p>
|
||||
<p><a href="/">Go back</a> to the main page.</p>
|
||||
</body>
|
||||
</html>
|
||||
''', {'Content-Type': 'text/html'}
|
||||
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class TokenAuth(HTTPAuth):
|
||||
self.error_callback = f
|
||||
|
||||
|
||||
class LoginAuth(BaseAuth):
|
||||
class Login(BaseAuth):
|
||||
def __init__(self, login_url='/login'):
|
||||
super().__init__()
|
||||
self.login_url = login_url
|
||||
@@ -132,6 +132,7 @@ class LoginAuth(BaseAuth):
|
||||
async def _set_remember_cookie(request, response):
|
||||
response.set_cookie('_remember', remember_payload,
|
||||
max_age=days * 24 * 60 * 60)
|
||||
print(response.headers)
|
||||
return response
|
||||
|
||||
def _get_auth(self, request):
|
||||
@@ -142,7 +143,7 @@ class LoginAuth(BaseAuth):
|
||||
remember_payload = request.app._session.decode(
|
||||
request.cookies['_remember'])
|
||||
user_id = remember_payload.get('user_id')
|
||||
if user_id:
|
||||
if user_id: # pragma: no branch
|
||||
self._update_remember_cookie(
|
||||
request, remember_payload.get('_days', 30), user_id)
|
||||
session['_user_id'] = user_id
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import binascii
|
||||
import unittest
|
||||
from microdot import Microdot
|
||||
from microdot.auth import BasicAuth, TokenAuth, LoginAuth
|
||||
from microdot.auth import BasicAuth, TokenAuth, Login
|
||||
from microdot.session import Session
|
||||
from microdot.test_client import TestClient
|
||||
|
||||
@@ -128,24 +128,25 @@ class TestAuth(unittest.TestCase):
|
||||
def test_login_auth(self):
|
||||
app = Microdot()
|
||||
Session(app, secret_key='secret')
|
||||
login_auth = LoginAuth()
|
||||
login_auth = Login()
|
||||
|
||||
@login_auth.id_to_user
|
||||
def id_to_user(user_id):
|
||||
return user_id
|
||||
return {'id': int(user_id), 'name': f'user{user_id}'}
|
||||
|
||||
@login_auth.user_to_id
|
||||
def user_to_id(user):
|
||||
return user
|
||||
return str(user['id'])
|
||||
|
||||
@app.get('/')
|
||||
@login_auth
|
||||
def index(request):
|
||||
return request.g.current_user
|
||||
return request.g.current_user['name']
|
||||
|
||||
@app.post('/login')
|
||||
async def login(request):
|
||||
return await login_auth.login_user(request, 'user')
|
||||
return await login_auth.login_user(
|
||||
request, {'id': 123, 'name': 'user123'})
|
||||
|
||||
@app.post('/logout')
|
||||
async def logout(request):
|
||||
@@ -160,10 +161,12 @@ class TestAuth(unittest.TestCase):
|
||||
res = self._run(client.post('/login?next=/%3Ffoo=bar'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/?foo=bar')
|
||||
self.assertEqual(len(res.headers['Set-Cookie']), 1)
|
||||
self.assertIn('session', client.cookies)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'user')
|
||||
self.assertEqual(res.text, 'user123')
|
||||
|
||||
res = self._run(client.post('/logout'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
@@ -174,7 +177,7 @@ class TestAuth(unittest.TestCase):
|
||||
def test_login_auth_bad_redirect(self):
|
||||
app = Microdot()
|
||||
Session(app, secret_key='secret')
|
||||
login_auth = LoginAuth()
|
||||
login_auth = Login()
|
||||
|
||||
@login_auth.id_to_user
|
||||
def id_to_user(user_id):
|
||||
@@ -197,3 +200,65 @@ class TestAuth(unittest.TestCase):
|
||||
res = self._run(client.post('/login?next=http://example.com'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/')
|
||||
|
||||
def test_login_remember(self):
|
||||
app = Microdot()
|
||||
Session(app, secret_key='secret')
|
||||
login_auth = Login()
|
||||
|
||||
@login_auth.id_to_user
|
||||
def id_to_user(user_id):
|
||||
return user_id
|
||||
|
||||
@login_auth.user_to_id
|
||||
def user_to_id(user):
|
||||
return user
|
||||
|
||||
@app.get('/')
|
||||
@login_auth
|
||||
def index(request):
|
||||
return request.g.current_user
|
||||
|
||||
@app.post('/login')
|
||||
async def login(request):
|
||||
return await login_auth.login_user(request, 'user', remember=True)
|
||||
|
||||
@app.post('/logout')
|
||||
async def logout(request):
|
||||
await login_auth.logout_user(request)
|
||||
return 'ok'
|
||||
|
||||
@app.get('/fresh')
|
||||
@login_auth.fresh
|
||||
async def fresh(request):
|
||||
return f'fresh {request.g.current_user}'
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.post('/login?next=/%3Ffoo=bar'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/?foo=bar')
|
||||
self.assertEqual(len(res.headers['Set-Cookie']), 2)
|
||||
self.assertIn('session', client.cookies)
|
||||
self.assertIn('_remember', client.cookies)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'user')
|
||||
res = self._run(client.get('/fresh'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'fresh user')
|
||||
|
||||
del client.cookies['session']
|
||||
print(client.cookies)
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
res = self._run(client.get('/fresh'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/login?next=/fresh')
|
||||
|
||||
res = self._run(client.post('/logout'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertFalse('_remember' in client.cookies)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
|
||||
Reference in New Issue
Block a user