Use a case insensitive dict for headers

This commit is contained in:
Miguel Grinberg
2022-09-21 23:29:01 +01:00
parent cbefb6bf3a
commit b0fd6c4323
6 changed files with 125 additions and 68 deletions

View File

@@ -91,6 +91,47 @@ def urldecode_bytes(s):
return b''.join(result).decode()
class NoCaseDict(dict):
"""A subclass of dictionary that holds case-insensitive keys.
:param initial_dict: an initial dictionary of key/value pairs to
initialize this object with.
Example::
>>> d = NoCaseDict()
>>> d['Content-Type'] = 'text/html'
>>> print(d['Content-Type'])
text/html
>>> print(d['content-type'])
text/html
>>> print(d['CONTENT-TYPE'])
text/html
>>> del d['cOnTeNt-TyPe']
>>> print(d)
{}
"""
def __init__(self, initial_dict=None):
super().__init__(initial_dict or {})
self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k}
def __setitem__(self, key, value):
key = self.keymap.get(key.lower(), key)
if key.lower() != key:
self.keymap[key.lower()] = key
super().__setitem__(key, value)
def __getitem__(self, key):
return super().__getitem__(self.keymap.get(key.lower(), key))
def __delitem__(self, key):
super().__delitem__(self.keymap.get(key.lower(), key))
def __contains__(self, key):
return self.keymap.get(key.lower(), key) in self.keys()
class MultiDict(dict):
"""A subclass of dictionary that can hold multiple values for the same
key. It is used to hold key/value pairs decoded from query strings and
@@ -248,14 +289,12 @@ class Request():
self.path, self.query_string = self.path.split('?', 1)
self.args = self._parse_urlencoded(self.query_string)
for header, value in self.headers.items():
header = header.lower()
if header == 'content-length':
self.content_length = int(value)
elif header == 'content-type':
self.content_type = value
elif header == 'cookie':
for cookie in value.split(';'):
if 'Content-Length' in self.headers:
self.content_length = int(self.headers['Content-Length'])
if 'Content-Type' in self.headers:
self.content_type = self.headers['Content-Type']
if 'Cookie' in self.headers:
for cookie in self.headers['Cookie'].split(';'):
name, value = cookie.strip().split('=', 1)
self.cookies[name] = value
@@ -289,7 +328,7 @@ class Request():
http_version = http_version.split('/', 1)[1]
# headers
headers = {}
headers = NoCaseDict()
while True:
line = Request._safe_readline(client_stream).strip().decode()
if line == '':
@@ -437,7 +476,7 @@ class Response():
body = ''
status_code = 204
self.status_code = status_code
self.headers = headers.copy() if headers else {}
self.headers = NoCaseDict(headers or {})
self.reason = reason
if isinstance(body, (dict, list)):
self.body = json.dumps(body).encode()

View File

@@ -4,6 +4,7 @@ import signal
from microdot_asyncio import * # noqa: F401, F403
from microdot_asyncio import Microdot as BaseMicrodot
from microdot_asyncio import Request
from microdot import NoCaseDict
class _BodyStream: # pragma: no cover
@@ -55,7 +56,7 @@ class Microdot(BaseMicrodot):
path = scope['path']
if 'query_string' in scope and scope['query_string']:
path += '?' + scope['query_string'].decode()
headers = {}
headers = NoCaseDict()
content_length = 0
for key, value in scope.get('headers', []):
headers[key] = value

View File

@@ -17,9 +17,10 @@ except ImportError:
import io
from microdot import Microdot as BaseMicrodot
from microdot import print_exception
from microdot import NoCaseDict
from microdot import Request as BaseRequest
from microdot import Response as BaseResponse
from microdot import print_exception
from microdot import HTTPException
from microdot import MUTED_SOCKET_ERRORS
@@ -74,7 +75,7 @@ class Request(BaseRequest):
http_version = http_version.split('/', 1)[1]
# headers
headers = {}
headers = NoCaseDict()
content_length = 0
while True:
line = (await Request._safe_readline(

View File

@@ -1,6 +1,6 @@
from io import BytesIO
import json
from microdot import Request, Response
from microdot import Request, Response, NoCaseDict
try:
from microdot_websocket import WebSocket
except: # pragma: no cover # noqa: E722
@@ -46,11 +46,10 @@ class TestResponse:
pass
def _process_json_body(self):
for name, value in self.headers.items(): # pragma: no branch
if name.lower() == 'content-type':
if value.lower().split(';')[0] == 'application/json':
if 'Content-Type' in self.headers: # pragma: no branch
content_type = self.headers['Content-Type']
if content_type.split(';')[0] == 'application/json':
self.json = json.loads(self.text)
break
@classmethod
def create(cls, res):
@@ -97,13 +96,11 @@ class TestClient:
body = b''
elif isinstance(body, (dict, list)):
body = json.dumps(body).encode()
if 'Content-Type' not in headers and \
'content-type' not in headers: # pragma: no cover
if 'Content-Type' not in headers: # pragma: no cover
headers['Content-Type'] = 'application/json'
elif isinstance(body, str):
body = body.encode()
if body and 'Content-Length' not in headers and \
'content-length' not in headers:
if body and 'Content-Length' not in headers:
headers['Content-Length'] = str(len(body))
if 'Host' not in headers: # pragma: no branch
headers['Host'] = 'example.com:1234'
@@ -132,9 +129,8 @@ class TestClient:
return request_bytes
def _update_cookies(self, res):
for name, value in res.headers.items():
if name.lower() == 'set-cookie':
for cookie in value:
cookies = res.headers.get('Set-Cookie', [])
for cookie in cookies:
cookie_name, cookie_value = cookie.split('=', 1)
cookie_options = cookie_value.split(';')
delete = False
@@ -154,7 +150,7 @@ class TestClient:
self.cookies[cookie_name] = cookie_options[0]
def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {}
headers = NoCaseDict(headers or {})
body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers)
request_bytes = self._render_request(method, path, headers, body)

View File

@@ -1,8 +1,7 @@
import os
import signal
from microdot import * # noqa: F401, F403
from microdot import Microdot as BaseMicrodot
from microdot import Request
from microdot import Microdot as BaseMicrodot, Request, NoCaseDict
class Microdot(BaseMicrodot):
@@ -15,7 +14,7 @@ class Microdot(BaseMicrodot):
path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '')
if 'QUERY_STRING' in environ and environ['QUERY_STRING']:
path += '?' + environ['QUERY_STRING']
headers = {}
headers = NoCaseDict()
for k, v in environ.items():
if k.startswith('HTTP_'):
h = '-'.join([p.title() for p in k[5:].split('_')])

View File

@@ -1,31 +1,52 @@
import unittest
from microdot import MultiDict
from microdot import MultiDict, NoCaseDict
class TestMultiDict(unittest.TestCase):
def test_multidict(self):
d = MultiDict()
assert dict(d) == {}
assert d.get('zero') is None
assert d.get('zero', default=0) == 0
assert d.getlist('zero') == []
assert d.getlist('zero', type=int) == []
self.assertEqual(dict(d), {})
self.assertIsNone(d.get('zero'))
self.assertEqual(d.get('zero', default=0), 0)
self.assertEqual(d.getlist('zero'), [])
self.assertEqual(d.getlist('zero', type=int), [])
d['one'] = 1
assert d['one'] == 1
assert d.get('one') == 1
assert d.get('one', default=2) == 1
assert d.get('one', type=int) == 1
assert d.get('one', type=str) == '1'
self.assertEqual(d['one'], 1)
self.assertEqual(d.get('one'), 1)
self.assertEqual(d.get('one', default=2), 1)
self.assertEqual(d.get('one', type=int), 1)
self.assertEqual(d.get('one', type=str), '1')
d['two'] = 1
d['two'] = 2
assert d['two'] == 1
assert d.get('two') == 1
assert d.get('two', default=2) == 1
assert d.get('two', type=int) == 1
assert d.get('two', type=str) == '1'
assert d.getlist('two') == [1, 2]
assert d.getlist('two', type=int) == [1, 2]
assert d.getlist('two', type=str) == ['1', '2']
self.assertEqual(d['two'], 1)
self.assertEqual(d.get('two'), 1)
self.assertEqual(d.get('two', default=2), 1)
self.assertEqual(d.get('two', type=int), 1)
self.assertEqual(d.get('two', type=str), '1')
self.assertEqual(d.getlist('two'), [1, 2])
self.assertEqual(d.getlist('two', type=int), [1, 2])
self.assertEqual(d.getlist('two', type=str), ['1', '2'])
def test_case_insensitive_dict(self):
d = NoCaseDict()
d['One'] = 1
d['one'] = 2
d['ONE'] = 3
d['One'] = 4
d['two'] = 5
self.assertEqual(d['one'], 4)
self.assertEqual(d['One'], 4)
self.assertEqual(d['ONE'], 4)
self.assertEqual(d['onE'], 4)
self.assertIn(('One', 4), list(d.items()))
self.assertIn(('two', 5), list(d.items()))
self.assertIn(4, list(d.values()))
self.assertIn(5, list(d.values()))
del d['oNE']
self.assertEqual(list(d.items()), [('two', 5)])
self.assertEqual(list(d.values()), [5])