Immutable QueryParams (#1600)
* Tweak QueryParams implementation * Immutable QueryParams
This commit is contained in:
parent
8fe32c52de
commit
2abb2f214a
@ -385,7 +385,7 @@ class BaseClient:
|
||||
"""
|
||||
if params or self.params:
|
||||
merged_queryparams = QueryParams(self.params)
|
||||
merged_queryparams.update(params)
|
||||
merged_queryparams = merged_queryparams.merge(params)
|
||||
return merged_queryparams
|
||||
return params
|
||||
|
||||
|
||||
172
httpx/_models.py
172
httpx/_models.py
@ -6,7 +6,7 @@ import typing
|
||||
import urllib.request
|
||||
from collections.abc import MutableMapping
|
||||
from http.cookiejar import Cookie, CookieJar
|
||||
from urllib.parse import parse_qsl, quote, unquote, urlencode
|
||||
from urllib.parse import parse_qs, quote, unquote, urlencode
|
||||
|
||||
import idna
|
||||
import rfc3986
|
||||
@ -48,7 +48,6 @@ from ._types import (
|
||||
URLTypes,
|
||||
)
|
||||
from ._utils import (
|
||||
flatten_queryparams,
|
||||
guess_json_utf,
|
||||
is_known_encoding,
|
||||
normalize_header_key,
|
||||
@ -148,8 +147,7 @@ class URL:
|
||||
# Add any query parameters, merging with any in the URL if needed.
|
||||
if params:
|
||||
if self._uri_reference.query:
|
||||
url_params = QueryParams(self._uri_reference.query)
|
||||
url_params.update(params)
|
||||
url_params = QueryParams(self._uri_reference.query).merge(params)
|
||||
query_string = str(url_params)
|
||||
else:
|
||||
query_string = str(QueryParams(params))
|
||||
@ -450,7 +448,7 @@ class URL:
|
||||
|
||||
url = httpx.URL("https://www.example.com/test")
|
||||
url = url.join("/new/path")
|
||||
assert url == "https://www.example.com/test/new/path"
|
||||
assert url == "https://www.example.com/new/path"
|
||||
"""
|
||||
if self.is_relative_url:
|
||||
# Workaround to handle relative URLs, which otherwise raise
|
||||
@ -504,38 +502,79 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
|
||||
if value is None or isinstance(value, (str, bytes)):
|
||||
value = value.decode("ascii") if isinstance(value, bytes) else value
|
||||
items = parse_qsl(value)
|
||||
self._dict = parse_qs(value)
|
||||
elif isinstance(value, QueryParams):
|
||||
items = value.multi_items()
|
||||
elif isinstance(value, (list, tuple)):
|
||||
items = value
|
||||
self._dict = {k: list(v) for k, v in value._dict.items()}
|
||||
else:
|
||||
items = flatten_queryparams(value)
|
||||
|
||||
self._dict: typing.Dict[str, typing.List[str]] = {}
|
||||
for item in items:
|
||||
k, v = item
|
||||
if str(k) not in self._dict:
|
||||
self._dict[str(k)] = [primitive_value_to_str(v)]
|
||||
dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
|
||||
if isinstance(value, (list, tuple)):
|
||||
# Convert list inputs like:
|
||||
# [("a", "123"), ("a", "456"), ("b", "789")]
|
||||
# To a dict representation, like:
|
||||
# {"a": ["123", "456"], "b": ["789"]}
|
||||
for item in value:
|
||||
dict_value.setdefault(item[0], []).append(item[1])
|
||||
else:
|
||||
self._dict[str(k)].append(primitive_value_to_str(v))
|
||||
# Convert dict inputs like:
|
||||
# {"a": "123", "b": ["456", "789"]}
|
||||
# To dict inputs where values are always lists, like:
|
||||
# {"a": ["123"], "b": ["456", "789"]}
|
||||
dict_value = {
|
||||
k: list(v) if isinstance(v, (list, tuple)) else [v]
|
||||
for k, v in value.items()
|
||||
}
|
||||
|
||||
# Ensure that keys and values are neatly coerced to strings.
|
||||
# We coerce values `True` and `False` to JSON-like "true" and "false"
|
||||
# representations, and coerce `None` values to the empty string.
|
||||
self._dict = {
|
||||
str(k): [primitive_value_to_str(item) for item in v]
|
||||
for k, v in dict_value.items()
|
||||
}
|
||||
|
||||
def keys(self) -> typing.KeysView:
|
||||
"""
|
||||
Return all the keys in the query params.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert list(q.keys()) == ["a", "b"]
|
||||
"""
|
||||
return self._dict.keys()
|
||||
|
||||
def values(self) -> typing.ValuesView:
|
||||
"""
|
||||
Return all the values in the query params. If a key occurs more than once
|
||||
only the first item for that key is returned.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert list(q.values()) == ["123", "789"]
|
||||
"""
|
||||
return {k: v[0] for k, v in self._dict.items()}.values()
|
||||
|
||||
def items(self) -> typing.ItemsView:
|
||||
"""
|
||||
Return all items in the query params. If a key occurs more than once
|
||||
only the first item for that key is returned.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert list(q.items()) == [("a", "123"), ("b", "789")]
|
||||
"""
|
||||
return {k: v[0] for k, v in self._dict.items()}.items()
|
||||
|
||||
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
"""
|
||||
Return all items in the query params. Allow duplicate keys to occur.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
|
||||
"""
|
||||
multi_items: typing.List[typing.Tuple[str, str]] = []
|
||||
for k, v in self._dict.items():
|
||||
@ -546,31 +585,93 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
"""
|
||||
Get a value from the query param for a given key. If the key occurs
|
||||
more than once, then only the first value is returned.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert q.get("a") == "123"
|
||||
"""
|
||||
if key in self._dict:
|
||||
return self._dict[key][0]
|
||||
return self._dict[str(key)][0]
|
||||
return default
|
||||
|
||||
def get_list(self, key: typing.Any) -> typing.List[str]:
|
||||
"""
|
||||
Get all values from the query param for a given key.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123&a=456&b=789")
|
||||
assert q.get_list("a") == ["123", "456"]
|
||||
"""
|
||||
return list(self._dict.get(key, []))
|
||||
return list(self._dict.get(str(key), []))
|
||||
|
||||
def update(self, params: QueryParamTypes = None) -> None:
|
||||
if not params:
|
||||
return
|
||||
def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
|
||||
"""
|
||||
Return a new QueryParams instance, setting the value of a key.
|
||||
|
||||
params = QueryParams(params)
|
||||
for k in params.keys():
|
||||
self._dict[k] = params.get_list(k)
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.set("a", "456")
|
||||
assert q == httpx.QueryParams("a=456")
|
||||
"""
|
||||
q = QueryParams()
|
||||
q._dict = dict(self._dict)
|
||||
q._dict[str(key)] = [primitive_value_to_str(value)]
|
||||
return q
|
||||
|
||||
def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
|
||||
"""
|
||||
Return a new QueryParams instance, setting or appending the value of a key.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.add("a", "456")
|
||||
assert q == httpx.QueryParams("a=123&a=456")
|
||||
"""
|
||||
q = QueryParams()
|
||||
q._dict = dict(self._dict)
|
||||
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
|
||||
return q
|
||||
|
||||
def remove(self, key: typing.Any) -> "QueryParams":
|
||||
"""
|
||||
Return a new QueryParams instance, removing the value of a key.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.remove("a")
|
||||
assert q == httpx.QueryParams("")
|
||||
"""
|
||||
q = QueryParams()
|
||||
q._dict = dict(self._dict)
|
||||
q._dict.pop(str(key), None)
|
||||
return q
|
||||
|
||||
def merge(self, params: QueryParamTypes = None) -> "QueryParams":
|
||||
"""
|
||||
Return a new QueryParams instance, updated with.
|
||||
|
||||
Usage:
|
||||
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.merge({"b": "456"})
|
||||
assert q == httpx.QueryParams("a=123&b=456")
|
||||
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.merge({"a": "456", "b": "789"})
|
||||
assert q == httpx.QueryParams("a=456&b=789")
|
||||
"""
|
||||
q = QueryParams(params)
|
||||
q._dict = {**self._dict, **q._dict}
|
||||
return q
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> str:
|
||||
return self._dict[key][0]
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
self._dict[key] = [value]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
return key in self._dict
|
||||
|
||||
@ -580,6 +681,9 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
@ -593,6 +697,18 @@ class QueryParams(typing.Mapping[str, str]):
|
||||
query_string = str(self)
|
||||
return f"{class_name}({query_string!r})"
|
||||
|
||||
def update(self, params: QueryParamTypes = None) -> None:
|
||||
raise RuntimeError(
|
||||
"QueryParams are immutable since 0.18.0. "
|
||||
"Use `q = q.merge(...)` to create an updated copy."
|
||||
)
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
raise RuntimeError(
|
||||
"QueryParams are immutable since 0.18.0. "
|
||||
"Use `q = q.set(key, value)` to create an updated copy."
|
||||
)
|
||||
|
||||
|
||||
class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import codecs
|
||||
import collections
|
||||
import logging
|
||||
import mimetypes
|
||||
import netrc
|
||||
@ -369,31 +368,6 @@ def peek_filelike_length(stream: typing.IO) -> int:
|
||||
return os.fstat(fd).st_size
|
||||
|
||||
|
||||
def flatten_queryparams(
|
||||
queryparams: typing.Mapping[
|
||||
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
|
||||
]
|
||||
) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
|
||||
"""
|
||||
Convert a mapping of query params into a flat list of two-tuples
|
||||
representing each item.
|
||||
|
||||
Example:
|
||||
>>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
|
||||
[("q", "httpx), ("tag", "python"), ("tag", "dev")]
|
||||
"""
|
||||
items = []
|
||||
|
||||
for k, v in queryparams.items():
|
||||
if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
|
||||
for u in v:
|
||||
items.append((k, u))
|
||||
else:
|
||||
items.append((k, typing.cast("PrimitiveData", v)))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
class Timer:
|
||||
async def _get_time(self) -> float:
|
||||
library = sniffio.current_async_library()
|
||||
|
||||
@ -76,19 +76,50 @@ def test_queryparam_types():
|
||||
assert str(q) == "a=1&a=2"
|
||||
|
||||
|
||||
def test_queryparam_setters():
|
||||
q = httpx.QueryParams({"a": 1})
|
||||
q.update([])
|
||||
def test_queryparam_update_is_hard_deprecated():
|
||||
q = httpx.QueryParams("a=123")
|
||||
with pytest.raises(RuntimeError):
|
||||
q.update({"a": "456"})
|
||||
|
||||
assert str(q) == "a=1"
|
||||
|
||||
q = httpx.QueryParams([("a", 1), ("a", 2)])
|
||||
q["a"] = "3"
|
||||
assert str(q) == "a=3"
|
||||
def test_queryparam_setter_is_hard_deprecated():
|
||||
q = httpx.QueryParams("a=123")
|
||||
with pytest.raises(RuntimeError):
|
||||
q["a"] = "456"
|
||||
|
||||
q = httpx.QueryParams([("a", 1), ("b", 1)])
|
||||
u = httpx.QueryParams([("b", 2), ("b", 3)])
|
||||
q.update(u)
|
||||
|
||||
assert str(q) == "a=1&b=2&b=3"
|
||||
assert q["b"] == u["b"]
|
||||
def test_queryparam_set():
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.set("a", "456")
|
||||
assert q == httpx.QueryParams("a=456")
|
||||
|
||||
|
||||
def test_queryparam_add():
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.add("a", "456")
|
||||
assert q == httpx.QueryParams("a=123&a=456")
|
||||
|
||||
|
||||
def test_queryparam_remove():
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.remove("a")
|
||||
assert q == httpx.QueryParams("")
|
||||
|
||||
|
||||
def test_queryparam_merge():
|
||||
q = httpx.QueryParams("a=123")
|
||||
q = q.merge({"b": "456"})
|
||||
assert q == httpx.QueryParams("a=123&b=456")
|
||||
q = q.merge({"a": "000", "c": "789"})
|
||||
assert q == httpx.QueryParams("a=000&b=456&c=789")
|
||||
|
||||
|
||||
def test_queryparams_are_hashable():
|
||||
params = (
|
||||
httpx.QueryParams("a=123"),
|
||||
httpx.QueryParams({"a": 123}),
|
||||
httpx.QueryParams("b=456"),
|
||||
httpx.QueryParams({"b": 456}),
|
||||
)
|
||||
|
||||
assert len(set(params)) == 2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user