10 Commits

Author SHA1 Message Date
Miguel Grinberg
7ee1c7eef9 Authentication support 2022-09-24 19:54:26 +01:00
Miguel Grinberg
01947b101e Cache user session 2022-09-24 19:40:28 +01:00
Miguel Grinberg
1547e861ee request.url attribute with the complete URL of the request 2022-09-24 19:33:46 +01:00
Miguel Grinberg
672512e086 urlencode() function 2022-09-24 19:33:10 +01:00
Miguel Grinberg
a8515c97b0 Small performance improvement for NoCaseDict 2022-09-24 15:37:52 +01:00
Miguel Grinberg
8ebe81c09b File upload example 2022-09-22 17:52:48 +01:00
Miguel Grinberg
4f263c63ab Minor documentation styling fixes 2022-09-21 23:38:51 +01:00
Miguel Grinberg
b0fd6c4323 Use a case insensitive dict for headers 2022-09-21 23:29:01 +01:00
Miguel Grinberg
cbefb6bf3a Do not log HTTPException occurrences 2022-09-19 23:50:04 +01:00
Miguel Grinberg
c81a2649c5 Version 1.1.2.dev0 2022-09-18 11:28:48 +01:00
22 changed files with 706 additions and 76 deletions

View File

@@ -1,3 +1,3 @@
.py .class, .py .method, .py .property { .py.class, .py.function, .py.method, .py.property {
margin-top: 20px; margin-top: 20px;
} }

View File

@@ -13,6 +13,9 @@ API Reference
.. autoclass:: microdot.Response .. autoclass:: microdot.Response
:members: :members:
.. autoclass:: microdot.NoCaseDict
:members:
.. autoclass:: microdot.MultiDict .. autoclass:: microdot.MultiDict
:members: :members:

View File

@@ -0,0 +1,27 @@
from microdot import Microdot
from microdot_auth import BasicAuth
app = Microdot()
basic_auth = BasicAuth()
USERS = {
'susan': 'hello',
'david': 'bye',
}
@basic_auth.callback
def verify_password(request, username, password):
if username in USERS and USERS[username] == password:
request.g.user = username
return True
@app.route('/')
@basic_auth
def index(request):
return f'Hello, {request.g.user}!'
if __name__ == '__main__':
app.run(debug=True)

View File

@@ -0,0 +1,60 @@
from microdot import Microdot, redirect
from microdot_session import set_session_secret_key
from microdot_login import LoginAuth
app = Microdot()
set_session_secret_key('top-secret')
login_auth = LoginAuth()
USERS = {
'susan': 'hello',
'david': 'bye',
}
@login_auth.callback
def check_user(request, user_id):
request.g.user = user_id
return True
@app.route('/')
@login_auth
def index(request):
return f'''
<h1>Login Auth Example</h1>
<p>Hello, {request.g.user}!</p>
<form method="POST" action="/logout">
<button type="submit">Logout</button>
</form>
''', {'Content-Type': 'text/html'}
@app.route('/login', methods=['GET', 'POST'])
def login(request):
if request.method == 'GET':
return '''
<h1>Login Auth Example</h1>
<form method="POST">
<input name="username" placeholder="username">
<input name="password" type="password" placeholder="password">
<button type="submit">Login</button>
</form>
''', {'Content-Type': 'text/html'}
username = request.form['username']
password = request.form['password']
if USERS.get(username) == password:
login_auth.login_user(request, username)
return login_auth.redirect_to_next(request)
else:
return redirect('/login')
@app.post('/logout')
def logout(request):
login_auth.logout_user(request)
return redirect('/')
if __name__ == '__main__':
app.run(debug=True)

View File

@@ -0,0 +1,27 @@
from microdot import Microdot
from microdot_auth import TokenAuth
app = Microdot()
token_auth = TokenAuth()
TOKENS = {
'hello': 'susan',
'bye': 'david',
}
@token_auth.callback
def verify_token(request, token):
if token in TOKENS:
request.g.user = TOKENS[token]
return True
@app.route('/')
@token_auth
def index(request):
return f'Hello, {request.g.user}!'
if __name__ == '__main__':
app.run(debug=True)

