Base URL improvements (#1130)

* URL.join(url=...), not URL.join(relative_url=...)

* Fix URL.join()

* Support no argument 'httpx.URL()' usage

* Support client.base_url as a property

* Resolve base_url joining behaviour

* Fix coverage

* Update _client.py
This commit is contained in:
Tom Christie 2020-08-05 18:56:25 +01:00 committed by GitHub
parent 7279ed4658
commit a3392c6ea7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 13 deletions

View File

@ -66,13 +66,10 @@ class BaseClient:
cookies: CookieTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_url: URLTypes = None,
base_url: URLTypes = "",
trust_env: bool = True,
):
if base_url is None:
self.base_url = URL("")
else:
self.base_url = URL(base_url)
self._base_url = self._enforce_trailing_slash(URL(base_url))
self.auth = auth
self._params = QueryParams(params)
@ -87,6 +84,11 @@ class BaseClient:
def trust_env(self) -> bool:
return self._trust_env
def _enforce_trailing_slash(self, url: URL) -> URL:
if url.path.endswith("/"):
return url
return url.copy_with(path=url.path + "/")
def _get_proxy_map(
self, proxies: typing.Optional[ProxiesTypes], allow_env_proxies: bool,
) -> typing.Dict[str, typing.Optional[Proxy]]:
@ -107,6 +109,17 @@ class BaseClient:
proxy = Proxy(url=proxies) if isinstance(proxies, (str, URL)) else proxies
return {"all": proxy}
@property
def base_url(self) -> URL:
"""
Base URL to use when sending requests with relative URLs.
"""
return self._base_url
@base_url.setter
def base_url(self, url: URLTypes) -> None:
self._base_url = self._enforce_trailing_slash(URL(url))
@property
def headers(self) -> Headers:
"""
@ -208,7 +221,13 @@ class BaseClient:
Merge a URL argument together with any 'base_url' on the client,
to create the URL used for the outgoing request.
"""
return self.base_url.join(url)
merge_url = URL(url)
if merge_url.is_relative_url:
# We always ensure the base_url paths include the trailing '/',
# and always strip any leading '/' from the merge URL.
merge_url = merge_url.copy_with(path=merge_url.path.lstrip("/"))
return self.base_url.join(merge_url)
return merge_url
def _merge_cookies(
self, cookies: CookieTypes = None
@ -441,7 +460,7 @@ class Client(BaseClient):
limits: Limits = DEFAULT_LIMITS,
pool_limits: Limits = None,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_url: URLTypes = None,
base_url: URLTypes = "",
transport: httpcore.SyncHTTPTransport = None,
app: typing.Callable = None,
trust_env: bool = True,
@ -972,7 +991,7 @@ class AsyncClient(BaseClient):
limits: Limits = DEFAULT_LIMITS,
pool_limits: Limits = None,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_url: URLTypes = None,
base_url: URLTypes = "",
transport: httpcore.AsyncHTTPTransport = None,
app: typing.Callable = None,
trust_env: bool = True,

View File

@ -55,7 +55,7 @@ from ._utils import (
class URL:
def __init__(self, url: URLTypes, params: QueryParamTypes = None) -> None:
def __init__(self, url: URLTypes = "", params: QueryParamTypes = None) -> None:
if isinstance(url, str):
self._uri_reference = rfc3986.api.iri_reference(url).encode()
else:

View File

@ -174,13 +174,31 @@ def test_base_url(server):
assert response.url == base_url
def test_merge_url():
def test_merge_absolute_url():
client = httpx.Client(base_url="https://www.example.com/")
request = client.build_request("GET", "http://www.example.com")
assert request.url.scheme == "http"
request = client.build_request("GET", "http://www.example.com/")
assert request.url == httpx.URL("http://www.example.com/")
assert not request.url.is_ssl
def test_merge_relative_url():
client = httpx.Client(base_url="https://www.example.com/")
request = client.build_request("GET", "/testing/123")
assert request.url == httpx.URL("https://www.example.com/testing/123")
def test_merge_relative_url_with_path():
client = httpx.Client(base_url="https://www.example.com/some/path")
request = client.build_request("GET", "/testing/123")
assert request.url == httpx.URL("https://www.example.com/some/path/testing/123")
def test_merge_relative_url_with_dotted_path():
client = httpx.Client(base_url="https://www.example.com/some/path")
request = client.build_request("GET", "../testing/123")
assert request.url == httpx.URL("https://www.example.com/some/testing/123")
def test_pool_limits_deprecated():
limits = httpx.Limits()

View File

@ -1,4 +1,25 @@
from httpx import AsyncClient, Cookies, Headers
from httpx import URL, AsyncClient, Cookies, Headers
def test_client_base_url():
client = AsyncClient()
client.base_url = "https://www.example.org/" # type: ignore
assert isinstance(client.base_url, URL)
assert client.base_url == URL("https://www.example.org/")
def test_client_base_url_without_trailing_slash():
client = AsyncClient()
client.base_url = "https://www.example.org/path" # type: ignore
assert isinstance(client.base_url, URL)
assert client.base_url == URL("https://www.example.org/path/")
def test_client_base_url_with_trailing_slash():
client = AsyncClient()
client.base_url = "https://www.example.org/path/" # type: ignore
assert isinstance(client.base_url, URL)
assert client.base_url == URL("https://www.example.org/path/")
def test_client_headers():