Use a case insensitive dict for headers
This commit is contained in:
@@ -91,6 +91,47 @@ def urldecode_bytes(s):
|
||||
return b''.join(result).decode()
|
||||
|
||||
|
||||
class NoCaseDict(dict):
|
||||
"""A subclass of dictionary that holds case-insensitive keys.
|
||||
|
||||
:param initial_dict: an initial dictionary of key/value pairs to
|
||||
initialize this object with.
|
||||
|
||||
Example::
|
||||
|
||||
>>> d = NoCaseDict()
|
||||
>>> d['Content-Type'] = 'text/html'
|
||||
>>> print(d['Content-Type'])
|
||||
text/html
|
||||
>>> print(d['content-type'])
|
||||
text/html
|
||||
>>> print(d['CONTENT-TYPE'])
|
||||
text/html
|
||||
>>> del d['cOnTeNt-TyPe']
|
||||
>>> print(d)
|
||||
{}
|
||||
"""
|
||||
|
||||
def __init__(self, initial_dict=None):
|
||||
super().__init__(initial_dict or {})
|
||||
self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k}
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
key = self.keymap.get(key.lower(), key)
|
||||
if key.lower() != key:
|
||||
self.keymap[key.lower()] = key
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return super().__getitem__(self.keymap.get(key.lower(), key))
|
||||
|
||||
def __delitem__(self, key):
|
||||
super().__delitem__(self.keymap.get(key.lower(), key))
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.keymap.get(key.lower(), key) in self.keys()
|
||||
|
||||
|
||||
class MultiDict(dict):
|
||||
"""A subclass of dictionary that can hold multiple values for the same
|
||||
key. It is used to hold key/value pairs decoded from query strings and
|
||||
@@ -248,14 +289,12 @@ class Request():
|
||||
self.path, self.query_string = self.path.split('?', 1)
|
||||
self.args = self._parse_urlencoded(self.query_string)
|
||||
|
||||
for header, value in self.headers.items():
|
||||
header = header.lower()
|
||||
if header == 'content-length':
|
||||
self.content_length = int(value)
|
||||
elif header == 'content-type':
|
||||
self.content_type = value
|
||||
elif header == 'cookie':
|
||||
for cookie in value.split(';'):
|
||||
if 'Content-Length' in self.headers:
|
||||
self.content_length = int(self.headers['Content-Length'])
|
||||
if 'Content-Type' in self.headers:
|
||||
self.content_type = self.headers['Content-Type']
|
||||
if 'Cookie' in self.headers:
|
||||
for cookie in self.headers['Cookie'].split(';'):
|
||||
name, value = cookie.strip().split('=', 1)
|
||||
self.cookies[name] = value
|
||||
|
||||
@@ -289,7 +328,7 @@ class Request():
|
||||
http_version = http_version.split('/', 1)[1]
|
||||
|
||||
# headers
|
||||
headers = {}
|
||||
headers = NoCaseDict()
|
||||
while True:
|
||||
line = Request._safe_readline(client_stream).strip().decode()
|
||||
if line == '':
|
||||
@@ -437,7 +476,7 @@ class Response():
|
||||
body = ''
|
||||
status_code = 204
|
||||
self.status_code = status_code
|
||||
self.headers = headers.copy() if headers else {}
|
||||
self.headers = NoCaseDict(headers or {})
|
||||
self.reason = reason
|
||||
if isinstance(body, (dict, list)):
|
||||
self.body = json.dumps(body).encode()
|
||||
|
||||
@@ -4,6 +4,7 @@ import signal
|
||||
from microdot_asyncio import * # noqa: F401, F403
|
||||
from microdot_asyncio import Microdot as BaseMicrodot
|
||||
from microdot_asyncio import Request
|
||||
from microdot import NoCaseDict
|
||||
|
||||
|
||||
class _BodyStream: # pragma: no cover
|
||||
@@ -55,7 +56,7 @@ class Microdot(BaseMicrodot):
|
||||
path = scope['path']
|
||||
if 'query_string' in scope and scope['query_string']:
|
||||
path += '?' + scope['query_string'].decode()
|
||||
headers = {}
|
||||
headers = NoCaseDict()
|
||||
content_length = 0
|
||||
for key, value in scope.get('headers', []):
|
||||
headers[key] = value
|
||||
|
||||
@@ -17,9 +17,10 @@ except ImportError:
|
||||
import io
|
||||
|
||||
from microdot import Microdot as BaseMicrodot
|
||||
from microdot import print_exception
|
||||
from microdot import NoCaseDict
|
||||
from microdot import Request as BaseRequest
|
||||
from microdot import Response as BaseResponse
|
||||
from microdot import print_exception
|
||||
from microdot import HTTPException
|
||||
from microdot import MUTED_SOCKET_ERRORS
|
||||
|
||||
@@ -74,7 +75,7 @@ class Request(BaseRequest):
|
||||
http_version = http_version.split('/', 1)[1]
|
||||
|
||||
# headers
|
||||
headers = {}
|
||||
headers = NoCaseDict()
|
||||
content_length = 0
|
||||
while True:
|
||||
line = (await Request._safe_readline(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from io import BytesIO
|
||||
import json
|
||||
from microdot import Request, Response
|
||||
from microdot import Request, Response, NoCaseDict
|
||||
try:
|
||||
from microdot_websocket import WebSocket
|
||||
except: # pragma: no cover # noqa: E722
|
||||
@@ -46,11 +46,10 @@ class TestResponse:
|
||||
pass
|
||||
|
||||
def _process_json_body(self):
|
||||
for name, value in self.headers.items(): # pragma: no branch
|
||||
if name.lower() == 'content-type':
|
||||
if value.lower().split(';')[0] == 'application/json':
|
||||
if 'Content-Type' in self.headers: # pragma: no branch
|
||||
content_type = self.headers['Content-Type']
|
||||
if content_type.split(';')[0] == 'application/json':
|
||||
self.json = json.loads(self.text)
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def create(cls, res):
|
||||
@@ -97,13 +96,11 @@ class TestClient:
|
||||
body = b''
|
||||
elif isinstance(body, (dict, list)):
|
||||
body = json.dumps(body).encode()
|
||||
if 'Content-Type' not in headers and \
|
||||
'content-type' not in headers: # pragma: no cover
|
||||
if 'Content-Type' not in headers: # pragma: no cover
|
||||
headers['Content-Type'] = 'application/json'
|
||||
elif isinstance(body, str):
|
||||
body = body.encode()
|
||||
if body and 'Content-Length' not in headers and \
|
||||
'content-length' not in headers:
|
||||
if body and 'Content-Length' not in headers:
|
||||
headers['Content-Length'] = str(len(body))
|
||||
if 'Host' not in headers: # pragma: no branch
|
||||
headers['Host'] = 'example.com:1234'
|
||||
@@ -132,9 +129,8 @@ class TestClient:
|
||||
return request_bytes
|
||||
|
||||
def _update_cookies(self, res):
|
||||
for name, value in res.headers.items():
|
||||
if name.lower() == 'set-cookie':
|
||||
for cookie in value:
|
||||
cookies = res.headers.get('Set-Cookie', [])
|
||||
for cookie in cookies:
|
||||
cookie_name, cookie_value = cookie.split('=', 1)
|
||||
cookie_options = cookie_value.split(';')
|
||||
delete = False
|
||||
@@ -154,7 +150,7 @@ class TestClient:
|
||||
self.cookies[cookie_name] = cookie_options[0]
|
||||
|
||||
def request(self, method, path, headers=None, body=None, sock=None):
|
||||
headers = headers or {}
|
||||
headers = NoCaseDict(headers or {})
|
||||
body, headers = self._process_body(body, headers)
|
||||
cookies, headers = self._process_cookies(headers)
|
||||
request_bytes = self._render_request(method, path, headers, body)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import os
|
||||
import signal
|
||||
from microdot import * # noqa: F401, F403
|
||||
from microdot import Microdot as BaseMicrodot
|
||||
from microdot import Request
|
||||
from microdot import Microdot as BaseMicrodot, Request, NoCaseDict
|
||||
|
||||
|
||||
class Microdot(BaseMicrodot):
|
||||
@@ -15,7 +14,7 @@ class Microdot(BaseMicrodot):
|
||||
path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '')
|
||||
if 'QUERY_STRING' in environ and environ['QUERY_STRING']:
|
||||
path += '?' + environ['QUERY_STRING']
|
||||
headers = {}
|
||||
headers = NoCaseDict()
|
||||
for k, v in environ.items():
|
||||
if k.startswith('HTTP_'):
|
||||
h = '-'.join([p.title() for p in k[5:].split('_')])
|
||||
|
||||
@@ -1,31 +1,52 @@
|
||||
import unittest
|
||||
from microdot import MultiDict
|
||||
from microdot import MultiDict, NoCaseDict
|
||||
|
||||
|
||||
class TestMultiDict(unittest.TestCase):
|
||||
def test_multidict(self):
|
||||
d = MultiDict()
|
||||
|
||||
assert dict(d) == {}
|
||||
assert d.get('zero') is None
|
||||
assert d.get('zero', default=0) == 0
|
||||
assert d.getlist('zero') == []
|
||||
assert d.getlist('zero', type=int) == []
|
||||
self.assertEqual(dict(d), {})
|
||||
self.assertIsNone(d.get('zero'))
|
||||
self.assertEqual(d.get('zero', default=0), 0)
|
||||
self.assertEqual(d.getlist('zero'), [])
|
||||
self.assertEqual(d.getlist('zero', type=int), [])
|
||||
|
||||
d['one'] = 1
|
||||
assert d['one'] == 1
|
||||
assert d.get('one') == 1
|
||||
assert d.get('one', default=2) == 1
|
||||
assert d.get('one', type=int) == 1
|
||||
assert d.get('one', type=str) == '1'
|
||||
self.assertEqual(d['one'], 1)
|
||||
self.assertEqual(d.get('one'), 1)
|
||||
self.assertEqual(d.get('one', default=2), 1)
|
||||
self.assertEqual(d.get('one', type=int), 1)
|
||||
self.assertEqual(d.get('one', type=str), '1')
|
||||
|
||||
d['two'] = 1
|
||||
d['two'] = 2
|
||||
assert d['two'] == 1
|
||||
assert d.get('two') == 1
|
||||
assert d.get('two', default=2) == 1
|
||||
assert d.get('two', type=int) == 1
|
||||
assert d.get('two', type=str) == '1'
|
||||
assert d.getlist('two') == [1, 2]
|
||||
assert d.getlist('two', type=int) == [1, 2]
|
||||
assert d.getlist('two', type=str) == ['1', '2']
|
||||
self.assertEqual(d['two'], 1)
|
||||
self.assertEqual(d.get('two'), 1)
|
||||
self.assertEqual(d.get('two', default=2), 1)
|
||||
self.assertEqual(d.get('two', type=int), 1)
|
||||
self.assertEqual(d.get('two', type=str), '1')
|
||||
self.assertEqual(d.getlist('two'), [1, 2])
|
||||
self.assertEqual(d.getlist('two', type=int), [1, 2])
|
||||
self.assertEqual(d.getlist('two', type=str), ['1', '2'])
|
||||
|
||||
def test_case_insensitive_dict(self):
|
||||
d = NoCaseDict()
|
||||
|
||||
d['One'] = 1
|
||||
d['one'] = 2
|
||||
d['ONE'] = 3
|
||||
d['One'] = 4
|
||||
d['two'] = 5
|
||||
self.assertEqual(d['one'], 4)
|
||||
self.assertEqual(d['One'], 4)
|
||||
self.assertEqual(d['ONE'], 4)
|
||||
self.assertEqual(d['onE'], 4)
|
||||
self.assertIn(('One', 4), list(d.items()))
|
||||
self.assertIn(('two', 5), list(d.items()))
|
||||
self.assertIn(4, list(d.values()))
|
||||
self.assertIn(5, list(d.values()))
|
||||
|
||||
del d['oNE']
|
||||
self.assertEqual(list(d.items()), [('two', 5)])
|
||||
self.assertEqual(list(d.values()), [5])
|
||||
|
||||
Reference in New Issue
Block a user