Add support for Mount API (#1362)
* Add support for Mount API * Add test cases * Add test case for all: mounted transport * Use 'transport' variable, in preference to 'proxy' * Add docs for mounted transports
This commit is contained in:
parent
2961f267fd
commit
c4d2e6fa28
@ -1040,12 +1040,13 @@ class HelloWorldTransport(httpcore.SyncHTTPTransport):
|
||||
A mock transport that always returns a JSON "Hello, world!" response.
|
||||
"""
|
||||
|
||||
def request(self, method, url, headers=None, stream=None, timeout=None):
|
||||
def request(self, method, url, headers=None, stream=None, ext=None):
|
||||
message = {"text": "Hello, world!"}
|
||||
content = json.dumps(message).encode("utf-8")
|
||||
stream = httpcore.PlainByteStream(content)
|
||||
headers = [(b"content-type", b"application/json")]
|
||||
return b"HTTP/1.1", 200, b"OK", headers, stream
|
||||
ext = {"http_version": b"HTTP/1.1"}
|
||||
return 200, headers, stream, ext
|
||||
```
|
||||
|
||||
Which we can use in the same way:
|
||||
@ -1057,3 +1058,54 @@ Which we can use in the same way:
|
||||
>>> response.json()
|
||||
{"text": "Hello, world!"}
|
||||
```
|
||||
|
||||
### Mounting transports
|
||||
|
||||
You can also mount transports against given schemes or domains, to control
|
||||
which transport an outgoing request should be routed via, with [the same style
|
||||
used for specifying proxy routing](#routing).
|
||||
|
||||
```python
|
||||
import httpcore
|
||||
import httpx
|
||||
|
||||
class HTTPSRedirectTransport(httpcore.SyncHTTPTransport):
|
||||
"""
|
||||
A transport that always redirects to HTTPS.
|
||||
"""
|
||||
|
||||
def request(self, method, url, headers=None, stream=None, ext=None):
|
||||
scheme, host, port, path = url
|
||||
if port is None:
|
||||
location = b"https://%s%s" % (host, path)
|
||||
else:
|
||||
location = b"https://%s:%d%s" % (host, port, path)
|
||||
stream = httpcore.PlainByteStream(b"")
|
||||
headers = [(b"location", location)]
|
||||
ext = {"http_version": b"HTTP/1.1"}
|
||||
return 303, headers, stream, ext
|
||||
|
||||
|
||||
# A client where any `http` requests are always redirected to `https`
|
||||
mounts = {'http://': HTTPSRedirectTransport()}
|
||||
client = httpx.Client(mounts=mounts)
|
||||
```
|
||||
|
||||
A couple of other sketches of how you might take advantage of mounted transports...
|
||||
|
||||
Mocking requests to a given domain:
|
||||
|
||||
```python
|
||||
# All requests to "example.org" should be mocked out.
|
||||
# Other requests occur as usual.
|
||||
mounts = {"all://example.org": MockTransport()}
|
||||
client = httpx.Client(mounts=mounts)
|
||||
```
|
||||
|
||||
Adding support for custom schemes:
|
||||
|
||||
```python
|
||||
# Support URLs like "file:///Users/sylvia_green/websites/new_client/index.html"
|
||||
mounts = {"file://": FileSystemTransport()}
|
||||
client = httpx.Client(mounts=mounts)
|
||||
```
|
||||
|
||||
@ -87,7 +87,7 @@ class BaseClient:
|
||||
cookies: CookieTypes = None,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
|
||||
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
|
||||
base_url: URLTypes = "",
|
||||
trust_env: bool = True,
|
||||
):
|
||||
@ -561,11 +561,12 @@ class Client(BaseClient):
|
||||
cert: CertTypes = None,
|
||||
http2: bool = False,
|
||||
proxies: ProxiesTypes = None,
|
||||
mounts: typing.Mapping[str, httpcore.SyncHTTPTransport] = None,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
limits: Limits = DEFAULT_LIMITS,
|
||||
pool_limits: Limits = None,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
|
||||
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
|
||||
base_url: URLTypes = "",
|
||||
transport: httpcore.SyncHTTPTransport = None,
|
||||
app: typing.Callable = None,
|
||||
@ -611,7 +612,7 @@ class Client(BaseClient):
|
||||
app=app,
|
||||
trust_env=trust_env,
|
||||
)
|
||||
self._proxies: typing.Dict[
|
||||
self._mounts: typing.Dict[
|
||||
URLPattern, typing.Optional[httpcore.SyncHTTPTransport]
|
||||
] = {
|
||||
URLPattern(key): None
|
||||
@ -626,7 +627,12 @@ class Client(BaseClient):
|
||||
)
|
||||
for key, proxy in proxy_map.items()
|
||||
}
|
||||
self._proxies = dict(sorted(self._proxies.items()))
|
||||
if mounts is not None:
|
||||
self._mounts.update(
|
||||
{URLPattern(key): transport for key, transport in mounts.items()}
|
||||
)
|
||||
|
||||
self._mounts = dict(sorted(self._mounts.items()))
|
||||
|
||||
def _init_transport(
|
||||
self,
|
||||
@ -681,7 +687,7 @@ class Client(BaseClient):
|
||||
Returns the transport instance that should be used for a given URL.
|
||||
This will either be the standard connection pool, or a proxy.
|
||||
"""
|
||||
for pattern, transport in self._proxies.items():
|
||||
for pattern, transport in self._mounts.items():
|
||||
if pattern.matches(url):
|
||||
return self._transport if transport is None else transport
|
||||
|
||||
@ -1109,17 +1115,17 @@ class Client(BaseClient):
|
||||
self._state = ClientState.CLOSED
|
||||
|
||||
self._transport.close()
|
||||
for proxy in self._proxies.values():
|
||||
if proxy is not None:
|
||||
proxy.close()
|
||||
for transport in self._mounts.values():
|
||||
if transport is not None:
|
||||
transport.close()
|
||||
|
||||
def __enter__(self: T) -> T:
|
||||
self._state = ClientState.OPENED
|
||||
|
||||
self._transport.__enter__()
|
||||
for proxy in self._proxies.values():
|
||||
if proxy is not None:
|
||||
proxy.__enter__()
|
||||
for transport in self._mounts.values():
|
||||
if transport is not None:
|
||||
transport.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
@ -1131,9 +1137,9 @@ class Client(BaseClient):
|
||||
self._state = ClientState.CLOSED
|
||||
|
||||
self._transport.__exit__(exc_type, exc_value, traceback)
|
||||
for proxy in self._proxies.values():
|
||||
if proxy is not None:
|
||||
proxy.__exit__(exc_type, exc_value, traceback)
|
||||
for transport in self._mounts.values():
|
||||
if transport is not None:
|
||||
transport.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
@ -1198,11 +1204,12 @@ class AsyncClient(BaseClient):
|
||||
cert: CertTypes = None,
|
||||
http2: bool = False,
|
||||
proxies: ProxiesTypes = None,
|
||||
mounts: typing.Mapping[str, httpcore.AsyncHTTPTransport] = None,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
limits: Limits = DEFAULT_LIMITS,
|
||||
pool_limits: Limits = None,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
|
||||
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
|
||||
base_url: URLTypes = "",
|
||||
transport: httpcore.AsyncHTTPTransport = None,
|
||||
app: typing.Callable = None,
|
||||
@ -1249,7 +1256,7 @@ class AsyncClient(BaseClient):
|
||||
trust_env=trust_env,
|
||||
)
|
||||
|
||||
self._proxies: typing.Dict[
|
||||
self._mounts: typing.Dict[
|
||||
URLPattern, typing.Optional[httpcore.AsyncHTTPTransport]
|
||||
] = {
|
||||
URLPattern(key): None
|
||||
@ -1264,7 +1271,11 @@ class AsyncClient(BaseClient):
|
||||
)
|
||||
for key, proxy in proxy_map.items()
|
||||
}
|
||||
self._proxies = dict(sorted(self._proxies.items()))
|
||||
if mounts is not None:
|
||||
self._mounts.update(
|
||||
{URLPattern(key): transport for key, transport in mounts.items()}
|
||||
)
|
||||
self._mounts = dict(sorted(self._mounts.items()))
|
||||
|
||||
def _init_transport(
|
||||
self,
|
||||
@ -1319,7 +1330,7 @@ class AsyncClient(BaseClient):
|
||||
Returns the transport instance that should be used for a given URL.
|
||||
This will either be the standard connection pool, or a proxy.
|
||||
"""
|
||||
for pattern, transport in self._proxies.items():
|
||||
for pattern, transport in self._mounts.items():
|
||||
if pattern.matches(url):
|
||||
return self._transport if transport is None else transport
|
||||
|
||||
@ -1499,7 +1510,7 @@ class AsyncClient(BaseClient):
|
||||
await timer.async_start()
|
||||
|
||||
with map_exceptions(HTTPCORE_EXC_MAP, request=request):
|
||||
(status_code, headers, stream, ext,) = await transport.arequest(
|
||||
(status_code, headers, stream, ext) = await transport.arequest(
|
||||
request.method.encode(),
|
||||
request.url.raw,
|
||||
headers=request.headers.raw,
|
||||
@ -1750,7 +1761,7 @@ class AsyncClient(BaseClient):
|
||||
self._state = ClientState.CLOSED
|
||||
|
||||
await self._transport.aclose()
|
||||
for proxy in self._proxies.values():
|
||||
for proxy in self._mounts.values():
|
||||
if proxy is not None:
|
||||
await proxy.aclose()
|
||||
|
||||
@ -1758,7 +1769,7 @@ class AsyncClient(BaseClient):
|
||||
self._state = ClientState.OPENED
|
||||
|
||||
await self._transport.__aenter__()
|
||||
for proxy in self._proxies.values():
|
||||
for proxy in self._mounts.values():
|
||||
if proxy is not None:
|
||||
await proxy.__aenter__()
|
||||
return self
|
||||
@ -1772,7 +1783,7 @@ class AsyncClient(BaseClient):
|
||||
self._state = ClientState.CLOSED
|
||||
|
||||
await self._transport.__aexit__(exc_type, exc_value, traceback)
|
||||
for proxy in self._proxies.values():
|
||||
for proxy in self._mounts.values():
|
||||
if proxy is not None:
|
||||
await proxy.__aexit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import typing
|
||||
from datetime import timedelta
|
||||
|
||||
import httpcore
|
||||
@ -188,15 +189,8 @@ async def test_context_managed_transport():
|
||||
await super().__aexit__(*args)
|
||||
self.events.append("transport.__aexit__")
|
||||
|
||||
# Note that we're including 'proxies' here to *also* run through the
|
||||
# proxy context management, although we can't easily test that at the
|
||||
# moment, since we can't add proxies as transport instances.
|
||||
#
|
||||
# Once we have a more generalised Mount API we'll be able to remove this
|
||||
# in favour of ensuring all mounts are context managed, which will
|
||||
# also neccessarily include proxies.
|
||||
transport = Transport()
|
||||
async with httpx.AsyncClient(transport=transport, proxies="http://www.example.com"):
|
||||
async with httpx.AsyncClient(transport=transport):
|
||||
pass
|
||||
|
||||
assert transport.events == [
|
||||
@ -206,6 +200,47 @@ async def test_context_managed_transport():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_context_managed_transport_and_mount():
|
||||
class Transport(httpcore.AsyncHTTPTransport):
|
||||
def __init__(self, name: str):
|
||||
self.name: str = name
|
||||
self.events: typing.List[str] = []
|
||||
|
||||
async def aclose(self):
|
||||
# The base implementation of httpcore.AsyncHTTPTransport just
|
||||
# calls into `.aclose`, so simple transport cases can just override
|
||||
# this method for any cleanup, where more complex cases
|
||||
# might want to additionally override `__aenter__`/`__aexit__`.
|
||||
self.events.append(f"{self.name}.aclose")
|
||||
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
self.events.append(f"{self.name}.__aenter__")
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await super().__aexit__(*args)
|
||||
self.events.append(f"{self.name}.__aexit__")
|
||||
|
||||
transport = Transport(name="transport")
|
||||
mounted = Transport(name="mounted")
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, mounts={"http://www.example.org": mounted}
|
||||
):
|
||||
pass
|
||||
|
||||
assert transport.events == [
|
||||
"transport.__aenter__",
|
||||
"transport.aclose",
|
||||
"transport.__aexit__",
|
||||
]
|
||||
assert mounted.events == [
|
||||
"mounted.__aenter__",
|
||||
"mounted.aclose",
|
||||
"mounted.__aexit__",
|
||||
]
|
||||
|
||||
|
||||
def hello_world(request):
|
||||
return httpx.Response(200, text="Hello, world!")
|
||||
|
||||
@ -242,3 +277,28 @@ async def test_deleting_unclosed_async_client_causes_warning():
|
||||
await client.get("http://example.com")
|
||||
with pytest.warns(UserWarning):
|
||||
del client
|
||||
|
||||
|
||||
def unmounted(request: httpx.Request) -> httpx.Response:
|
||||
data = {"app": "unmounted"}
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
|
||||
def mounted(request: httpx.Request) -> httpx.Response:
|
||||
data = {"app": "mounted"}
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_mounted_transport():
|
||||
transport = MockTransport(unmounted)
|
||||
mounts = {"custom://": MockTransport(mounted)}
|
||||
|
||||
async with httpx.AsyncClient(transport=transport, mounts=mounts) as client:
|
||||
response = await client.get("https://www.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"app": "unmounted"}
|
||||
|
||||
response = await client.get("custom://www.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"app": "mounted"}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import typing
|
||||
from datetime import timedelta
|
||||
|
||||
import httpcore
|
||||
@ -227,15 +228,8 @@ def test_context_managed_transport():
|
||||
super().__exit__(*args)
|
||||
self.events.append("transport.__exit__")
|
||||
|
||||
# Note that we're including 'proxies' here to *also* run through the
|
||||
# proxy context management, although we can't easily test that at the
|
||||
# moment, since we can't add proxies as transport instances.
|
||||
#
|
||||
# Once we have a more generalised Mount API we'll be able to remove this
|
||||
# in favour of ensuring all mounts are context managed, which will
|
||||
# also neccessarily include proxies.
|
||||
transport = Transport()
|
||||
with httpx.Client(transport=transport, proxies="http://www.example.com"):
|
||||
with httpx.Client(transport=transport):
|
||||
pass
|
||||
|
||||
assert transport.events == [
|
||||
@ -245,6 +239,44 @@ def test_context_managed_transport():
|
||||
]
|
||||
|
||||
|
||||
def test_context_managed_transport_and_mount():
|
||||
class Transport(httpcore.SyncHTTPTransport):
|
||||
def __init__(self, name: str):
|
||||
self.name: str = name
|
||||
self.events: typing.List[str] = []
|
||||
|
||||
def close(self):
|
||||
# The base implementation of httpcore.SyncHTTPTransport just
|
||||
# calls into `.close`, so simple transport cases can just override
|
||||
# this method for any cleanup, where more complex cases
|
||||
# might want to additionally override `__enter__`/`__exit__`.
|
||||
self.events.append(f"{self.name}.close")
|
||||
|
||||
def __enter__(self):
|
||||
super().__enter__()
|
||||
self.events.append(f"{self.name}.__enter__")
|
||||
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
self.events.append(f"{self.name}.__exit__")
|
||||
|
||||
transport = Transport(name="transport")
|
||||
mounted = Transport(name="mounted")
|
||||
with httpx.Client(transport=transport, mounts={"http://www.example.org": mounted}):
|
||||
pass
|
||||
|
||||
assert transport.events == [
|
||||
"transport.__enter__",
|
||||
"transport.close",
|
||||
"transport.__exit__",
|
||||
]
|
||||
assert mounted.events == [
|
||||
"mounted.__enter__",
|
||||
"mounted.close",
|
||||
"mounted.__exit__",
|
||||
]
|
||||
|
||||
|
||||
def hello_world(request):
|
||||
return httpx.Response(200, text="Hello, world!")
|
||||
|
||||
@ -300,3 +332,38 @@ def test_raw_client_header():
|
||||
["User-Agent", f"python-httpx/{httpx.__version__}"],
|
||||
["Example-Header", "example-value"],
|
||||
]
|
||||
|
||||
|
||||
def unmounted(request: httpx.Request) -> httpx.Response:
|
||||
data = {"app": "unmounted"}
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
|
||||
def mounted(request: httpx.Request) -> httpx.Response:
|
||||
data = {"app": "mounted"}
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
|
||||
def test_mounted_transport():
|
||||
transport = MockTransport(unmounted)
|
||||
mounts = {"custom://": MockTransport(mounted)}
|
||||
|
||||
client = httpx.Client(transport=transport, mounts=mounts)
|
||||
|
||||
response = client.get("https://www.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"app": "unmounted"}
|
||||
|
||||
response = client.get("custom://www.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"app": "mounted"}
|
||||
|
||||
|
||||
def test_all_mounted_transport():
|
||||
mounts = {"all://": MockTransport(mounted)}
|
||||
|
||||
client = httpx.Client(mounts=mounts)
|
||||
|
||||
response = client.get("https://www.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"app": "mounted"}
|
||||
|
||||
@ -41,12 +41,12 @@ def test_proxies_parameter(proxies, expected_proxies):
|
||||
|
||||
for proxy_key, url in expected_proxies:
|
||||
pattern = URLPattern(proxy_key)
|
||||
assert pattern in client._proxies
|
||||
proxy = client._proxies[pattern]
|
||||
assert pattern in client._mounts
|
||||
proxy = client._mounts[pattern]
|
||||
assert isinstance(proxy, httpcore.SyncHTTPProxy)
|
||||
assert proxy.proxy_origin == url_to_origin(url)
|
||||
|
||||
assert len(expected_proxies) == len(client._proxies)
|
||||
assert len(expected_proxies) == len(client._mounts)
|
||||
|
||||
|
||||
PROXY_URL = "http://[::1]"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user