Cross-Origin Resource Sharing (CORS) support (Fixes #45)

This commit is contained in:
Miguel Grinberg
2023-03-23 00:02:20 +00:00
parent ea6766cea9
commit 67798f7dbf
6 changed files with 325 additions and 0 deletions

View File

@@ -52,6 +52,12 @@ API Reference
.. automodule:: microdot_session
:members:
``microdot_cors`` module
------------------------
.. automodule:: microdot_cors
:members:
``microdot_websocket`` module
------------------------------

View File

@@ -208,6 +208,42 @@ Example::
delete_session(req)
return redirect('/')
Cross-Origin Resource Sharing (CORS)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:align: left
* - Compatibility
- | CPython & MicroPython
* - Required Microdot source files
- | `microdot.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot.py>`_
| `microdot_cors.py <https://github.com/miguelgrinberg/microdot/tree/main/src/microdot_cors.py>`_
* - Required external dependencies
- | None
* - Examples
- | `cors.py <https://github.com/miguelgrinberg/microdot/blob/main/examples/cors/cors.py>`_
The CORS extension provides support for `Cross-Origin Resource Sharing
(CORS) <https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS>`_. CORS is a
mechanism that allows web applications running on different origins to access
resources from each other. For example, a web application running on
``https://example.com`` can access resources from ``https://api.example.com``.
To enable CORS support, create an instance of the
:class:`CORS <microdot_cors.CORS>` class and configure the desired options.
Example::
from microdot import Microdot
from microdot_cors import CORS
app = Microdot()
cors = CORS(app, allowed_origins=['https://example.com'],
allow_credentials=True)
WebSocket Support
~~~~~~~~~~~~~~~~~

1
examples/cors/README.md Normal file
View File

@@ -0,0 +1 @@
This directory contains Cross-Origin Resource Sharing (CORS) examples.

14
examples/cors/app.py Normal file
View File

@@ -0,0 +1,14 @@
from microdot import Microdot
from microdot_cors import CORS
app = Microdot()
CORS(app, allowed_origins=['https://example.org'], allow_credentials=True)
@app.route('/')
def index(request):
return 'Hello World!'
if __name__ == '__main__':
app.run()

110
src/microdot_cors.py Normal file
View File

@@ -0,0 +1,110 @@
class CORS:
"""Add CORS headers to HTTP responses.
:param app: The application to add CORS headers to.
:param allowed_origins: A list of origins that are allowed to make
cross-site requests. If set to '*', all origins are
allowed.
:param allow_credentials: If set to True, the
``Access-Control-Allow-Credentials`` header will
be set to ``true`` to indicate to the browser
that it can expose cookies and authentication
headers.
:param allowed_methods: A list of methods that are allowed to be used when
making cross-site requests. If not set, all methods
are allowed.
:param expose_headers: A list of headers that the browser is allowed to
exposed.
:param allowed_headers: A list of headers that are allowed to be used when
making cross-site requests. If not set, all headers
are allowed.
:param max_age: The maximum amount of time in seconds that the browser
should cache the results of a preflight request.
:param handle_cors: If set to False, CORS headers will not be added to
responses. This can be useful if you want to add CORS
headers manually.
"""
def __init__(self, app=None, allowed_origins=None, allow_credentials=False,
allowed_methods=None, expose_headers=None,
allowed_headers=None, max_age=None, handle_cors=True):
self.allowed_origins = allowed_origins
self.allow_credentials = allow_credentials
self.allowed_methods = allowed_methods
self.expose_headers = expose_headers
self.allowed_headers = None if allowed_headers is None \
else [h.lower() for h in allowed_headers]
self.max_age = max_age
if app is not None:
self.initialize(app, handle_cors=handle_cors)
def initialize(self, app, handle_cors=True):
"""Initialize the CORS object for the given application.
:param app: The application to add CORS headers to.
:param handle_cors: If set to False, CORS headers will not be added to
responses. This can be useful if you want to add
CORS headers manually.
"""
self.default_options_handler = app.options_handler
if handle_cors:
app.options_handler = self.options_handler
app.after_request(self.after_request)
app.after_error_request(self.after_request)
def options_handler(self, request):
headers = self.default_options_handler(request)
headers.update(self.get_cors_headers(request))
return headers
def get_cors_headers(self, request):
"""Return a dictionary of CORS headers to add to a given request.
:param request: The request to add CORS headers to.
"""
cors_headers = {}
origin = request.headers.get('Origin')
if self.allowed_origins == '*':
cors_headers['Access-Control-Allow-Origin'] = origin or '*'
if origin:
cors_headers['Vary'] = 'Origin'
elif origin in (self.allowed_origins or []):
cors_headers['Access-Control-Allow-Origin'] = origin
cors_headers['Vary'] = 'Origin'
if self.allow_credentials and \
'Access-Control-Allow-Origin' in cors_headers:
cors_headers['Access-Control-Allow-Credentials'] = 'true'
if self.expose_headers:
cors_headers['Access-Control-Expose-Headers'] = \
', '.join(self.expose_headers)
if request.method == 'OPTIONS':
# handle preflight request
if self.max_age:
cors_headers['Access-Control-Max-Age'] = str(self.max_age)
method = request.headers.get('Access-Control-Request-Method')
if method:
method = method.upper()
if self.allowed_methods is None or \
method in self.allowed_methods:
cors_headers['Access-Control-Allow-Methods'] = method
headers = request.headers.get('Access-Control-Request-Headers')
if headers:
if self.allowed_headers is None:
cors_headers['Access-Control-Allow-Headers'] = headers
else:
headers = [h.strip() for h in headers.split(',')]
headers = [h for h in headers
if h.lower() in self.allowed_headers]
cors_headers['Access-Control-Allow-Headers'] = \
', '.join(headers)
return cors_headers
def after_request(self, request, response):
saved_vary = response.headers.get('Vary')
response.headers.update(self.get_cors_headers(request))
if saved_vary and saved_vary != response.headers.get('Vary'):
response.headers['Vary'] = (
saved_vary + ', ' + response.headers['Vary'])

158
tests/test_cors.py Normal file
View File

@@ -0,0 +1,158 @@
import unittest
from microdot import Microdot
from microdot_test_client import TestClient
from microdot_cors import CORS
class TestCORS(unittest.TestCase):
def test_origin(self):
app = Microdot()
cors = CORS(allowed_origins=['https://example.com'],
allow_credentials=True)
cors.initialize(app)
@app.get('/')
def index(req):
return 'foo'
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 200)
self.assertFalse('Access-Control-Allow-Origin' in res.headers)
self.assertFalse('Access-Control-Allow-Credentials' in res.headers)
self.assertFalse('Vary' in res.headers)
res = client.get('/', headers={'Origin': 'https://example.com'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertEqual(res.headers['Access-Control-Allow-Credentials'],
'true')
self.assertEqual(res.headers['Vary'], 'Origin')
cors.allow_credentials = False
res = client.get('/foo', headers={'Origin': 'https://example.com'})
self.assertEqual(res.status_code, 404)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertFalse('Access-Control-Allow-Credentials' in res.headers)
self.assertEqual(res.headers['Vary'], 'Origin')
res = client.get('/', headers={'Origin': 'https://bad.com'})
self.assertEqual(res.status_code, 200)
self.assertFalse('Access-Control-Allow-Origin' in res.headers)
self.assertFalse('Access-Control-Allow-Credentials' in res.headers)
self.assertFalse('Vary' in res.headers)
def test_all_origins(self):
app = Microdot()
CORS(app, allowed_origins='*', expose_headers=['X-Test', 'X-Test2'])
@app.get('/')
def index(req):
return 'foo'
@app.get('/foo')
def foo(req):
return 'foo', {'Vary': 'X-Foo, X-Bar'}
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Access-Control-Allow-Origin'], '*')
self.assertFalse('Vary' in res.headers)
self.assertEqual(res.headers['Access-Control-Expose-Headers'],
'X-Test, X-Test2')
res = client.get('/', headers={'Origin': 'https://example.com'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertEqual(res.headers['Vary'], 'Origin')
self.assertEqual(res.headers['Access-Control-Expose-Headers'],
'X-Test, X-Test2')
res = client.get('/bad', headers={'Origin': 'https://example.com'})
self.assertEqual(res.status_code, 404)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertEqual(res.headers['Vary'], 'Origin')
self.assertEqual(res.headers['Access-Control-Expose-Headers'],
'X-Test, X-Test2')
res = client.get('/foo', headers={'Origin': 'https://example.com'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Vary'], 'X-Foo, X-Bar, Origin')
def test_cors_preflight(self):
app = Microdot()
CORS(app, allowed_origins='*')
@app.route('/', methods=['GET', 'POST'])
def index(req):
return 'foo'
client = TestClient(app)
res = client.request('OPTIONS', '/', headers={
'Origin': 'https://example.com',
'Access-Control-Request-Method': 'POST',
'Access-Control-Request-Headers': 'X-Test, X-Test2'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertFalse('Access-Control-Max-Age' in res.headers)
self.assertEqual(res.headers['Access-Control-Allow-Methods'], 'POST')
self.assertEqual(res.headers['Access-Control-Allow-Headers'],
'X-Test, X-Test2')
res = client.request('OPTIONS', '/', headers={
'Origin': 'https://example.com'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertFalse('Access-Control-Max-Age' in res.headers)
self.assertFalse('Access-Control-Allow-Methods' in res.headers)
self.assertFalse('Access-Control-Allow-Headers' in res.headers)
def test_cors_preflight_with_options(self):
app = Microdot()
CORS(app, allowed_origins='*', max_age=3600, allowed_methods=['POST'],
allowed_headers=['X-Test'])
@app.route('/', methods=['GET', 'POST'])
def index(req):
return 'foo'
client = TestClient(app)
res = client.request('OPTIONS', '/', headers={
'Origin': 'https://example.com',
'Access-Control-Request-Method': 'POST',
'Access-Control-Request-Headers': 'X-Test, X-Test2'})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.headers['Access-Control-Allow-Origin'],
'https://example.com')
self.assertEqual(res.headers['Access-Control-Max-Age'], '3600')
self.assertEqual(res.headers['Access-Control-Allow-Methods'], 'POST')
self.assertEqual(res.headers['Access-Control-Allow-Headers'], 'X-Test')
res = client.request('OPTIONS', '/', headers={
'Origin': 'https://example.com',
'Access-Control-Request-Method': 'GET'})
self.assertEqual(res.status_code, 200)
self.assertFalse('Access-Control-Allow-Methods' in res.headers)
self.assertFalse('Access-Control-Allow-Headers' in res.headers)
def test_cors_disabled(self):
app = Microdot()
CORS(app, allowed_origins='*', handle_cors=False)
@app.get('/')
def index(req):
return 'foo'
client = TestClient(app)
res = client.get('/')
self.assertEqual(res.status_code, 200)
self.assertFalse('Access-Control-Allow-Origin' in res.headers)
self.assertFalse('Vary' in res.headers)