172 lines
6.8 KiB
Python
172 lines
6.8 KiB
Python
import asyncio
|
|
import unittest
|
|
from microdot import Microdot
|
|
from microdot.test_client import TestClient
|
|
from microdot.cors import CORS
|
|
|
|
|
|
class TestCORS(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.loop = asyncio.new_event_loop()
|
|
|
|
def _run(self, coro):
|
|
return self.loop.run_until_complete(coro)
|
|
|
|
def test_origin(self):
|
|
app = Microdot()
|
|
cors = CORS(allowed_origins=['https://example.com'],
|
|
allow_credentials=True)
|
|
cors.initialize(app)
|
|
|
|
@app.get('/')
|
|
def index(req):
|
|
return 'foo'
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.get('/'))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertFalse('Access-Control-Allow-Origin' in res.headers)
|
|
self.assertFalse('Access-Control-Allow-Credentials' in res.headers)
|
|
self.assertFalse('Vary' in res.headers)
|
|
|
|
res = self._run(client.get(
|
|
'/', headers={'Origin': 'https://example.com'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertEqual(res.headers['Access-Control-Allow-Credentials'],
|
|
'true')
|
|
self.assertEqual(res.headers['Vary'], 'Origin')
|
|
|
|
cors.allow_credentials = False
|
|
|
|
res = self._run(client.get(
|
|
'/foo', headers={'Origin': 'https://example.com'}))
|
|
self.assertEqual(res.status_code, 404)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertFalse('Access-Control-Allow-Credentials' in res.headers)
|
|
self.assertEqual(res.headers['Vary'], 'Origin')
|
|
|
|
res = self._run(client.get('/', headers={'Origin': 'https://bad.com'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertFalse('Access-Control-Allow-Origin' in res.headers)
|
|
self.assertFalse('Access-Control-Allow-Credentials' in res.headers)
|
|
self.assertFalse('Vary' in res.headers)
|
|
|
|
def test_all_origins(self):
|
|
app = Microdot()
|
|
CORS(app, allowed_origins='*', expose_headers=['X-Test', 'X-Test2'])
|
|
|
|
@app.get('/')
|
|
def index(req):
|
|
return 'foo'
|
|
|
|
@app.get('/foo')
|
|
def foo(req):
|
|
return 'foo', {'Vary': 'X-Foo, X-Bar'}
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.get('/'))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'], '*')
|
|
self.assertFalse('Vary' in res.headers)
|
|
self.assertEqual(res.headers['Access-Control-Expose-Headers'],
|
|
'X-Test, X-Test2')
|
|
|
|
res = self._run(client.get(
|
|
'/', headers={'Origin': 'https://example.com'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertEqual(res.headers['Vary'], 'Origin')
|
|
self.assertEqual(res.headers['Access-Control-Expose-Headers'],
|
|
'X-Test, X-Test2')
|
|
|
|
res = self._run(client.get(
|
|
'/bad', headers={'Origin': 'https://example.com'}))
|
|
self.assertEqual(res.status_code, 404)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertEqual(res.headers['Vary'], 'Origin')
|
|
self.assertEqual(res.headers['Access-Control-Expose-Headers'],
|
|
'X-Test, X-Test2')
|
|
|
|
res = self._run(client.get(
|
|
'/foo', headers={'Origin': 'https://example.com'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Vary'], 'X-Foo, X-Bar, Origin')
|
|
|
|
def test_cors_preflight(self):
|
|
app = Microdot()
|
|
CORS(app, allowed_origins='*')
|
|
|
|
@app.route('/', methods=['GET', 'POST'])
|
|
def index(req):
|
|
return 'foo'
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.request('OPTIONS', '/', headers={
|
|
'Origin': 'https://example.com',
|
|
'Access-Control-Request-Method': 'POST',
|
|
'Access-Control-Request-Headers': 'X-Test, X-Test2'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertFalse('Access-Control-Max-Age' in res.headers)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Methods'], 'POST')
|
|
self.assertEqual(res.headers['Access-Control-Allow-Headers'],
|
|
'X-Test, X-Test2')
|
|
|
|
res = self._run(client.request('OPTIONS', '/', headers={
|
|
'Origin': 'https://example.com'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertFalse('Access-Control-Max-Age' in res.headers)
|
|
self.assertFalse('Access-Control-Allow-Methods' in res.headers)
|
|
self.assertFalse('Access-Control-Allow-Headers' in res.headers)
|
|
|
|
def test_cors_preflight_with_options(self):
|
|
app = Microdot()
|
|
CORS(app, allowed_origins='*', max_age=3600, allowed_methods=['POST'],
|
|
allowed_headers=['X-Test'])
|
|
|
|
@app.route('/', methods=['GET', 'POST'])
|
|
def index(req):
|
|
return 'foo'
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.request('OPTIONS', '/', headers={
|
|
'Origin': 'https://example.com',
|
|
'Access-Control-Request-Method': 'POST',
|
|
'Access-Control-Request-Headers': 'X-Test, X-Test2'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
|
|
'https://example.com')
|
|
self.assertEqual(res.headers['Access-Control-Max-Age'], '3600')
|
|
self.assertEqual(res.headers['Access-Control-Allow-Methods'], 'POST')
|
|
self.assertEqual(res.headers['Access-Control-Allow-Headers'], 'X-Test')
|
|
|
|
res = self._run(client.request('OPTIONS', '/', headers={
|
|
'Origin': 'https://example.com',
|
|
'Access-Control-Request-Method': 'GET'}))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertFalse('Access-Control-Allow-Methods' in res.headers)
|
|
self.assertFalse('Access-Control-Allow-Headers' in res.headers)
|
|
|
|
def test_cors_disabled(self):
|
|
app = Microdot()
|
|
CORS(app, allowed_origins='*', handle_cors=False)
|
|
|
|
@app.get('/')
|
|
def index(req):
|
|
return 'foo'
|
|
|
|
client = TestClient(app)
|
|
res = self._run(client.get('/'))
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertFalse('Access-Control-Allow-Origin' in res.headers)
|
|
self.assertFalse('Vary' in res.headers)
|