CSRF: accept cross-site request if origin is in the CORS allowed origin list

This commit is contained in:
Miguel Grinberg
2025-12-21 10:48:29 +00:00
parent a99b658c3f
commit ba6893ca0f
2 changed files with 27 additions and 17 deletions

View File

@@ -56,6 +56,7 @@ class CSRF:
) or request.route in self.protected_routes: ) or request.route in self.protected_routes:
allow = False allow = False
sfs = request.headers.get('Sec-Fetch-Site') sfs = request.headers.get('Sec-Fetch-Site')
origin = request.headers.get('Origin')
if sfs: if sfs:
# if the Sec-Fetch-Site header was given, ensure it is not # if the Sec-Fetch-Site header was given, ensure it is not
# cross-site # cross-site
@@ -63,14 +64,11 @@ class CSRF:
allow = True allow = True
elif sfs == 'same-site' and self.allow_subdomains: elif sfs == 'same-site' and self.allow_subdomains:
allow = True allow = True
elif self.cors and self.cors.allowed_origins != '*': if not allow and origin and self.cors and \
# if there is no Sec-Fetch-Site header but we have a list self.cors.allowed_origins != '*':
# of allowed origins, then we can validate the origin # if we have a list of allowed origins, then we can
origin = request.headers.get('Origin') # validate the origin
if origin is None: if not self.allow_subdomains:
# origin wasn't given so this isn't a browser
allow = True
elif not self.allow_subdomains:
allow = origin in self.cors.allowed_origins allow = origin in self.cors.allowed_origins
else: else:
origin_scheme, origin_host = origin.split('://', 1) origin_scheme, origin_host = origin.split('://', 1)
@@ -83,7 +81,7 @@ class CSRF:
): ):
allow = True allow = True
break break
else: if not allow and not sfs and not origin:
allow = True # no headers to check allow = True # no headers to check
if not allow: if not allow:

View File

@@ -203,10 +203,6 @@ class TestCSRF(unittest.TestCase):
res = self._run(client.get('/')) res = self._run(client.get('/'))
self.assertEqual(res.status_code, 204) self.assertEqual(res.status_code, 204)
res = self._run(client.get(
'/', headers={'Origin': 'foo.com'}
))
self.assertEqual(res.status_code, 204)
res = self._run(client.get( res = self._run(client.get(
'/', headers={'Origin': 'http://foo.com'} '/', headers={'Origin': 'http://foo.com'}
)) ))
@@ -230,6 +226,26 @@ class TestCSRF(unittest.TestCase):
'/submit', headers={'Origin': 'http://bar.com:8888'} '/submit', headers={'Origin': 'http://bar.com:8888'}
)) ))
self.assertEqual(res.status_code, 403) self.assertEqual(res.status_code, 403)
res = self._run(client.post(
'/submit', headers={
'Sec-Fetch-Site': 'cross-site',
'Origin': 'https://bar.com:8888',
},
))
self.assertEqual(res.status_code, 204)
res = self._run(client.post(
'/submit', headers={
'Sec-Fetch-Site': 'cross-site',
'Origin': 'https://bar.com:8889',
},
))
self.assertEqual(res.status_code, 403)
res = self._run(client.post(
'/submit', headers={
'Sec-Fetch-Site': 'cross-site',
},
))
self.assertEqual(res.status_code, 403)
res = self._run(client.post( res = self._run(client.post(
'/submit', headers={'Origin': 'https://x.y.bar.com:8888'} '/submit', headers={'Origin': 'https://x.y.bar.com:8888'}
)) ))
@@ -257,10 +273,6 @@ class TestCSRF(unittest.TestCase):
res = self._run(client.get('/')) res = self._run(client.get('/'))
self.assertEqual(res.status_code, 204) self.assertEqual(res.status_code, 204)
res = self._run(client.get(
'/', headers={'Origin': 'foo.com'}
))
self.assertEqual(res.status_code, 204)
res = self._run(client.get( res = self._run(client.get(
'/', headers={'Origin': 'http://foo.com'} '/', headers={'Origin': 'http://foo.com'}
)) ))