diff --git a/src/microdot/csrf.py b/src/microdot/csrf.py index dcd2ece..dcee5da 100644 --- a/src/microdot/csrf.py +++ b/src/microdot/csrf.py @@ -56,6 +56,7 @@ class CSRF: ) or request.route in self.protected_routes: allow = False sfs = request.headers.get('Sec-Fetch-Site') + origin = request.headers.get('Origin') if sfs: # if the Sec-Fetch-Site header was given, ensure it is not # cross-site @@ -63,14 +64,11 @@ class CSRF: allow = True elif sfs == 'same-site' and self.allow_subdomains: allow = True - elif self.cors and self.cors.allowed_origins != '*': - # if there is no Sec-Fetch-Site header but we have a list - # of allowed origins, then we can validate the origin - origin = request.headers.get('Origin') - if origin is None: - # origin wasn't given so this isn't a browser - allow = True - elif not self.allow_subdomains: + if not allow and origin and self.cors and \ + self.cors.allowed_origins != '*': + # if we have a list of allowed origins, then we can + # validate the origin + if not self.allow_subdomains: allow = origin in self.cors.allowed_origins else: origin_scheme, origin_host = origin.split('://', 1) @@ -83,7 +81,7 @@ class CSRF: ): allow = True break - else: + if not allow and not sfs and not origin: allow = True # no headers to check if not allow: diff --git a/tests/test_csrf.py b/tests/test_csrf.py index d3f6238..1faa403 100644 --- a/tests/test_csrf.py +++ b/tests/test_csrf.py @@ -203,10 +203,6 @@ class TestCSRF(unittest.TestCase): res = self._run(client.get('/')) self.assertEqual(res.status_code, 204) - res = self._run(client.get( - '/', headers={'Origin': 'foo.com'} - )) - self.assertEqual(res.status_code, 204) res = self._run(client.get( '/', headers={'Origin': 'http://foo.com'} )) @@ -230,6 +226,26 @@ class TestCSRF(unittest.TestCase): '/submit', headers={'Origin': 'http://bar.com:8888'} )) self.assertEqual(res.status_code, 403) + res = self._run(client.post( + '/submit', headers={ + 'Sec-Fetch-Site': 'cross-site', + 'Origin': 'https://bar.com:8888', + }, + )) + self.assertEqual(res.status_code, 204) + res = self._run(client.post( + '/submit', headers={ + 'Sec-Fetch-Site': 'cross-site', + 'Origin': 'https://bar.com:8889', + }, + )) + self.assertEqual(res.status_code, 403) + res = self._run(client.post( + '/submit', headers={ + 'Sec-Fetch-Site': 'cross-site', + }, + )) + self.assertEqual(res.status_code, 403) res = self._run(client.post( '/submit', headers={'Origin': 'https://x.y.bar.com:8888'} )) @@ -257,10 +273,6 @@ class TestCSRF(unittest.TestCase): res = self._run(client.get('/')) self.assertEqual(res.status_code, 204) - res = self._run(client.get( - '/', headers={'Origin': 'foo.com'} - )) - self.assertEqual(res.status_code, 204) res = self._run(client.get( '/', headers={'Origin': 'http://foo.com'} ))