View File

@@ -0,0 +1 @@
This directory contains file upload examples.

View File

@@ -0,0 +1 @@
Uploaded files are saved to this directory.

View File

@@ -0,0 +1,34 @@
<!doctype html>
<html>
<head>
<title>Microdot Upload Example</title>
</head>
<body>
<h1>Microdot Upload Example</h1>
<form id="form">
<input type="file" id="file" name="file" />
<input type="submit" value="Upload" />
</form>
<script>
async function upload(ev) {
ev.preventDefault();
const file = document.getElementById('file').files[0];
if (!file) {
return;
}
await fetch('/upload', {
method: 'POST',
body: file,
headers: {
'Content-Type': 'application/octet-stream',
'Content-Disposition': `attachment; filename="${file.name}"`,
},
}).then(res => {
console.log('Upload accepted');
window.location.href = '/';
});
}
document.getElementById('form').addEventListener('submit', upload);
</script>
</body>
</html>

View File

@@ -0,0 +1,33 @@
from microdot import Microdot, send_file
app = Microdot()
@app.get('/')
def index(request):
return send_file('index.html')
@app.post('/upload')
def upload(request):
# obtain the filename and size from request headers
filename = request.headers['Content-Disposition'].split(
'filename=')[1].strip('"')
size = int(request.headers['Content-Length'])
# sanitize the filename
filename = filename.replace('/', '_')
# write the file to the files directory in 1K chunks
with open('files/' + filename, 'wb') as f:
while size > 0:
chunk = request.stream.read(min(size, 1024))
f.write(chunk)
size -= len(chunk)
print('Successfully saved file: ' + filename)
return ''
if __name__ == '__main__':
app.run()

View File

@@ -1,6 +1,6 @@
[metadata] [metadata]
name = microdot name = microdot
version = 1.1.1 version = 1.1.2.dev0
author = Miguel Grinberg author = Miguel Grinberg
author_email = miguel.grinberg@gmail.com author_email = miguel.grinberg@gmail.com
description = The impossibly small web framework for MicroPython description = The impossibly small web framework for MicroPython
@@ -28,6 +28,8 @@ py_modules =
microdot_utemplate microdot_utemplate
microdot_jinja microdot_jinja
microdot_session microdot_session
microdot_auth
microdot_login
microdot_websocket microdot_websocket
microdot_websocket_alt microdot_websocket_alt
microdot_asyncio_websocket microdot_asyncio_websocket

View File

