Basic, token and login authentication
This commit is contained in:
147
src/microdot/auth.py
Normal file
147
src/microdot/auth.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from microdot import abort, redirect
|
||||
from microdot.microdot import urlencode, invoke_handler
|
||||
|
||||
|
||||
class BaseAuth:
|
||||
def __init__(self):
|
||||
self.auth_callback = None
|
||||
self.error_callback = lambda request: abort(401)
|
||||
|
||||
def __call__(self, f):
|
||||
"""Decorator to protect a route with authentication.
|
||||
|
||||
Microdot will only call the route if the authentication callback
|
||||
returns a valid user object, otherwise it will call the error
|
||||
callback."""
|
||||
async def wrapper(request, *args, **kwargs):
|
||||
auth = self._get_auth(request)
|
||||
if not auth:
|
||||
return await invoke_handler(self.error_callback, request)
|
||||
request.g.current_user = await invoke_handler(
|
||||
self.auth_callback, request, *auth)
|
||||
if not request.g.current_user:
|
||||
return await invoke_handler(self.error_callback, request)
|
||||
return await invoke_handler(f, request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class HTTPAuth(BaseAuth):
|
||||
def authenticate(self, f):
|
||||
"""Decorator to configure the authentication callback.
|
||||
|
||||
Microdot calls the authentication callback to allow the application to
|
||||
check user credentials.
|
||||
"""
|
||||
self.auth_callback = f
|
||||
|
||||
|
||||
class BasicAuth(HTTPAuth):
|
||||
def __init__(self, realm='Please login', charset='UTF-8', scheme='Basic',
|
||||
error_status=401):
|
||||
super().__init__()
|
||||
self.realm = realm
|
||||
self.charset = charset
|
||||
self.scheme = scheme
|
||||
self.error_status = error_status
|
||||
self.error_callback = self.authentication_error
|
||||
|
||||
def _get_auth(self, request):
|
||||
auth = request.headers.get('Authorization')
|
||||
if auth and auth.startswith('Basic '):
|
||||
import binascii
|
||||
try:
|
||||
username, password = binascii.a2b_base64(
|
||||
auth[6:]).decode().split(':', 1)
|
||||
except Exception: # pragma: no cover
|
||||
return None
|
||||
return username, password
|
||||
|
||||
def authentication_error(self, request):
|
||||
return '', self.error_status, {
|
||||
'WWW-Authenticate': '{} realm="{}", charset="{}"'.format(
|
||||
self.scheme, self.realm, self.charset)}
|
||||
|
||||
|
||||
class TokenAuth(HTTPAuth):
|
||||
def __init__(self, header='Authorization', scheme='Bearer'):
|
||||
super().__init__()
|
||||
self.header = header
|
||||
self.scheme = scheme.lower()
|
||||
|
||||
def _get_auth(self, request):
|
||||
auth = request.headers.get(self.header)
|
||||
if auth:
|
||||
if self.header == 'Authorization':
|
||||
try:
|
||||
scheme, token = auth.split(' ', 1)
|
||||
except Exception:
|
||||
return None
|
||||
if scheme.lower() == self.scheme:
|
||||
return (token.strip(),)
|
||||
else:
|
||||
return (auth,)
|
||||
|
||||
def errorhandler(self, f):
|
||||
"""Decorator to configure the error callback.
|
||||
|
||||
Microdot calls the error callback to allow the application to generate
|
||||
a custom error response. The default error response is to call
|
||||
``abort(401)``.
|
||||
"""
|
||||
self.error_callback = f
|
||||
|
||||
|
||||
class LoginAuth(BaseAuth):
|
||||
def __init__(self, login_url='/login'):
|
||||
super().__init__()
|
||||
self.login_url = login_url
|
||||
self.user_callback = None
|
||||
self.user_id_callback = None
|
||||
self.auth_callback = self._authenticate
|
||||
self.error_callback = self._redirect_to_login
|
||||
|
||||
def id_to_user(self, f):
|
||||
"""Decorator to configure the user callback.
|
||||
|
||||
Microdot calls the user callback to load the user object from the
|
||||
user ID stored in the user session.
|
||||
"""
|
||||
self.user_callback = f
|
||||
|
||||
def user_to_id(self, f):
|
||||
"""Decorator to configure the user ID callback.
|
||||
|
||||
Microdot calls the user ID callback to load the user ID from the
|
||||
user session.
|
||||
"""
|
||||
self.user_id_callback = f
|
||||
|
||||
def _get_session(self, request):
|
||||
return request.app._session.get(request)
|
||||
|
||||
def _get_auth(self, request):
|
||||
session = self._get_session(request)
|
||||
if session and 'user_id' in session:
|
||||
return (session['user_id'],)
|
||||
|
||||
async def _authenticate(self, request, user_id):
|
||||
return await invoke_handler(self.user_callback, user_id)
|
||||
|
||||
async def _redirect_to_login(self, request):
|
||||
return '', 302, {'Location': self.login_url + '?next=' + urlencode(
|
||||
request.url)}
|
||||
|
||||
async def login_user(self, request, user, redirect_url='/'):
|
||||
session = self._get_session(request)
|
||||
session['user_id'] = await invoke_handler(self.user_id_callback, user)
|
||||
session.save()
|
||||
next_url = request.args.get('next', redirect_url)
|
||||
if not next_url.startswith('/'):
|
||||
next_url = redirect_url
|
||||
return redirect(next_url)
|
||||
|
||||
async def logout_user(self, request):
|
||||
session = self._get_session(request)
|
||||
session.pop('user_id', None)
|
||||
session.save()
|
||||
@@ -9,3 +9,4 @@ from tests.test_sse import * # noqa: F401, F403
|
||||
from tests.test_cors import * # noqa: F401, F403
|
||||
from tests.test_utemplate import * # noqa: F401, F403
|
||||
from tests.test_session import * # noqa: F401, F403
|
||||
from tests.test_auth import * # noqa: F401, F403
|
||||
|
||||
199
tests/test_auth.py
Normal file
199
tests/test_auth.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import asyncio
|
||||
import binascii
|
||||
import unittest
|
||||
from microdot import Microdot
|
||||
from microdot.auth import BasicAuth, TokenAuth, LoginAuth
|
||||
from microdot.session import Session
|
||||
from microdot.test_client import TestClient
|
||||
|
||||
|
||||
class TestAuth(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if hasattr(asyncio, 'set_event_loop'):
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
cls.loop = asyncio.get_event_loop()
|
||||
|
||||
def _run(self, coro):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
def test_basic_auth(self):
|
||||
app = Microdot()
|
||||
basic_auth = BasicAuth()
|
||||
|
||||
@basic_auth.authenticate
|
||||
def authenticate(request, username, password):
|
||||
if username == 'foo' and password == 'bar':
|
||||
return {'username': username}
|
||||
|
||||
@app.route('/')
|
||||
@basic_auth
|
||||
def index(request):
|
||||
return request.g.current_user['username']
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'Authorization': 'Basic ' + binascii.b2a_base64(
|
||||
b'foo:bar').decode()}))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'foo')
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'Authorization': 'Basic ' + binascii.b2a_base64(
|
||||
b'foo:baz').decode()}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
def test_token_auth(self):
|
||||
app = Microdot()
|
||||
token_auth = TokenAuth()
|
||||
|
||||
@token_auth.authenticate
|
||||
def authenticate(request, token):
|
||||
if token == 'foo':
|
||||
return 'user'
|
||||
|
||||
@app.route('/')
|
||||
@token_auth
|
||||
def index(request):
|
||||
return request.g.current_user
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'Authorization': 'Basic foo'}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={'Authorization': 'foo'}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'Authorization': 'Bearer foo'}))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'user')
|
||||
|
||||
def test_token_auth_custom_header(self):
|
||||
app = Microdot()
|
||||
token_auth = TokenAuth(header='X-Auth-Token')
|
||||
|
||||
@token_auth.authenticate
|
||||
def authenticate(request, token):
|
||||
if token == 'foo':
|
||||
return 'user'
|
||||
|
||||
@app.route('/')
|
||||
@token_auth
|
||||
def index(request):
|
||||
return request.g.current_user
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'Authorization': 'Basic foo'}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={'Authorization': 'foo'}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'Authorization': 'Bearer foo'}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={
|
||||
'X-Token-Auth': 'Bearer foo'}))
|
||||
self.assertEqual(res.status_code, 401)
|
||||
|
||||
res = self._run(client.get('/', headers={'X-Auth-Token': 'foo'}))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'user')
|
||||
|
||||
res = self._run(client.get('/', headers={'x-auth-token': 'foo'}))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'user')
|
||||
|
||||
@token_auth.errorhandler
|
||||
def error_handler(request):
|
||||
return {'status_code': 403}, 403
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
self.assertEqual(res.json, {'status_code': 403})
|
||||
|
||||
def test_login_auth(self):
|
||||
app = Microdot()
|
||||
Session(app, secret_key='secret')
|
||||
login_auth = LoginAuth()
|
||||
|
||||
@login_auth.id_to_user
|
||||
def id_to_user(user_id):
|
||||
return user_id
|
||||
|
||||
@login_auth.user_to_id
|
||||
def user_to_id(user):
|
||||
return user
|
||||
|
||||
@app.get('/')
|
||||
@login_auth
|
||||
def index(request):
|
||||
return request.g.current_user
|
||||
|
||||
@app.post('/login')
|
||||
async def login(request):
|
||||
return await login_auth.login_user(request, 'user')
|
||||
|
||||
@app.post('/logout')
|
||||
async def logout(request):
|
||||
await login_auth.logout_user(request)
|
||||
return 'ok'
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.get('/?foo=bar'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar')
|
||||
|
||||
res = self._run(client.post('/login?next=/%3Ffoo=bar'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/?foo=bar')
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.text, 'user')
|
||||
|
||||
res = self._run(client.post('/logout'))
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
|
||||
def test_login_auth_bad_redirect(self):
|
||||
app = Microdot()
|
||||
Session(app, secret_key='secret')
|
||||
login_auth = LoginAuth()
|
||||
|
||||
@login_auth.id_to_user
|
||||
def id_to_user(user_id):
|
||||
return user_id
|
||||
|
||||
@login_auth.user_to_id
|
||||
def user_to_id(user):
|
||||
return user
|
||||
|
||||
@app.get('/')
|
||||
@login_auth
|
||||
async def index(request):
|
||||
return 'ok'
|
||||
|
||||
@app.post('/login')
|
||||
async def login(request):
|
||||
return await login_auth.login_user(request, 'user')
|
||||
|
||||
client = TestClient(app)
|
||||
res = self._run(client.post('/login?next=http://example.com'))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(res.headers['Location'], '/')
|
||||
@@ -37,7 +37,7 @@ class TestSession(unittest.TestCase):
|
||||
|
||||
@app.post('/set')
|
||||
@with_session
|
||||
async def save_session(req, session):
|
||||
def save_session(req, session):
|
||||
session['name'] = 'joe'
|
||||
session.save()
|
||||
return 'OK'
|
||||
|
||||
Reference in New Issue
Block a user