Expand URL interface (#1601)

* Expand URL interface

* Add URL query param manipulation methods
This commit is contained in:
Tom Christie 2021-04-27 09:01:14 +01:00 committed by GitHub
parent 2abb2f214a
commit e67b0dd15b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 18 deletions

View File

@ -112,7 +112,7 @@ class URL:
"""
def __init__(
self, url: typing.Union["URL", str, RawURL] = "", params: QueryParamTypes = None
self, url: typing.Union["URL", str, RawURL] = "", **kwargs: typing.Any
) -> None:
if isinstance(url, (str, tuple)):
if isinstance(url, tuple):
@ -144,14 +144,8 @@ class URL:
f"Invalid type for url. Expected str or httpx.URL, got {type(url)}: {url!r}"
)
# Add any query parameters, merging with any in the URL if needed.
if params:
if self._uri_reference.query:
url_params = QueryParams(self._uri_reference.query).merge(params)
query_string = str(url_params)
else:
query_string = str(QueryParams(params))
self._uri_reference = self._uri_reference.copy_with(query=query_string)
if kwargs:
self._uri_reference = self.copy_with(**kwargs)._uri_reference
@property
def scheme(self) -> str:
@ -293,12 +287,27 @@ class URL:
def query(self) -> bytes:
"""
The URL query string, as raw bytes, excluding the leading b"?".
Note that URL decoding can only be applied on URL query strings
at the point of decoding the individual parameter names/values.
This is neccessarily a bytewise interface, because we cannot
perform URL decoding of this representation until we've parsed
the keys and values into a QueryParams instance.
For example:
url = httpx.URL("https://example.com/?filter=some%20search%20terms")
assert url.query == b"filter=some%20search%20terms"
"""
query = self._uri_reference.query or ""
return query.encode("ascii")
@property
def params(self) -> "QueryParams":
"""
The URL query parameters, neatly parsed and packaged into an immutable
multidict representation.
"""
return QueryParams(self._uri_reference.query)
@property
def raw_path(self) -> bytes:
"""
@ -382,6 +391,7 @@ class URL:
"query": bytes,
"raw_path": bytes,
"fragment": str,
"params": object,
}
for key, value in kwargs.items():
if key not in allowed:
@ -434,12 +444,28 @@ class URL:
if kwargs.get("path") is not None:
kwargs["path"] = quote(kwargs["path"])
# Ensure query=<str> for rfc3986
if kwargs.get("query") is not None:
# Ensure query=<str> for rfc3986
kwargs["query"] = kwargs["query"].decode("ascii")
if "params" in kwargs:
params = kwargs.pop("params")
kwargs["query"] = None if not params else str(QueryParams(params))
return URL(self._uri_reference.copy_with(**kwargs).unsplit())
def copy_set_param(self, key: str, value: typing.Any = None) -> "URL":
return self.copy_with(params=self.params.set(key, value))
def copy_add_param(self, key: str, value: typing.Any = None) -> "URL":
return self.copy_with(params=self.params.add(key, value))
def copy_remove_param(self, key: str) -> "URL":
return self.copy_with(params=self.params.remove(key))
def copy_merge_params(self, params: QueryParamTypes) -> "URL":
return self.copy_with(params=self.params.merge(params))
def join(self, url: URLTypes) -> "URL":
"""
Return an absolute URL, using this URL as the base.
@ -595,7 +621,7 @@ class QueryParams(typing.Mapping[str, str]):
return self._dict[str(key)][0]
return default
def get_list(self, key: typing.Any) -> typing.List[str]:
def get_list(self, key: str) -> typing.List[str]:
"""
Get all values from the query param for a given key.
@ -606,7 +632,7 @@ class QueryParams(typing.Mapping[str, str]):
"""
return list(self._dict.get(str(key), []))
def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
def set(self, key: str, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting the value of a key.
@ -621,7 +647,7 @@ class QueryParams(typing.Mapping[str, str]):
q._dict[str(key)] = [primitive_value_to_str(value)]
return q
def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
def add(self, key: str, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting or appending the value of a key.
@ -636,7 +662,7 @@ class QueryParams(typing.Mapping[str, str]):
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
return q
def remove(self, key: typing.Any) -> "QueryParams":
def remove(self, key: str) -> "QueryParams":
"""
Return a new QueryParams instance, removing the value of a key.
@ -681,6 +707,9 @@ class QueryParams(typing.Mapping[str, str]):
def __len__(self) -> int:
return len(self._dict)
def __bool__(self) -> bool:
return bool(self._dict)
def __hash__(self) -> int:
return hash(str(self))
@ -971,7 +1000,9 @@ class Request:
self.method = method.decode("ascii").upper()
else:
self.method = method.upper()
self.url = URL(url, params=params)
self.url = URL(url)
if params is not None:
self.url = self.url.copy_merge_params(params=params)
self.headers = Headers(headers)
if cookies:
Cookies(cookies).set_cookie_header(self)

View File

@ -100,11 +100,13 @@ def test_url_eq_str():
def test_url_params():
url = httpx.URL("https://example.org:123/path/to/somewhere", params={"a": "123"})
assert str(url) == "https://example.org:123/path/to/somewhere?a=123"
assert url.params == httpx.QueryParams({"a": "123"})
url = httpx.URL(
"https://example.org:123/path/to/somewhere?b=456", params={"a": "123"}
)
assert str(url) == "https://example.org:123/path/to/somewhere?b=456&a=123"
assert str(url) == "https://example.org:123/path/to/somewhere?a=123"
assert url.params == httpx.QueryParams({"a": "123"})
def test_url_join():
@ -122,6 +124,38 @@ def test_url_join():
assert url.join("../../somewhere-else") == "https://example.org:123/somewhere-else"
def test_url_set_param_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_set_param("a", "456") == "https://example.org:123/?a=456"
def test_url_add_param_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_add_param("a", "456") == "https://example.org:123/?a=123&a=456"
def test_url_remove_param_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_remove_param("a") == "https://example.org:123/"
def test_url_merge_params_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_merge_params({"b": "456"}) == "https://example.org:123/?a=123&b=456"
def test_relative_url_join():
url = httpx.URL("/path/to/somewhere")
assert url.join("/somewhere-else") == "/somewhere-else"