@@ -91,6 +91,59 @@ def urldecode_bytes(s):
return b''.join(result).decode() return b''.join(result).decode()
def urlencode(s):
return s.replace(' ', '+').replace('%', '%25').replace('?', '%3F').replace(
'#', '%23').replace('&', '%26').replace('+', '%2B')
class NoCaseDict(dict):
"""A subclass of dictionary that holds case-insensitive keys.
:param initial_dict: an initial dictionary of key/value pairs to
initialize this object with.
Example::
>>> d = NoCaseDict()
>>> d['Content-Type'] = 'text/html'
>>> print(d['Content-Type'])
text/html
>>> print(d['content-type'])
text/html
>>> print(d['CONTENT-TYPE'])
text/html
>>> del d['cOnTeNt-TyPe']
>>> print(d)
{}
"""
def __init__(self, initial_dict=None):
super().__init__(initial_dict or {})
self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k}
def __setitem__(self, key, value):
kl = key.lower()
key = self.keymap.get(kl, key)
if kl != key:
self.keymap[kl] = key
super().__setitem__(key, value)
def __getitem__(self, key):
kl = key.lower()
return super().__getitem__(self.keymap.get(kl, kl))
def __delitem__(self, key):
kl = key.lower()
super().__delitem__(self.keymap.get(kl, kl))
def __contains__(self, key):
kl = key.lower()
return self.keymap.get(kl, kl) in self.keys()
def get(self, key, default=None):
kl = key.lower()
return super().get(self.keymap.get(kl, kl), default)
class MultiDict(dict): class MultiDict(dict):
"""A subclass of dictionary that can hold multiple values for the same """A subclass of dictionary that can hold multiple values for the same
key. It is used to hold key/value pairs decoded from query strings and key. It is used to hold key/value pairs decoded from query strings and
@@ -224,6 +277,8 @@ class Request():
self.client_addr = client_addr self.client_addr = client_addr
#: The HTTP method of the request. #: The HTTP method of the request.
self.method = method self.method = method
#: The request URL, including the path and query string.
self.url = url
#: The path portion of the URL. #: The path portion of the URL.
self.path = url self.path = url
#: The query string portion of the URL. #: The query string portion of the URL.
@@ -248,14 +303,12 @@ class Request():
self.path, self.query_string = self.path.split('?', 1) self.path, self.query_string = self.path.split('?', 1)
self.args = self._parse_urlencoded(self.query_string) self.args = self._parse_urlencoded(self.query_string)
for header, value in self.headers.items(): if 'Content-Length' in self.headers:
header = header.lower() self.content_length = int(self.headers['Content-Length'])
if header == 'content-length': if 'Content-Type' in self.headers:
self.content_length = int(value) self.content_type = self.headers['Content-Type']
elif header == 'content-type': if 'Cookie' in self.headers:
self.content_type = value for cookie in self.headers['Cookie'].split(';'):
elif header == 'cookie':
for cookie in value.split(';'):
name, value = cookie.strip().split('=', 1) name, value = cookie.strip().split('=', 1)
self.cookies[name] = value self.cookies[name] = value
@@ -289,7 +342,7 @@ class Request():
http_version = http_version.split('/', 1)[1] http_version = http_version.split('/', 1)[1]
# headers # headers
headers = {} headers = NoCaseDict()
while True: while True:
line = Request._safe_readline(client_stream).strip().decode() line = Request._safe_readline(client_stream).strip().decode()
if line == '': if line == '':
@@ -437,7 +490,7 @@ class Response():
body = '' body = ''
status_code = 204 status_code = 204
self.status_code = status_code self.status_code = status_code
self.headers = headers.copy() if headers else {} self.headers = NoCaseDict(headers or {})
self.reason = reason self.reason = reason
if isinstance(body, (dict, list)): if isinstance(body, (dict, list)):
self.body = json.dumps(body).encode() self.body = json.dumps(body).encode()
@@ -1044,7 +1097,6 @@ class Microdot():
else: else:
res = 'Not found', f res = 'Not found', f
except HTTPException as exc: except HTTPException as exc:
print_exception(exc)
if exc.status_code in self.error_handlers: if exc.status_code in self.error_handlers:
res = self.error_handlers[exc.status_code](req) res = self.error_handlers[exc.status_code](req)
else: else:

View File

@@ -4,6 +4,7 @@ import signal
from microdot_asyncio import * # noqa: F401, F403 from microdot_asyncio import * # noqa: F401, F403
from microdot_asyncio import Microdot as BaseMicrodot from microdot_asyncio import Microdot as BaseMicrodot
from microdot_asyncio import Request from microdot_asyncio import Request
from microdot import NoCaseDict
class _BodyStream: # pragma: no cover class _BodyStream: # pragma: no cover
@@ -55,7 +56,7 @@ class Microdot(BaseMicrodot):
path = scope['path'] path = scope['path']
if 'query_string' in scope and scope['query_string']: if 'query_string' in scope and scope['query_string']:
path += '?' + scope['query_string'].decode() path += '?' + scope['query_string'].decode()
headers = {} headers = NoCaseDict()
content_length = 0 content_length = 0
for key, value in scope.get('headers', []): for key, value in scope.get('headers', []):
headers[key] = value headers[key] = value

View File

