Use a case insensitive dict for headers

This commit is contained in:
Miguel Grinberg
2022-09-21 23:29:01 +01:00
parent cbefb6bf3a
commit b0fd6c4323
6 changed files with 125 additions and 68 deletions

View File

@@ -91,6 +91,47 @@ def urldecode_bytes(s):
return b''.join(result).decode() return b''.join(result).decode()
class NoCaseDict(dict):
"""A subclass of dictionary that holds case-insensitive keys.
:param initial_dict: an initial dictionary of key/value pairs to
initialize this object with.
Example::
>>> d = NoCaseDict()
>>> d['Content-Type'] = 'text/html'
>>> print(d['Content-Type'])
text/html
>>> print(d['content-type'])
text/html
>>> print(d['CONTENT-TYPE'])
text/html
>>> del d['cOnTeNt-TyPe']
>>> print(d)
{}
"""
def __init__(self, initial_dict=None):
super().__init__(initial_dict or {})
self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k}
def __setitem__(self, key, value):
key = self.keymap.get(key.lower(), key)
if key.lower() != key:
self.keymap[key.lower()] = key
super().__setitem__(key, value)
def __getitem__(self, key):
return super().__getitem__(self.keymap.get(key.lower(), key))
def __delitem__(self, key):
super().__delitem__(self.keymap.get(key.lower(), key))
def __contains__(self, key):
return self.keymap.get(key.lower(), key) in self.keys()
class MultiDict(dict): class MultiDict(dict):
"""A subclass of dictionary that can hold multiple values for the same """A subclass of dictionary that can hold multiple values for the same
key. It is used to hold key/value pairs decoded from query strings and key. It is used to hold key/value pairs decoded from query strings and
@@ -248,16 +289,14 @@ class Request():
self.path, self.query_string = self.path.split('?', 1) self.path, self.query_string = self.path.split('?', 1)
self.args = self._parse_urlencoded(self.query_string) self.args = self._parse_urlencoded(self.query_string)
for header, value in self.headers.items(): if 'Content-Length' in self.headers:
header = header.lower() self.content_length = int(self.headers['Content-Length'])
if header == 'content-length': if 'Content-Type' in self.headers:
self.content_length = int(value) self.content_type = self.headers['Content-Type']
elif header == 'content-type': if 'Cookie' in self.headers:
self.content_type = value for cookie in self.headers['Cookie'].split(';'):
elif header == 'cookie': name, value = cookie.strip().split('=', 1)
for cookie in value.split(';'): self.cookies[name] = value
name, value = cookie.strip().split('=', 1)
self.cookies[name] = value
self._body = body self._body = body
self.body_used = False self.body_used = False
@@ -289,7 +328,7 @@ class Request():
http_version = http_version.split('/', 1)[1] http_version = http_version.split('/', 1)[1]
# headers # headers
headers = {} headers = NoCaseDict()
while True: while True:
line = Request._safe_readline(client_stream).strip().decode() line = Request._safe_readline(client_stream).strip().decode()
if line == '': if line == '':
@@ -437,7 +476,7 @@ class Response():
body = '' body = ''
status_code = 204 status_code = 204
self.status_code = status_code self.status_code = status_code
self.headers = headers.copy() if headers else {} self.headers = NoCaseDict(headers or {})
self.reason = reason self.reason = reason
if isinstance(body, (dict, list)): if isinstance(body, (dict, list)):
self.body = json.dumps(body).encode() self.body = json.dumps(body).encode()

View File

@@ -4,6 +4,7 @@ import signal
from microdot_asyncio import * # noqa: F401, F403 from microdot_asyncio import * # noqa: F401, F403
from microdot_asyncio import Microdot as BaseMicrodot from microdot_asyncio import Microdot as BaseMicrodot
from microdot_asyncio import Request from microdot_asyncio import Request
from microdot import NoCaseDict
class _BodyStream: # pragma: no cover class _BodyStream: # pragma: no cover
@@ -55,7 +56,7 @@ class Microdot(BaseMicrodot):
path = scope['path'] path = scope['path']
if 'query_string' in scope and scope['query_string']: if 'query_string' in scope and scope['query_string']:
path += '?' + scope['query_string'].decode() path += '?' + scope['query_string'].decode()
headers = {} headers = NoCaseDict()
content_length = 0 content_length = 0
for key, value in scope.get('headers', []): for key, value in scope.get('headers', []):
headers[key] = value headers[key] = value

