Immutable QueryParams (#1600)

* Tweak QueryParams implementation

* Immutable QueryParams
This commit is contained in:
Tom Christie 2021-04-26 14:57:02 +01:00 committed by GitHub
parent 8fe32c52de
commit 2abb2f214a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 188 additions and 67 deletions

View File

@ -385,7 +385,7 @@ class BaseClient:
"""
if params or self.params:
merged_queryparams = QueryParams(self.params)
merged_queryparams.update(params)
merged_queryparams = merged_queryparams.merge(params)
return merged_queryparams
return params

View File

@ -6,7 +6,7 @@ import typing
import urllib.request
from collections.abc import MutableMapping
from http.cookiejar import Cookie, CookieJar
from urllib.parse import parse_qsl, quote, unquote, urlencode
from urllib.parse import parse_qs, quote, unquote, urlencode
import idna
import rfc3986
@ -48,7 +48,6 @@ from ._types import (
URLTypes,
)
from ._utils import (
flatten_queryparams,
guess_json_utf,
is_known_encoding,
normalize_header_key,
@ -148,8 +147,7 @@ class URL:
# Add any query parameters, merging with any in the URL if needed.
if params:
if self._uri_reference.query:
url_params = QueryParams(self._uri_reference.query)
url_params.update(params)
url_params = QueryParams(self._uri_reference.query).merge(params)
query_string = str(url_params)
else:
query_string = str(QueryParams(params))
@ -450,7 +448,7 @@ class URL:
url = httpx.URL("https://www.example.com/test")
url = url.join("/new/path")
assert url == "https://www.example.com/test/new/path"
assert url == "https://www.example.com/new/path"
"""
if self.is_relative_url:
# Workaround to handle relative URLs, which otherwise raise
@ -504,38 +502,79 @@ class QueryParams(typing.Mapping[str, str]):
items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
if value is None or isinstance(value, (str, bytes)):
value = value.decode("ascii") if isinstance(value, bytes) else value
items = parse_qsl(value)
self._dict = parse_qs(value)
elif isinstance(value, QueryParams):
items = value.multi_items()
elif isinstance(value, (list, tuple)):
items = value
self._dict = {k: list(v) for k, v in value._dict.items()}
else:
items = flatten_queryparams(value)
self._dict: typing.Dict[str, typing.List[str]] = {}
for item in items:
k, v = item
if str(k) not in self._dict:
self._dict[str(k)] = [primitive_value_to_str(v)]
dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
if isinstance(value, (list, tuple)):
# Convert list inputs like:
# [("a", "123"), ("a", "456"), ("b", "789")]
# To a dict representation, like:
# {"a": ["123", "456"], "b": ["789"]}
for item in value:
dict_value.setdefault(item[0], []).append(item[1])
else:
self._dict[str(k)].append(primitive_value_to_str(v))
# Convert dict inputs like:
# {"a": "123", "b": ["456", "789"]}
# To dict inputs where values are always lists, like:
# {"a": ["123"], "b": ["456", "789"]}
dict_value = {
k: list(v) if isinstance(v, (list, tuple)) else [v]
for k, v in value.items()
}
# Ensure that keys and values are neatly coerced to strings.
# We coerce values `True` and `False` to JSON-like "true" and "false"
# representations, and coerce `None` values to the empty string.
self._dict = {
str(k): [primitive_value_to_str(item) for item in v]
for k, v in dict_value.items()
}
def keys(self) -> typing.KeysView:
"""
Return all the keys in the query params.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.keys()) == ["a", "b"]
"""
return self._dict.keys()
def values(self) -> typing.ValuesView:
"""
Return all the values in the query params. If a key occurs more than once
only the first item for that key is returned.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.values()) == ["123", "789"]
"""
return {k: v[0] for k, v in self._dict.items()}.values()
def items(self) -> typing.ItemsView:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.items()) == [("a", "123"), ("b", "789")]
"""
return {k: v[0] for k, v in self._dict.items()}.items()
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
"""
Return all items in the query params. Allow duplicate keys to occur.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
"""
multi_items: typing.List[typing.Tuple[str, str]] = []
for k, v in self._dict.items():
@ -546,31 +585,93 @@ class QueryParams(typing.Mapping[str, str]):
"""
Get a value from the query param for a given key. If the key occurs
more than once, then only the first value is returned.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert q.get("a") == "123"
"""
if key in self._dict:
return self._dict[key][0]
return self._dict[str(key)][0]
return default
def get_list(self, key: typing.Any) -> typing.List[str]:
"""
Get all values from the query param for a given key.
Usage:
q = httpx.QueryParams("a=123&a=456&b=789")
assert q.get_list("a") == ["123", "456"]
"""
return list(self._dict.get(key, []))
return list(self._dict.get(str(key), []))
def update(self, params: QueryParamTypes = None) -> None:
if not params:
return
def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting the value of a key.
params = QueryParams(params)
for k in params.keys():
self._dict[k] = params.get_list(k)
Usage:
q = httpx.QueryParams("a=123")
q = q.set("a", "456")
assert q == httpx.QueryParams("a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = [primitive_value_to_str(value)]
return q
def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting or appending the value of a key.
Usage:
q = httpx.QueryParams("a=123")
q = q.add("a", "456")
assert q == httpx.QueryParams("a=123&a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
return q
def remove(self, key: typing.Any) -> "QueryParams":
"""
Return a new QueryParams instance, removing the value of a key.
Usage:
q = httpx.QueryParams("a=123")
q = q.remove("a")
assert q == httpx.QueryParams("")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict.pop(str(key), None)
return q
def merge(self, params: QueryParamTypes = None) -> "QueryParams":
"""
Return a new QueryParams instance, updated with.
Usage:
q = httpx.QueryParams("a=123")
q = q.merge({"b": "456"})
assert q == httpx.QueryParams("a=123&b=456")
q = httpx.QueryParams("a=123")
q = q.merge({"a": "456", "b": "789"})
assert q == httpx.QueryParams("a=456&b=789")
"""
q = QueryParams(params)
q._dict = {**self._dict, **q._dict}
return q
def __getitem__(self, key: typing.Any) -> str:
return self._dict[key][0]
def __setitem__(self, key: str, value: str) -> None:
self._dict[key] = [value]
def __contains__(self, key: typing.Any) -> bool:
return key in self._dict
@ -580,6 +681,9 @@ class QueryParams(typing.Mapping[str, str]):
def __len__(self) -> int:
return len(self._dict)
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, self.__class__):
return False
@ -593,6 +697,18 @@ class QueryParams(typing.Mapping[str, str]):
query_string = str(self)
return f"{class_name}({query_string!r})"
def update(self, params: QueryParamTypes = None) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.merge(...)` to create an updated copy."
)
def __setitem__(self, key: str, value: str) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.set(key, value)` to create an updated copy."
)
class Headers(typing.MutableMapping[str, str]):
"""

View File

@ -1,5 +1,4 @@
import codecs
import collections
import logging
import mimetypes
import netrc
@ -369,31 +368,6 @@ def peek_filelike_length(stream: typing.IO) -> int:
return os.fstat(fd).st_size
def flatten_queryparams(
queryparams: typing.Mapping[
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
]
) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
"""
Convert a mapping of query params into a flat list of two-tuples
representing each item.
Example:
>>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
[("q", "httpx), ("tag", "python"), ("tag", "dev")]
"""
items = []
for k, v in queryparams.items():
if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
for u in v:
items.append((k, u))
else:
items.append((k, typing.cast("PrimitiveData", v)))
return items
class Timer:
async def _get_time(self) -> float:
library = sniffio.current_async_library()

View File

@ -76,19 +76,50 @@ def test_queryparam_types():
assert str(q) == "a=1&a=2"
def test_queryparam_setters():
q = httpx.QueryParams({"a": 1})
q.update([])
def test_queryparam_update_is_hard_deprecated():
q = httpx.QueryParams("a=123")
with pytest.raises(RuntimeError):
q.update({"a": "456"})
assert str(q) == "a=1"
q = httpx.QueryParams([("a", 1), ("a", 2)])
q["a"] = "3"
assert str(q) == "a=3"
def test_queryparam_setter_is_hard_deprecated():
q = httpx.QueryParams("a=123")
with pytest.raises(RuntimeError):
q["a"] = "456"
q = httpx.QueryParams([("a", 1), ("b", 1)])
u = httpx.QueryParams([("b", 2), ("b", 3)])
q.update(u)
assert str(q) == "a=1&b=2&b=3"
assert q["b"] == u["b"]
def test_queryparam_set():
q = httpx.QueryParams("a=123")
q = q.set("a", "456")
assert q == httpx.QueryParams("a=456")
def test_queryparam_add():
q = httpx.QueryParams("a=123")
q = q.add("a", "456")
assert q == httpx.QueryParams("a=123&a=456")
def test_queryparam_remove():
q = httpx.QueryParams("a=123")
q = q.remove("a")
assert q == httpx.QueryParams("")
def test_queryparam_merge():
q = httpx.QueryParams("a=123")
q = q.merge({"b": "456"})
assert q == httpx.QueryParams("a=123&b=456")
q = q.merge({"a": "000", "c": "789"})
assert q == httpx.QueryParams("a=000&b=456&c=789")
def test_queryparams_are_hashable():
params = (
httpx.QueryParams("a=123"),
httpx.QueryParams({"a": 123}),
httpx.QueryParams("b=456"),
httpx.QueryParams({"b": 456}),
)
assert len(set(params)) == 2