diff --git a/src/microdot/login.py b/src/microdot/login.py index aa894e9..53cdb5f 100644 --- a/src/microdot/login.py +++ b/src/microdot/login.py @@ -82,6 +82,7 @@ class Login: session['_user_id'] = user.id session['_fresh'] = True session.save() + request.g.current_user = user if remember: days = 30 if remember is True else int(remember) @@ -104,9 +105,21 @@ class Login: session.pop('_user_id', None) session.pop('_fresh', None) session.save() + request.g.current_user = None if '_remember' in request.cookies: self._update_remember_cookie(request, 0) + async def get_current_user(self, request): + """Return the currently logged in user.""" + if not hasattr(request.g, 'current_user'): + user_id = self._get_user_id_from_session(request) + if user_id: + request.g.current_user = await invoke_handler( + self.user_loader_callback, user_id) + else: + request.g.current_user = None + return request.g.current_user + def __call__(self, f): """Decorator to protect a route with authentication. @@ -124,12 +137,8 @@ class Login: """ async def wrapper(request, *args, **kwargs): - user_id = self._get_user_id_from_session(request) - if not user_id: - return await self._redirect_to_login(request) - request.g.current_user = await invoke_handler( - self.user_loader_callback, user_id) - if not request.g.current_user: + user = await self.get_current_user(request) + if not user: return await self._redirect_to_login(request) return await invoke_handler(f, request, *args, **kwargs) diff --git a/src/microdot/session.py b/src/microdot/session.py index cbfb0a8..07de334 100644 --- a/src/microdot/session.py +++ b/src/microdot/session.py @@ -23,7 +23,7 @@ class SessionDict(dict): class Session: - """ + """Session handling :param app: The application instance. :param secret_key: The secret key, as a string or bytes object. :param cookie_options: A dictionary with cookie options to pass as diff --git a/tests/test_login.py b/tests/test_login.py index 3199b76..2f33059 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -32,7 +32,9 @@ class TestLogin(unittest.TestCase): @app.get('/') @login - def index(request): + async def index(request): + assert await login.get_current_user(request) == \ + request.g.current_user return request.g.current_user.name @app.post('/login') diff --git a/tox.ini b/tox.ini index 86fab3e..0b76f35 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ python = [testenv] commands= pip install -e . - pytest -p no:logging --cov=src --cov-config=.coveragerc --cov-branch --cov-report=term-missing --cov-report=xml tests + pytest -p no:logging --cov=src --cov-config=.coveragerc --cov-branch --cov-report=term-missing --cov-report=xml {posargs} deps= pytest pytest-cov