Configurable session cookie options (Fixes #242)

This commit is contained in:
Miguel Grinberg
2024-06-18 00:09:44 +01:00
parent 4204db61e5
commit 0151611fc8
3 changed files with 83 additions and 13 deletions

View File

@@ -29,14 +29,21 @@ class Session:
"""
secret_key = None
def __init__(self, app=None, secret_key=None):
def __init__(self, app=None, secret_key=None, cookie_options=None):
self.secret_key = secret_key
self.cookie_options = cookie_options or {}
if app is not None:
self.initialize(app)
def initialize(self, app, secret_key=None):
def initialize(self, app, secret_key=None, cookie_options=None):
if secret_key is not None:
self.secret_key = secret_key
if cookie_options is not None:
self.cookie_options = cookie_options
if 'path' not in self.cookie_options:
self.cookie_options['path'] = '/'
if 'http_only' not in self.cookie_options:
self.cookie_options['http_only'] = True
app._session = self
def get(self, request):
@@ -86,7 +93,8 @@ class Session:
@request.after_request
def _update_session(request, response):
response.set_cookie('session', encoded_session, http_only=True)
response.set_cookie('session', encoded_session,
**self.cookie_options)
return response
def delete(self, request):
@@ -109,8 +117,7 @@ class Session:
"""
@request.after_request
def _delete_session(request, response):
response.set_cookie('session', '', http_only=True,
expires='Thu, 01 Jan 1970 00:00:01 GMT')
response.delete_cookie('session')
return response
def encode(self, payload, secret_key=None):

View File

@@ -112,9 +112,13 @@ class TestClient:
headers['Host'] = 'example.com:1234'
return body, headers
def _process_cookies(self, headers):
def _process_cookies(self, path, headers):
cookies = ''
for name, value in self.cookies.items():
if isinstance(value, tuple):
value, cookie_path = value
if not path.startswith(cookie_path):
continue
if cookies:
cookies += '; '
cookies += name + '=' + value
@@ -123,7 +127,7 @@ class TestClient:
headers['Cookie'] += '; ' + cookies
else:
headers['Cookie'] = cookies
return cookies, headers
return headers
def _render_request(self, method, path, headers, body):
request_bytes = '{method} {path} HTTP/1.0\n'.format(
@@ -139,11 +143,13 @@ class TestClient:
for cookie in cookies:
cookie_name, cookie_value = cookie.split('=', 1)
cookie_options = cookie_value.split(';')
path = '/'
delete = False
for option in cookie_options[1:]:
if option.strip().lower().startswith(
option = option.strip().lower()
if option.startswith(
'max-age='): # pragma: no cover
_, age = option.strip().split('=', 1)
_, age = option.split('=', 1)
try:
age = int(age)
except ValueError: # pragma: no cover
@@ -151,24 +157,29 @@ class TestClient:
if age <= 0:
delete = True
break
elif option.strip().lower().startswith('expires='):
_, e = option.strip().split('=', 1)
elif option.startswith('expires='):
_, e = option.split('=', 1)
# this is a very limited parser for cookie expiry
# that only detects a cookie deletion request when
# the date is 1/1/1970
if '1 jan 1970' in e.lower(): # pragma: no branch
delete = True
break
elif option.startswith('path='):
_, path = option.split('=', 1)
if delete:
if cookie_name in self.cookies: # pragma: no branch
del self.cookies[cookie_name]
else:
self.cookies[cookie_name] = cookie_options[0]
if path == '/':
self.cookies[cookie_name] = cookie_options[0]
else:
self.cookies[cookie_name] = (cookie_options[0], path)
async def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {}
body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers)
headers = self._process_cookies(path, headers)
request_bytes = self._render_request(method, path, headers, body)
if sock:
reader = sock[0]

View File

@@ -82,3 +82,55 @@ class TestSession(unittest.TestCase):
res = self._run(client.get('/'))
self.assertEqual(res.status_code, 200)
def test_session_default_path(self):
app = Microdot()
session_ext.initialize(app, secret_key='some-other-secret')
client = TestClient(app)
@app.get('/')
@with_session
def index(req, session):
session['foo'] = 'bar'
session.save()
return ''
@app.get('/child')
@with_session
def child(req, session):
return str(session.get('foo'))
res = self._run(client.get('/'))
self.assertEqual(res.status_code, 200)
res = self._run(client.get('/child'))
self.assertEqual(res.text, 'bar')
def test_session_custom_path(self):
app = Microdot()
session_ext.initialize(app, secret_key='some-other-secret',
cookie_options={'path': '/child'})
client = TestClient(app)
@app.get('/')
@with_session
def index(req, session):
return str(session.get('foo'))
@app.get('/child')
@with_session
def child(req, session):
session['foo'] = 'bar'
session.save()
return ''
@app.get('/child/foo')
@with_session
def foo(req, session):
return str(session.get('foo'))
res = self._run(client.get('/child'))
self.assertEqual(res.status_code, 200)
res = self._run(client.get('/'))
self.assertEqual(res.text, 'None')
res = self._run(client.get('/child/foo'))
self.assertEqual(res.text, 'bar')