Use a case insensitive dict for headers
This commit is contained in:
@@ -91,6 +91,47 @@ def urldecode_bytes(s):
|
|||||||
return b''.join(result).decode()
|
return b''.join(result).decode()
|
||||||
|
|
||||||
|
|
||||||
|
class NoCaseDict(dict):
|
||||||
|
"""A subclass of dictionary that holds case-insensitive keys.
|
||||||
|
|
||||||
|
:param initial_dict: an initial dictionary of key/value pairs to
|
||||||
|
initialize this object with.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> d = NoCaseDict()
|
||||||
|
>>> d['Content-Type'] = 'text/html'
|
||||||
|
>>> print(d['Content-Type'])
|
||||||
|
text/html
|
||||||
|
>>> print(d['content-type'])
|
||||||
|
text/html
|
||||||
|
>>> print(d['CONTENT-TYPE'])
|
||||||
|
text/html
|
||||||
|
>>> del d['cOnTeNt-TyPe']
|
||||||
|
>>> print(d)
|
||||||
|
{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_dict=None):
|
||||||
|
super().__init__(initial_dict or {})
|
||||||
|
self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k}
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
key = self.keymap.get(key.lower(), key)
|
||||||
|
if key.lower() != key:
|
||||||
|
self.keymap[key.lower()] = key
|
||||||
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return super().__getitem__(self.keymap.get(key.lower(), key))
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
super().__delitem__(self.keymap.get(key.lower(), key))
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return self.keymap.get(key.lower(), key) in self.keys()
|
||||||
|
|
||||||
|
|
||||||
class MultiDict(dict):
|
class MultiDict(dict):
|
||||||
"""A subclass of dictionary that can hold multiple values for the same
|
"""A subclass of dictionary that can hold multiple values for the same
|
||||||
key. It is used to hold key/value pairs decoded from query strings and
|
key. It is used to hold key/value pairs decoded from query strings and
|
||||||
@@ -248,16 +289,14 @@ class Request():
|
|||||||
self.path, self.query_string = self.path.split('?', 1)
|
self.path, self.query_string = self.path.split('?', 1)
|
||||||
self.args = self._parse_urlencoded(self.query_string)
|
self.args = self._parse_urlencoded(self.query_string)
|
||||||
|
|
||||||
for header, value in self.headers.items():
|
if 'Content-Length' in self.headers:
|
||||||
header = header.lower()
|
self.content_length = int(self.headers['Content-Length'])
|
||||||
if header == 'content-length':
|
if 'Content-Type' in self.headers:
|
||||||
self.content_length = int(value)
|
self.content_type = self.headers['Content-Type']
|
||||||
elif header == 'content-type':
|
if 'Cookie' in self.headers:
|
||||||
self.content_type = value
|
for cookie in self.headers['Cookie'].split(';'):
|
||||||
elif header == 'cookie':
|
name, value = cookie.strip().split('=', 1)
|
||||||
for cookie in value.split(';'):
|
self.cookies[name] = value
|
||||||
name, value = cookie.strip().split('=', 1)
|
|
||||||
self.cookies[name] = value
|
|
||||||
|
|
||||||
self._body = body
|
self._body = body
|
||||||
self.body_used = False
|
self.body_used = False
|
||||||
@@ -289,7 +328,7 @@ class Request():
|
|||||||
http_version = http_version.split('/', 1)[1]
|
http_version = http_version.split('/', 1)[1]
|
||||||
|
|
||||||
# headers
|
# headers
|
||||||
headers = {}
|
headers = NoCaseDict()
|
||||||
while True:
|
while True:
|
||||||
line = Request._safe_readline(client_stream).strip().decode()
|
line = Request._safe_readline(client_stream).strip().decode()
|
||||||
if line == '':
|
if line == '':
|
||||||
@@ -437,7 +476,7 @@ class Response():
|
|||||||
body = ''
|
body = ''
|
||||||
status_code = 204
|
status_code = 204
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.headers = headers.copy() if headers else {}
|
self.headers = NoCaseDict(headers or {})
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
if isinstance(body, (dict, list)):
|
if isinstance(body, (dict, list)):
|
||||||
self.body = json.dumps(body).encode()
|
self.body = json.dumps(body).encode()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import signal
|
|||||||
from microdot_asyncio import * # noqa: F401, F403
|
from microdot_asyncio import * # noqa: F401, F403
|
||||||
from microdot_asyncio import Microdot as BaseMicrodot
|
from microdot_asyncio import Microdot as BaseMicrodot
|
||||||
from microdot_asyncio import Request
|
from microdot_asyncio import Request
|
||||||
|
from microdot import NoCaseDict
|
||||||
|
|
||||||
|
|
||||||
class _BodyStream: # pragma: no cover
|
class _BodyStream: # pragma: no cover
|
||||||
@@ -55,7 +56,7 @@ class Microdot(BaseMicrodot):
|
|||||||
path = scope['path']
|
path = scope['path']
|
||||||
if 'query_string' in scope and scope['query_string']:
|
if 'query_string' in scope and scope['query_string']:
|
||||||
path += '?' + scope['query_string'].decode()
|
path += '?' + scope['query_string'].decode()
|
||||||
headers = {}
|
headers = NoCaseDict()
|
||||||
content_length = 0
|
content_length = 0
|
||||||
for key, value in scope.get('headers', []):
|
for key, value in scope.get('headers', []):
|
||||||
headers[key] = value
|
headers[key] = value
|
||||||
|
|||||||
@@ -17,9 +17,10 @@ except ImportError:
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from microdot import Microdot as BaseMicrodot
|
from microdot import Microdot as BaseMicrodot
|
||||||
from microdot import print_exception
|
from microdot import NoCaseDict
|
||||||
from microdot import Request as BaseRequest
|
from microdot import Request as BaseRequest
|
||||||
from microdot import Response as BaseResponse
|
from microdot import Response as BaseResponse
|
||||||
|
from microdot import print_exception
|
||||||
from microdot import HTTPException
|
from microdot import HTTPException
|
||||||
from microdot import MUTED_SOCKET_ERRORS
|
from microdot import MUTED_SOCKET_ERRORS
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ class Request(BaseRequest):
|
|||||||
http_version = http_version.split('/', 1)[1]
|
http_version = http_version.split('/', 1)[1]
|
||||||
|
|
||||||
# headers
|
# headers
|
||||||
headers = {}
|
headers = NoCaseDict()
|
||||||
content_length = 0
|
content_length = 0
|
||||||
while True:
|
while True:
|
||||||
line = (await Request._safe_readline(
|
line = (await Request._safe_readline(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import json
|
import json
|
||||||
from microdot import Request, Response
|
from microdot import Request, Response, NoCaseDict
|
||||||
try:
|
try:
|
||||||
from microdot_websocket import WebSocket
|
from microdot_websocket import WebSocket
|
||||||
except: # pragma: no cover # noqa: E722
|
except: # pragma: no cover # noqa: E722
|
||||||
@@ -46,11 +46,10 @@ class TestResponse:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _process_json_body(self):
|
def _process_json_body(self):
|
||||||
for name, value in self.headers.items(): # pragma: no branch
|
if 'Content-Type' in self.headers: # pragma: no branch
|
||||||
if name.lower() == 'content-type':
|
content_type = self.headers['Content-Type']
|
||||||
if value.lower().split(';')[0] == 'application/json':
|
if content_type.split(';')[0] == 'application/json':
|
||||||
self.json = json.loads(self.text)
|
self.json = json.loads(self.text)
|
||||||
break
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, res):
|
def create(cls, res):
|
||||||
@@ -97,13 +96,11 @@ class TestClient:
|
|||||||
body = b''
|
body = b''
|
||||||
elif isinstance(body, (dict, list)):
|
elif isinstance(body, (dict, list)):
|
||||||
body = json.dumps(body).encode()
|
body = json.dumps(body).encode()
|
||||||
if 'Content-Type' not in headers and \
|
if 'Content-Type' not in headers: # pragma: no cover
|
||||||
'content-type' not in headers: # pragma: no cover
|
|
||||||
headers['Content-Type'] = 'application/json'
|
headers['Content-Type'] = 'application/json'
|
||||||
elif isinstance(body, str):
|
elif isinstance(body, str):
|
||||||
body = body.encode()
|
body = body.encode()
|
||||||
if body and 'Content-Length' not in headers and \
|
if body and 'Content-Length' not in headers:
|
||||||
'content-length' not in headers:
|
|
||||||
headers['Content-Length'] = str(len(body))
|
headers['Content-Length'] = str(len(body))
|
||||||
if 'Host' not in headers: # pragma: no branch
|
if 'Host' not in headers: # pragma: no branch
|
||||||
headers['Host'] = 'example.com:1234'
|
headers['Host'] = 'example.com:1234'
|
||||||
@@ -132,29 +129,28 @@ class TestClient:
|
|||||||
return request_bytes
|
return request_bytes
|
||||||
|
|
||||||
def _update_cookies(self, res):
|
def _update_cookies(self, res):
|
||||||
for name, value in res.headers.items():
|
cookies = res.headers.get('Set-Cookie', [])
|
||||||
if name.lower() == 'set-cookie':
|
for cookie in cookies:
|
||||||
for cookie in value:
|
cookie_name, cookie_value = cookie.split('=', 1)
|
||||||
cookie_name, cookie_value = cookie.split('=', 1)
|
cookie_options = cookie_value.split(';')
|
||||||
cookie_options = cookie_value.split(';')
|
delete = False
|
||||||
delete = False
|
for option in cookie_options[1:]:
|
||||||
for option in cookie_options[1:]:
|
if option.strip().lower().startswith('expires='):
|
||||||
if option.strip().lower().startswith('expires='):
|
_, e = option.strip().split('=', 1)
|
||||||
_, e = option.strip().split('=', 1)
|
# this is a very limited parser for cookie expiry
|
||||||
# this is a very limited parser for cookie expiry
|
# that only detects a cookie deletion request when
|
||||||
# that only detects a cookie deletion request when
|
# the date is 1/1/1970
|
||||||
# the date is 1/1/1970
|
if '1 jan 1970' in e.lower(): # pragma: no branch
|
||||||
if '1 jan 1970' in e.lower(): # pragma: no branch
|
delete = True
|
||||||
delete = True
|
break
|
||||||
break
|
if delete:
|
||||||
if delete:
|
if cookie_name in self.cookies: # pragma: no branch
|
||||||
if cookie_name in self.cookies: # pragma: no branch
|
del self.cookies[cookie_name]
|
||||||
del self.cookies[cookie_name]
|
else:
|
||||||
else:
|
self.cookies[cookie_name] = cookie_options[0]
|
||||||
self.cookies[cookie_name] = cookie_options[0]
|
|
||||||
|
|
||||||
def request(self, method, path, headers=None, body=None, sock=None):
|
def request(self, method, path, headers=None, body=None, sock=None):
|
||||||
headers = headers or {}
|
headers = NoCaseDict(headers or {})
|
||||||
body, headers = self._process_body(body, headers)
|
body, headers = self._process_body(body, headers)
|
||||||
cookies, headers = self._process_cookies(headers)
|
cookies, headers = self._process_cookies(headers)
|
||||||
request_bytes = self._render_request(method, path, headers, body)
|
request_bytes = self._render_request(method, path, headers, body)
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
from microdot import * # noqa: F401, F403
|
from microdot import * # noqa: F401, F403
|
||||||
from microdot import Microdot as BaseMicrodot
|
from microdot import Microdot as BaseMicrodot, Request, NoCaseDict
|
||||||
from microdot import Request
|
|
||||||
|
|
||||||
|
|
||||||
class Microdot(BaseMicrodot):
|
class Microdot(BaseMicrodot):
|
||||||
@@ -15,7 +14,7 @@ class Microdot(BaseMicrodot):
|
|||||||
path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '')
|
path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '')
|
||||||
if 'QUERY_STRING' in environ and environ['QUERY_STRING']:
|
if 'QUERY_STRING' in environ and environ['QUERY_STRING']:
|
||||||
path += '?' + environ['QUERY_STRING']
|
path += '?' + environ['QUERY_STRING']
|
||||||
headers = {}
|
headers = NoCaseDict()
|
||||||
for k, v in environ.items():
|
for k, v in environ.items():
|
||||||
if k.startswith('HTTP_'):
|
if k.startswith('HTTP_'):
|
||||||
h = '-'.join([p.title() for p in k[5:].split('_')])
|
h = '-'.join([p.title() for p in k[5:].split('_')])
|
||||||
|
|||||||
@@ -1,31 +1,52 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from microdot import MultiDict
|
from microdot import MultiDict, NoCaseDict
|
||||||
|
|
||||||
|
|
||||||
class TestMultiDict(unittest.TestCase):
|
class TestMultiDict(unittest.TestCase):
|
||||||
def test_multidict(self):
|
def test_multidict(self):
|
||||||
d = MultiDict()
|
d = MultiDict()
|
||||||
|
|
||||||
assert dict(d) == {}
|
self.assertEqual(dict(d), {})
|
||||||
assert d.get('zero') is None
|
self.assertIsNone(d.get('zero'))
|
||||||
assert d.get('zero', default=0) == 0
|
self.assertEqual(d.get('zero', default=0), 0)
|
||||||
assert d.getlist('zero') == []
|
self.assertEqual(d.getlist('zero'), [])
|
||||||
assert d.getlist('zero', type=int) == []
|
self.assertEqual(d.getlist('zero', type=int), [])
|
||||||
|
|
||||||
d['one'] = 1
|
d['one'] = 1
|
||||||
assert d['one'] == 1
|
self.assertEqual(d['one'], 1)
|
||||||
assert d.get('one') == 1
|
self.assertEqual(d.get('one'), 1)
|
||||||
assert d.get('one', default=2) == 1
|
self.assertEqual(d.get('one', default=2), 1)
|
||||||
assert d.get('one', type=int) == 1
|
self.assertEqual(d.get('one', type=int), 1)
|
||||||
assert d.get('one', type=str) == '1'
|
self.assertEqual(d.get('one', type=str), '1')
|
||||||
|
|
||||||
d['two'] = 1
|
d['two'] = 1
|
||||||
d['two'] = 2
|
d['two'] = 2
|
||||||
assert d['two'] == 1
|
self.assertEqual(d['two'], 1)
|
||||||
assert d.get('two') == 1
|
self.assertEqual(d.get('two'), 1)
|
||||||
assert d.get('two', default=2) == 1
|
self.assertEqual(d.get('two', default=2), 1)
|
||||||
assert d.get('two', type=int) == 1
|
self.assertEqual(d.get('two', type=int), 1)
|
||||||
assert d.get('two', type=str) == '1'
|
self.assertEqual(d.get('two', type=str), '1')
|
||||||
assert d.getlist('two') == [1, 2]
|
self.assertEqual(d.getlist('two'), [1, 2])
|
||||||
assert d.getlist('two', type=int) == [1, 2]
|
self.assertEqual(d.getlist('two', type=int), [1, 2])
|
||||||
assert d.getlist('two', type=str) == ['1', '2']
|
self.assertEqual(d.getlist('two', type=str), ['1', '2'])
|
||||||
|
|
||||||
|
def test_case_insensitive_dict(self):
|
||||||
|
d = NoCaseDict()
|
||||||
|
|
||||||
|
d['One'] = 1
|
||||||
|
d['one'] = 2
|
||||||
|
d['ONE'] = 3
|
||||||
|
d['One'] = 4
|
||||||
|
d['two'] = 5
|
||||||
|
self.assertEqual(d['one'], 4)
|
||||||
|
self.assertEqual(d['One'], 4)
|
||||||
|
self.assertEqual(d['ONE'], 4)
|
||||||
|
self.assertEqual(d['onE'], 4)
|
||||||
|
self.assertIn(('One', 4), list(d.items()))
|
||||||
|
self.assertIn(('two', 5), list(d.items()))
|
||||||
|
self.assertIn(4, list(d.values()))
|
||||||
|
self.assertIn(5, list(d.values()))
|
||||||
|
|
||||||
|
del d['oNE']
|
||||||
|
self.assertEqual(list(d.items()), [('two', 5)])
|
||||||
|
self.assertEqual(list(d.values()), [5])
|
||||||
|
|||||||
Reference in New Issue
Block a user