Configurable session cookie options (Fixes #242)
This commit is contained in:
@@ -29,14 +29,21 @@ class Session:
|
|||||||
"""
|
"""
|
||||||
secret_key = None
|
secret_key = None
|
||||||
|
|
||||||
def __init__(self, app=None, secret_key=None):
|
def __init__(self, app=None, secret_key=None, cookie_options=None):
|
||||||
self.secret_key = secret_key
|
self.secret_key = secret_key
|
||||||
|
self.cookie_options = cookie_options or {}
|
||||||
if app is not None:
|
if app is not None:
|
||||||
self.initialize(app)
|
self.initialize(app)
|
||||||
|
|
||||||
def initialize(self, app, secret_key=None):
|
def initialize(self, app, secret_key=None, cookie_options=None):
|
||||||
if secret_key is not None:
|
if secret_key is not None:
|
||||||
self.secret_key = secret_key
|
self.secret_key = secret_key
|
||||||
|
if cookie_options is not None:
|
||||||
|
self.cookie_options = cookie_options
|
||||||
|
if 'path' not in self.cookie_options:
|
||||||
|
self.cookie_options['path'] = '/'
|
||||||
|
if 'http_only' not in self.cookie_options:
|
||||||
|
self.cookie_options['http_only'] = True
|
||||||
app._session = self
|
app._session = self
|
||||||
|
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -86,7 +93,8 @@ class Session:
|
|||||||
|
|
||||||
@request.after_request
|
@request.after_request
|
||||||
def _update_session(request, response):
|
def _update_session(request, response):
|
||||||
response.set_cookie('session', encoded_session, http_only=True)
|
response.set_cookie('session', encoded_session,
|
||||||
|
**self.cookie_options)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def delete(self, request):
|
def delete(self, request):
|
||||||
@@ -109,8 +117,7 @@ class Session:
|
|||||||
"""
|
"""
|
||||||
@request.after_request
|
@request.after_request
|
||||||
def _delete_session(request, response):
|
def _delete_session(request, response):
|
||||||
response.set_cookie('session', '', http_only=True,
|
response.delete_cookie('session')
|
||||||
expires='Thu, 01 Jan 1970 00:00:01 GMT')
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def encode(self, payload, secret_key=None):
|
def encode(self, payload, secret_key=None):
|
||||||
|
|||||||
@@ -112,9 +112,13 @@ class TestClient:
|
|||||||
headers['Host'] = 'example.com:1234'
|
headers['Host'] = 'example.com:1234'
|
||||||
return body, headers
|
return body, headers
|
||||||
|
|
||||||
def _process_cookies(self, headers):
|
def _process_cookies(self, path, headers):
|
||||||
cookies = ''
|
cookies = ''
|
||||||
for name, value in self.cookies.items():
|
for name, value in self.cookies.items():
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
value, cookie_path = value
|
||||||
|
if not path.startswith(cookie_path):
|
||||||
|
continue
|
||||||
if cookies:
|
if cookies:
|
||||||
cookies += '; '
|
cookies += '; '
|
||||||
cookies += name + '=' + value
|
cookies += name + '=' + value
|
||||||
@@ -123,7 +127,7 @@ class TestClient:
|
|||||||
headers['Cookie'] += '; ' + cookies
|
headers['Cookie'] += '; ' + cookies
|
||||||
else:
|
else:
|
||||||
headers['Cookie'] = cookies
|
headers['Cookie'] = cookies
|
||||||
return cookies, headers
|
return headers
|
||||||
|
|
||||||
def _render_request(self, method, path, headers, body):
|
def _render_request(self, method, path, headers, body):
|
||||||
request_bytes = '{method} {path} HTTP/1.0\n'.format(
|
request_bytes = '{method} {path} HTTP/1.0\n'.format(
|
||||||
@@ -139,11 +143,13 @@ class TestClient:
|
|||||||
for cookie in cookies:
|
for cookie in cookies:
|
||||||
cookie_name, cookie_value = cookie.split('=', 1)
|
cookie_name, cookie_value = cookie.split('=', 1)
|
||||||
cookie_options = cookie_value.split(';')
|
cookie_options = cookie_value.split(';')
|
||||||
|
path = '/'
|
||||||
delete = False
|
delete = False
|
||||||
for option in cookie_options[1:]:
|
for option in cookie_options[1:]:
|
||||||
if option.strip().lower().startswith(
|
option = option.strip().lower()
|
||||||
|
if option.startswith(
|
||||||
'max-age='): # pragma: no cover
|
'max-age='): # pragma: no cover
|
||||||
_, age = option.strip().split('=', 1)
|
_, age = option.split('=', 1)
|
||||||
try:
|
try:
|
||||||
age = int(age)
|
age = int(age)
|
||||||
except ValueError: # pragma: no cover
|
except ValueError: # pragma: no cover
|
||||||
@@ -151,24 +157,29 @@ class TestClient:
|
|||||||
if age <= 0:
|
if age <= 0:
|
||||||
delete = True
|
delete = True
|
||||||
break
|
break
|
||||||
elif option.strip().lower().startswith('expires='):
|
elif option.startswith('expires='):
|
||||||
_, e = option.strip().split('=', 1)
|
_, e = option.split('=', 1)
|
||||||
# this is a very limited parser for cookie expiry
|
# this is a very limited parser for cookie expiry
|
||||||
# that only detects a cookie deletion request when
|
# that only detects a cookie deletion request when
|
||||||
# the date is 1/1/1970
|
# the date is 1/1/1970
|
||||||
if '1 jan 1970' in e.lower(): # pragma: no branch
|
if '1 jan 1970' in e.lower(): # pragma: no branch
|
||||||
delete = True
|
delete = True
|
||||||
break
|
break
|
||||||
|
elif option.startswith('path='):
|
||||||
|
_, path = option.split('=', 1)
|
||||||
if delete:
|
if delete:
|
||||||
if cookie_name in self.cookies: # pragma: no branch
|
if cookie_name in self.cookies: # pragma: no branch
|
||||||
del self.cookies[cookie_name]
|
del self.cookies[cookie_name]
|
||||||
else:
|
else:
|
||||||
self.cookies[cookie_name] = cookie_options[0]
|
if path == '/':
|
||||||
|
self.cookies[cookie_name] = cookie_options[0]
|
||||||
|
else:
|
||||||
|
self.cookies[cookie_name] = (cookie_options[0], path)
|
||||||
|
|
||||||
async def request(self, method, path, headers=None, body=None, sock=None):
|
async def request(self, method, path, headers=None, body=None, sock=None):
|
||||||
headers = headers or {}
|
headers = headers or {}
|
||||||
body, headers = self._process_body(body, headers)
|
body, headers = self._process_body(body, headers)
|
||||||
cookies, headers = self._process_cookies(headers)
|
headers = self._process_cookies(path, headers)
|
||||||
request_bytes = self._render_request(method, path, headers, body)
|
request_bytes = self._render_request(method, path, headers, body)
|
||||||
if sock:
|
if sock:
|
||||||
reader = sock[0]
|
reader = sock[0]
|
||||||
|
|||||||
@@ -82,3 +82,55 @@ class TestSession(unittest.TestCase):
|
|||||||
|
|
||||||
res = self._run(client.get('/'))
|
res = self._run(client.get('/'))
|
||||||
self.assertEqual(res.status_code, 200)
|
self.assertEqual(res.status_code, 200)
|
||||||
|
|
||||||
|
def test_session_default_path(self):
|
||||||
|
app = Microdot()
|
||||||
|
session_ext.initialize(app, secret_key='some-other-secret')
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
@app.get('/')
|
||||||
|
@with_session
|
||||||
|
def index(req, session):
|
||||||
|
session['foo'] = 'bar'
|
||||||
|
session.save()
|
||||||
|
return ''
|
||||||
|
|
||||||
|
@app.get('/child')
|
||||||
|
@with_session
|
||||||
|
def child(req, session):
|
||||||
|
return str(session.get('foo'))
|
||||||
|
|
||||||
|
res = self._run(client.get('/'))
|
||||||
|
self.assertEqual(res.status_code, 200)
|
||||||
|
res = self._run(client.get('/child'))
|
||||||
|
self.assertEqual(res.text, 'bar')
|
||||||
|
|
||||||
|
def test_session_custom_path(self):
|
||||||
|
app = Microdot()
|
||||||
|
session_ext.initialize(app, secret_key='some-other-secret',
|
||||||
|
cookie_options={'path': '/child'})
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
@app.get('/')
|
||||||
|
@with_session
|
||||||
|
def index(req, session):
|
||||||
|
return str(session.get('foo'))
|
||||||
|
|
||||||
|
@app.get('/child')
|
||||||
|
@with_session
|
||||||
|
def child(req, session):
|
||||||
|
session['foo'] = 'bar'
|
||||||
|
session.save()
|
||||||
|
return ''
|
||||||
|
|
||||||
|
@app.get('/child/foo')
|
||||||
|
@with_session
|
||||||
|
def foo(req, session):
|
||||||
|
return str(session.get('foo'))
|
||||||
|
|
||||||
|
res = self._run(client.get('/child'))
|
||||||
|
self.assertEqual(res.status_code, 200)
|
||||||
|
res = self._run(client.get('/'))
|
||||||
|
self.assertEqual(res.text, 'None')
|
||||||
|
res = self._run(client.get('/child/foo'))
|
||||||
|
self.assertEqual(res.text, 'bar')
|
||||||
|
|||||||
Reference in New Issue
Block a user