View File

@@ -17,9 +17,10 @@ except ImportError:
import io import io
from microdot import Microdot as BaseMicrodot from microdot import Microdot as BaseMicrodot
from microdot import print_exception from microdot import NoCaseDict
from microdot import Request as BaseRequest from microdot import Request as BaseRequest
from microdot import Response as BaseResponse from microdot import Response as BaseResponse
from microdot import print_exception
from microdot import HTTPException from microdot import HTTPException
from microdot import MUTED_SOCKET_ERRORS from microdot import MUTED_SOCKET_ERRORS
@@ -74,7 +75,7 @@ class Request(BaseRequest):
http_version = http_version.split('/', 1)[1] http_version = http_version.split('/', 1)[1]
# headers # headers
headers = {} headers = NoCaseDict()
content_length = 0 content_length = 0
while True: while True:
line = (await Request._safe_readline( line = (await Request._safe_readline(

View File

@@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
import json import json
from microdot import Request, Response from microdot import Request, Response, NoCaseDict
try: try:
from microdot_websocket import WebSocket from microdot_websocket import WebSocket
except: # pragma: no cover # noqa: E722 except: # pragma: no cover # noqa: E722
@@ -46,11 +46,10 @@ class TestResponse:
pass pass
def _process_json_body(self): def _process_json_body(self):
for name, value in self.headers.items(): # pragma: no branch if 'Content-Type' in self.headers: # pragma: no branch
if name.lower() == 'content-type': content_type = self.headers['Content-Type']
if value.lower().split(';')[0] == 'application/json': if content_type.split(';')[0] == 'application/json':
self.json = json.loads(self.text) self.json = json.loads(self.text)
break
@classmethod @classmethod
def create(cls, res): def create(cls, res):
@@ -97,13 +96,11 @@ class TestClient:
body = b'' body = b''
elif isinstance(body, (dict, list)): elif isinstance(body, (dict, list)):
body = json.dumps(body).encode() body = json.dumps(body).encode()
if 'Content-Type' not in headers and \ if 'Content-Type' not in headers: # pragma: no cover
'content-type' not in headers: # pragma: no cover
headers['Content-Type'] = 'application/json' headers['Content-Type'] = 'application/json'
elif isinstance(body, str): elif isinstance(body, str):
body = body.encode() body = body.encode()
if body and 'Content-Length' not in headers and \ if body and 'Content-Length' not in headers:
'content-length' not in headers:
headers['Content-Length'] = str(len(body)) headers['Content-Length'] = str(len(body))
if 'Host' not in headers: # pragma: no branch if 'Host' not in headers: # pragma: no branch
headers['Host'] = 'example.com:1234' headers['Host'] = 'example.com:1234'
@@ -132,29 +129,28 @@ class TestClient:
return request_bytes return request_bytes
def _update_cookies(self, res): def _update_cookies(self, res):
for name, value in res.headers.items(): cookies = res.headers.get('Set-Cookie', [])
if name.lower() == 'set-cookie': for cookie in cookies:
for cookie in value: cookie_name, cookie_value = cookie.split('=', 1)
cookie_name, cookie_value = cookie.split('=', 1) cookie_options = cookie_value.split(';')
cookie_options = cookie_value.split(';') delete = False
delete = False for option in cookie_options[1:]:
for option in cookie_options[1:]: if option.strip().lower().startswith('expires='):
if option.strip().lower().startswith('expires='): _, e = option.strip().split('=', 1)
_, e = option.strip().split('=', 1) # this is a very limited parser for cookie expiry
# this is a very limited parser for cookie expiry # that only detects a cookie deletion request when
# that only detects a cookie deletion request when # the date is 1/1/1970
# the date is 1/1/1970 if '1 jan 1970' in e.lower(): # pragma: no branch
if '1 jan 1970' in e.lower(): # pragma: no branch delete = True
delete = True break
break if delete:
if delete: if cookie_name in self.cookies: # pragma: no branch
if cookie_name in self.cookies: # pragma: no branch del self.cookies[cookie_name]
del self.cookies[cookie_name] else:
else: self.cookies[cookie_name] = cookie_options[0]
self.cookies[cookie_name] = cookie_options[0]
def request(self, method, path, headers=None, body=None, sock=None): def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {} headers = NoCaseDict(headers or {})
body, headers = self._process_body(body, headers) body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers) cookies, headers = self._process_cookies(headers)
request_bytes = self._render_request(method, path, headers, body) request_bytes = self._render_request(method, path, headers, body)

View File

@@ -1,8 +1,7 @@
import os import os
import signal import signal
from microdot import * # noqa: F401, F403 from microdot import * # noqa: F401, F403
from microdot import Microdot as BaseMicrodot from microdot import Microdot as BaseMicrodot, Request, NoCaseDict
from microdot import Request
class Microdot(BaseMicrodot): class Microdot(BaseMicrodot):
@@ -15,7 +14,7 @@ class Microdot(BaseMicrodot):
path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '') path = environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', '')
if 'QUERY_STRING' in environ and environ['QUERY_STRING']: if 'QUERY_STRING' in environ and environ['QUERY_STRING']:
path += '?' + environ['QUERY_STRING'] path += '?' + environ['QUERY_STRING']
headers = {} headers = NoCaseDict()
for k, v in environ.items(): for k, v in environ.items():
if k.startswith('HTTP_'): if k.startswith('HTTP_'):
h = '-'.join([p.title() for p in k[5:].split('_')]) h = '-'.join([p.title() for p in k[5:].split('_')])

View File

@@ -1,31 +1,52 @@
import unittest import unittest
from microdot import MultiDict from microdot import MultiDict, NoCaseDict
class TestMultiDict(unittest.TestCase): class TestMultiDict(unittest.TestCase):
def test_multidict(self): def test_multidict(self):
d = MultiDict() d = MultiDict()
assert dict(d) == {} self.assertEqual(dict(d), {})
assert d.get('zero') is None self.assertIsNone(d.get('zero'))
assert d.get('zero', default=0) == 0 self.assertEqual(d.get('zero', default=0), 0)
assert d.getlist('zero') == [] self.assertEqual(d.getlist('zero'), [])
assert d.getlist('zero', type=int) == [] self.assertEqual(d.getlist('zero', type=int), [])
d['one'] = 1 d['one'] = 1
assert d['one'] == 1 self.assertEqual(d['one'], 1)
assert d.get('one') == 1 self.assertEqual(d.get('one'), 1)
assert d.get('one', default=2) == 1 self.assertEqual(d.get('one', default=2), 1)
assert d.get('one', type=int) == 1 self.assertEqual(d.get('one', type=int), 1)
assert d.get('one', type=str) == '1' self.assertEqual(d.get('one', type=str), '1')
d['two'] = 1 d['two'] = 1
d['two'] = 2 d['two'] = 2
assert d['two'] == 1 self.assertEqual(d['two'], 1)
assert d.get('two') == 1 self.assertEqual(d.get('two'), 1)
assert d.get('two', default=2) == 1 self.assertEqual(d.get('two', default=2), 1)
assert d.get('two', type=int) == 1 self.assertEqual(d.get('two', type=int), 1)
assert d.get('two', type=str) == '1' self.assertEqual(d.get('two', type=str), '1')
assert d.getlist('two') == [1, 2] self.assertEqual(d.getlist('two'), [1, 2])
assert d.getlist('two', type=int) == [1, 2] self.assertEqual(d.getlist('two', type=int), [1, 2])
assert d.getlist('two', type=str) == ['1', '2'] self.assertEqual(d.getlist('two', type=str), ['1', '2'])
def test_case_insensitive_dict(self):
d = NoCaseDict()
d['One'] = 1
d['one'] = 2
d['ONE'] = 3
d['One'] = 4
d['two'] = 5
self.assertEqual(d['one'], 4)
self.assertEqual(d['One'], 4)
self.assertEqual(d['ONE'], 4)
self.assertEqual(d['onE'], 4)
self.assertIn(('One', 4), list(d.items()))
self.assertIn(('two', 5), list(d.items()))
self.assertIn(4, list(d.values()))
self.assertIn(5, list(d.values()))
del d['oNE']
self.assertEqual(list(d.items()), [('two', 5)])
self.assertEqual(list(d.values()), [5])