CSRF protection (#335)
This commit is contained in:
@@ -12,3 +12,4 @@ from tests.test_utemplate import * # noqa: F401, F403
|
||||
from tests.test_session import * # noqa: F401, F403
|
||||
from tests.test_auth import * # noqa: F401, F403
|
||||
from tests.test_login import * # noqa: F401, F403
|
||||
from tests.test_csrf import * # noqa: F401, F403
|
||||
|
||||
298
tests/test_csrf.py
Normal file
298
tests/test_csrf.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from microdot import Microdot
|
||||
from microdot.cors import CORS
|
||||
from microdot.csrf import CSRF
|
||||
from microdot.test_client import TestClient
|
||||
|
||||
|
||||
class TestCSRF(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if hasattr(asyncio, 'set_event_loop'):
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
cls.loop = asyncio.get_event_loop()
|
||||
|
||||
def _run(self, coro):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
def test_protect_all_true(self):
|
||||
app = Microdot()
|
||||
csrf = CSRF(app)
|
||||
|
||||
@app.get('/')
|
||||
def index(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit')
|
||||
def submit(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit-exempt')
|
||||
@csrf.exempt
|
||||
def submit_exempt(request):
|
||||
return 204
|
||||
|
||||
@app.get('/get-protected')
|
||||
@csrf.protect
|
||||
def get_protected(request):
|
||||
return 204
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'https://evil.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'same-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'same-origin'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit-exempt'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit-exempt', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get('/get-protected'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/get-protected', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
|
||||
def test_protect_all_false(self):
|
||||
app = Microdot()
|
||||
csrf = CSRF(protect_all=False)
|
||||
csrf.initialize(app)
|
||||
|
||||
@app.get('/')
|
||||
def index(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit')
|
||||
@csrf.protect
|
||||
def submit(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit-exempt')
|
||||
def submit_exempt(request):
|
||||
return 204
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'https://evil.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'same-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'same-origin'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit-exempt'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit-exempt', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
def test_allow_subdomains(self):
|
||||
app = Microdot()
|
||||
csrf = CSRF(allow_subdomains=True)
|
||||
csrf.initialize(app)
|
||||
|
||||
@app.get('/')
|
||||
def index(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit')
|
||||
def submit(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit-exempt')
|
||||
@csrf.exempt
|
||||
def submit_exempt(request):
|
||||
return 204
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'https://evil.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'same-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Sec-Fetch-Site': 'same-origin'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit-exempt'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit-exempt', headers={'Sec-Fetch-Site': 'cross-site'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
def test_allowed_origins(self):
|
||||
app = Microdot()
|
||||
cors = CORS(allowed_origins=['http://foo.com', 'https://bar.com:8888'])
|
||||
csrf = CSRF()
|
||||
csrf.initialize(app, cors)
|
||||
|
||||
@app.get('/')
|
||||
def index(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit')
|
||||
def submit(request):
|
||||
return 204
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'foo.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'http://foo.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'https://baz.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'http://x.baz.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'https://bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'http://bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'https://x.y.bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'http://baz.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
|
||||
def test_allowed_origins_with_subdomains(self):
|
||||
app = Microdot()
|
||||
cors = CORS(allowed_origins=['http://foo.com', 'https://bar.com:8888'])
|
||||
csrf = CSRF(allow_subdomains=True)
|
||||
csrf.initialize(app, cors)
|
||||
|
||||
@app.get('/')
|
||||
def index(request):
|
||||
return 204
|
||||
|
||||
@app.post('/submit')
|
||||
def submit(request):
|
||||
return 204
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
res = self._run(client.get('/'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'foo.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'http://foo.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'https://baz.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.get(
|
||||
'/', headers={'Origin': 'http://x.baz.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
|
||||
res = self._run(client.post('/submit'))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'https://bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'http://bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'https://x.y.bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 204)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'http://x.y.bar.com:8888'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
res = self._run(client.post(
|
||||
'/submit', headers={'Origin': 'http://baz.com'}
|
||||
))
|
||||
self.assertEqual(res.status_code, 403)
|
||||
Reference in New Issue
Block a user