@@ -17,9 +17,10 @@ except ImportError:
import io import io
from microdot import Microdot as BaseMicrodot from microdot import Microdot as BaseMicrodot
from microdot import print_exception from microdot import NoCaseDict
from microdot import Request as BaseRequest from microdot import Request as BaseRequest
from microdot import Response as BaseResponse from microdot import Response as BaseResponse
from microdot import print_exception
from microdot import HTTPException from microdot import HTTPException
from microdot import MUTED_SOCKET_ERRORS from microdot import MUTED_SOCKET_ERRORS
@@ -74,7 +75,7 @@ class Request(BaseRequest):
http_version = http_version.split('/', 1)[1] http_version = http_version.split('/', 1)[1]
# headers # headers
headers = {} headers = NoCaseDict()
content_length = 0 content_length = 0
while True: while True:
line = (await Request._safe_readline( line = (await Request._safe_readline(
@@ -386,7 +387,6 @@ class Microdot(BaseMicrodot):
else: else:
res = 'Not found', f res = 'Not found', f
except HTTPException as exc: except HTTPException as exc:
print_exception(exc)
if exc.status_code in self.error_handlers: if exc.status_code in self.error_handlers:
res = self.error_handlers[exc.status_code](req) res = self.error_handlers[exc.status_code](req)
else: else:

65
src/microdot_auth.py Normal file
View File

@@ -0,0 +1,65 @@
from microdot import abort
class BaseAuth:
def __init__(self, header='Authorization', scheme=None):
self.auth_callback = None
self.error_callback = self.auth_failed
self.header = header
self.scheme = scheme.lower()
def callback(self, f):
"""Decorator to configure the authentication callback.
Microdot calls the authentication callback to allow the application to
check user credentials.
"""
self.auth_callback = f
def errorhandler(self, f):
"""Decorator to configure the error callback.
Microdot calls the error callback to allow the application to generate
a custom error response. The default error response is to call
``abort(401)``.
"""
self.error_callback = f
def auth_failed(self):
abort(401)
def __call__(self, func):
def wrapper(request, *args, **kwargs):
auth = request.headers.get(self.header)
if not auth:
return self.error_callback()
if self.header == 'Authorization':
if ' ' not in auth:
return self.error_callback()
scheme, auth = auth.split(' ', 1)
if scheme.lower() != self.scheme:
return self.error_callback()
if not self.auth_callback(request, *self._get_auth_args(auth)):
return self.error_callback()
return func(request, *args, **kwargs)
return wrapper
class BasicAuth(BaseAuth):
def __init__(self):
super().__init__(scheme='Basic')
def _get_auth_args(self, auth):
import binascii
username, password = binascii.a2b_base64(auth).decode('utf-8').split(
':', 1)
return (username, password)
class TokenAuth(BaseAuth):
def __init__(self, header='Authorization', scheme='Bearer'):
super().__init__(header=header, scheme=scheme)
def _get_auth_args(self, token):
return (token,)

46
src/microdot_login.py Normal file
View File

@@ -0,0 +1,46 @@
from microdot import redirect, urlencode
from microdot_session import get_session, update_session
class LoginAuth:
def __init__(self, login_url='/login'):
super().__init__()
self.login_url = login_url
self.user_callback = self._accept_user
def callback(self, f):
self.user_callback = f
def login_user(self, request, user_id):
session = get_session(request)
session['user_id'] = user_id
update_session(request, session)
return session
def logout_user(self, request):
session = get_session(request)
session.pop('user_id', None)
update_session(request, session)
return session
def redirect_to_next(self, request, default_url='/'):
next_url = request.args.get('next', default_url)
if not next_url.startswith('/'):
next_url = default_url
return redirect(next_url)
def __call__(self, func):
def wrapper(request, *args, **kwargs):
session = get_session(request)
if 'user_id' not in session:
return redirect(self.login_url + '?next=' + urlencode(
request.url))
if not self.user_callback(request, session['user_id']):
return redirect(self.login_url + '?next=' + urlencode(
request.url))
return func(request, *args, **kwargs)
return wrapper
def _accept_user(self, request, user_id):
return True

View File

@@ -23,15 +23,19 @@ def get_session(request):
global secret_key global secret_key
if not secret_key: if not secret_key:
raise ValueError('The session secret key is not configured') raise ValueError('The session secret key is not configured')
if hasattr(request.g, '_session'):
return request.g._session
session = request.cookies.get('session') session = request.cookies.get('session')
if session is None: if session is None:
return {} request.g._session = {}
return request.g._session
try: try:
session = jwt.decode(session, secret_key, algorithms=['HS256']) session = jwt.decode(session, secret_key, algorithms=['HS256'])
except jwt.exceptions.PyJWTError: # pragma: no cover except jwt.exceptions.PyJWTError: # pragma: no cover
raise request.g._session = {}
return {} else:
return session request.g._session = session
return request.g._session
def update_session(request, session): def update_session(request, session):

View File

@@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
import json import json
from microdot import Request, Response from microdot import Request, Response, NoCaseDict
try: try:
from microdot_websocket import WebSocket from microdot_websocket import WebSocket
except: # pragma: no cover # noqa: E722 except: # pragma: no cover # noqa: E722
@@ -46,11 +46,10 @@ class TestResponse:
pass pass
def _process_json_body(self): def _process_json_body(self):
for name, value in self.headers.items(): # pragma: no branch if 'Content-Type' in self.headers: # pragma: no branch
if name.lower() == 'content-type': content_type = self.headers['Content-Type']
if value.lower().split(';')[0] == 'application/json': if content_type.split(';')[0] == 'application/json':
self.json = json.loads(self.text) self.json = json.loads(self.text)
break
@classmethod @classmethod
def create(cls, res): def create(cls, res):
@@ -97,13 +96,11 @@ class TestClient:
body = b'' body = b''
elif isinstance(body, (dict, list)): elif isinstance(body, (dict, list)):
body = json.dumps(body).encode() body = json.dumps(body).encode()
if 'Content-Type' not in headers and \ if 'Content-Type' not in headers: # pragma: no cover
'content-type' not in headers: # pragma: no cover
headers['Content-Type'] = 'application/json' headers['Content-Type'] = 'application/json'
elif isinstance(body, str): elif isinstance(body, str):
body = body.encode() body = body.encode()
if body and 'Content-Length' not in headers and \ if body and 'Content-Length' not in headers:
'content-length' not in headers:
headers['Content-Length'] = str(len(body)) headers['Content-Length'] = str(len(body))
if 'Host' not in headers: # pragma: no branch if 'Host' not in headers: # pragma: no branch
headers['Host'] = 'example.com:1234' headers['Host'] = 'example.com:1234'
@@ -132,9 +129,8 @@ class TestClient:
return request_bytes return request_bytes
def _update_cookies(self, res): def _update_cookies(self, res):
for name, value in res.headers.items(): cookies = res.headers.get('Set-Cookie', [])
if name.lower() == 'set-cookie': for cookie in cookies:
for cookie in value:
cookie_name, cookie_value = cookie.split('=', 1) cookie_name, cookie_value = cookie.split('=', 1)
cookie_options = cookie_value.split(';') cookie_options = cookie_value.split(';')
delete = False delete = False
@@ -154,7 +150,7 @@ class TestClient:
self.cookies[cookie_name] = cookie_options[0] self.cookies[cookie_name] = cookie_options[0]
def request(self, method, path, headers=None, body=None, sock=None): def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {} headers = NoCaseDict(headers or {})
body, headers = self._process_body(body, headers) body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers) cookies, headers = self._process_cookies(headers)
request_bytes = self._render_request(method, path, headers, body) request_bytes = self._render_request(method, path, headers, body)

View File

@@ -1,8 +1,7 @@
import os import os
import signal import signal
from microdot import * # noqa: F401, F403 from microdot import * # noqa: F401, F403
from microdot import Microdot as BaseMicrodot from microdot import Microdot as BaseMicrodot, Request, NoCaseDict
from microdot import Request
class Microdot(BaseMicrodot): class Microdot(BaseMicrodot):
@@ -15,7 +14,7 @@ class Microdot(BaseMicrodot):
path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '') path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '')
if 'QUERY_STRING' in environ and environ['QUERY_STRING']: if 'QUERY_STRING' in environ and environ['QUERY_STRING']:
path += '?' + environ['QUERY_STRING'] path += '?' + environ['QUERY_STRING']
headers = {} headers = NoCaseDict()
for k, v in environ.items(): for k, v in environ.items():
if k.startswith('HTTP_'): if k.startswith('HTTP_'):
h = '-'.join([p.title() for p in k[5:].split('_')]) h = '-'.join([p.title() for p in k[5:].split('_')])

113
tests/test_microdot_auth.py Normal file
View File

@@ -0,0 +1,113 @@
import binascii
import unittest
from microdot import Microdot
from microdot_auth import BasicAuth, TokenAuth
from microdot_test_client import TestClient
class TestAuth(unittest.TestCase):
def test_basic_auth(self):
app = Microdot()
basic_auth = BasicAuth()
@basic_auth.callback
def authenticate(request, username, password):
if username == 'foo' and password == 'bar':
request.g.user = {'username': username}
return True
@app.route('/')
@basic_auth
def index(request):
return request.g.user['username']
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={
'Authorization': 'Basic ' + binascii.b2a_base64(
b'foo:bar').decode()})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'foo')
res = client.get('/', headers={
'Authorization': 'Basic ' + binascii.b2a_base64(
b'foo:baz').decode()})
self.assertEqual(res.status_code, 401)
def test_token_auth(self):
app = Microdot()
token_auth = TokenAuth()
@token_auth.callback
def authenticate(request, token):
if token == 'foo':
request.g.user = 'user'
return True
@app.route('/')
@token_auth
def index(request):
return request.g.user
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'Authorization': 'Basic foo'})
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'Authorization': 'foo'})
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'Authorization': 'Bearer foo'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')
def test_token_auth_custom_header(self):
app = Microdot()
token_auth = TokenAuth(header='X-Auth-Token')
@token_auth.callback
def authenticate(request, token):
if token == 'foo':
request.g.user = 'user'
return True
@app.route('/')
@token_auth
def index(request):
return request.g.user
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'Authorization': 'Basic foo'})
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'Authorization': 'foo'})
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'Authorization': 'Bearer foo'})
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'X-Token-Auth': 'Bearer foo'})
self.assertEqual(res.status_code, 401)
res = client.get('/', headers={'X-Auth-Token': 'foo'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')
res = client.get('/', headers={'x-auth-token': 'foo'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')
@token_auth.errorhandler
def error_handler():
return {'status_code': 403}, 403
res = client.get('/')
self.assertEqual(res.status_code, 403)
self.assertEqual(res.json, {'status_code': 403})

View File

@@ -0,0 +1,134 @@
import unittest
from microdot import Microdot
from microdot_login import LoginAuth
from microdot_session import set_session_secret_key, with_session
from microdot_test_client import TestClient
set_session_secret_key('top-secret!')
class TestLogin(unittest.TestCase):
def test_login_auth(self):
app = Microdot()
login_auth = LoginAuth()
@app.get('/')
@login_auth
def index(request):
return 'ok'
@app.post('/login')
def login(request):
login_auth.login_user(request, 'user')
return login_auth.redirect_to_next(request)
@app.post('/logout')
def logout(request):
login_auth.logout_user(request)
return 'ok'
client = TestClient(app)
res = client.get('/?foo=bar')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar')
res = client.post('/login?next=/%3Ffoo=bar')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/?foo=bar')
res = client.get('/')
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'ok')
res = client.post('/logout')
self.assertEqual(res.status_code, 200)
res = client.get('/')
self.assertEqual(res.status_code, 302)
def test_login_auth_with_session(self):
app = Microdot()
login_auth = LoginAuth(login_url='/foo')
@app.get('/')
@login_auth
@with_session
def index(request, session):
return session['user_id']
@app.post('/foo')
def login(request):
login_auth.login_user(request, 'user')
return login_auth.redirect_to_next(request)
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/foo?next=/')
res = client.post('/foo')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/')
res = client.get('/')
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')
def test_login_auth_user_callback(self):
app = Microdot()
login_auth = LoginAuth()
@login_auth.callback
def check_user(request, user_id):
request.g.user_id = user_id
return user_id == 'user'
@app.get('/')
@login_auth
def index(request):
return request.g.user_id
@app.post('/good-login')
def good_login(request):
login_auth.login_user(request, 'user')
return login_auth.redirect_to_next(request)
@app.post('/bad-login')
def bad_login(request):
login_auth.login_user(request, 'foo')
return login_auth.redirect_to_next(request)
client = TestClient(app)
res = client.post('/good-login')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/')
res = client.get('/')
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')
res = client.post('/bad-login')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/')
res = client.get('/')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/login?next=/')
def test_login_auth_bad_redirect(self):
app = Microdot()
login_auth = LoginAuth()
@app.get('/')
@login_auth
def index(request):
return 'ok'
@app.post('/login')
def login(request):
login_auth.login_user(request, 'user')
return login_auth.redirect_to_next(request)
client = TestClient(app)
res = client.post('/login?next=http://example.com')
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/')

View File

@@ -1,31 +1,60 @@
import unittest import unittest
from microdot import MultiDict from microdot import MultiDict, NoCaseDict
class TestMultiDict(unittest.TestCase): class TestMultiDict(unittest.TestCase):
def test_multidict(self): def test_multidict(self):
d = MultiDict() d = MultiDict()
assert dict(d) == {} self.assertEqual(dict(d), {})
assert d.get('zero') is None self.assertIsNone(d.get('zero'))
assert d.get('zero', default=0) == 0 self.assertEqual(d.get('zero', default=0), 0)
assert d.getlist('zero') == [] self.assertEqual(d.getlist('zero'), [])
assert d.getlist('zero', type=int) == [] self.assertEqual(d.getlist('zero', type=int), [])
d['one'] = 1 d['one'] = 1
assert d['one'] == 1 self.assertEqual(d['one'], 1)
assert d.get('one') == 1 self.assertEqual(d.get('one'), 1)
assert d.get('one', default=2) == 1 self.assertEqual(d.get('one', default=2), 1)
assert d.get('one', type=int) == 1 self.assertEqual(d.get('one', type=int), 1)
assert d.get('one', type=str) == '1' self.assertEqual(d.get('one', type=str), '1')
d['two'] = 1 d['two'] = 1
d['two'] = 2 d['two'] = 2
assert d['two'] == 1 self.assertEqual(d['two'], 1)
assert d.get('two') == 1 self.assertEqual(d.get('two'), 1)
assert d.get('two', default=2) == 1 self.assertEqual(d.get('two', default=2), 1)
assert d.get('two', type=int) == 1 self.assertEqual(d.get('two', type=int), 1)
assert d.get('two', type=str) == '1' self.assertEqual(d.get('two', type=str), '1')
assert d.getlist('two') == [1, 2] self.assertEqual(d.getlist('two'), [1, 2])
assert d.getlist('two', type=int) == [1, 2] self.assertEqual(d.getlist('two', type=int), [1, 2])
assert d.getlist('two', type=str) == ['1', '2'] self.assertEqual(d.getlist('two', type=str), ['1', '2'])
def test_case_insensitive_dict(self):
d = NoCaseDict()
d['One'] = 1
d['one'] = 2
d['ONE'] = 3
d['One'] = 4
d['two'] = 5
self.assertEqual(d['one'], 4)
self.assertEqual(d['One'], 4)
self.assertEqual(d['ONE'], 4)
self.assertEqual(d['onE'], 4)
self.assertEqual(d['two'], 5)
self.assertEqual(d['tWO'], 5)
self.assertEqual(d.get('one'), 4)
self.assertEqual(d.get('One'), 4)
self.assertEqual(d.get('ONE'), 4)
self.assertEqual(d.get('onE'), 4)
self.assertEqual(d.get('two'), 5)
self.assertEqual(d.get('tWO'), 5)
self.assertIn(('One', 4), list(d.items()))
self.assertIn(('two', 5), list(d.items()))
self.assertIn(4, list(d.values()))
self.assertIn(5, list(d.values()))
del d['oNE']
self.assertEqual(list(d.items()), [('two', 5)])
self.assertEqual(list(d.values()), [5])

View File

@@ -19,6 +19,9 @@ class TestSession(unittest.TestCase):
@self.app.get('/') @self.app.get('/')
def index(req): def index(req):
session = get_session(req) session = get_session(req)
session2 = get_session(req)
session2['foo'] = 'bar'
self.assertEqual(session['foo'], 'bar')
return str(session.get('name')) return str(session.get('name'))
@self.app.get('/with') @self.app.get('/with')