Add support for Mount API (#1362)

* Add support for Mount API

* Add test cases

* Add test case for all: mounted transport

* Use 'transport' variable, in preference to 'proxy'

* Add docs for mounted transports
This commit is contained in:
Tom Christie 2020-11-24 10:35:51 +00:00 committed by GitHub
parent 2961f267fd
commit c4d2e6fa28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 233 additions and 43 deletions

View File

@ -1040,12 +1040,13 @@ class HelloWorldTransport(httpcore.SyncHTTPTransport):
A mock transport that always returns a JSON "Hello, world!" response.
"""
def request(self, method, url, headers=None, stream=None, timeout=None):
def request(self, method, url, headers=None, stream=None, ext=None):
message = {"text": "Hello, world!"}
content = json.dumps(message).encode("utf-8")
stream = httpcore.PlainByteStream(content)
headers = [(b"content-type", b"application/json")]
return b"HTTP/1.1", 200, b"OK", headers, stream
ext = {"http_version": b"HTTP/1.1"}
return 200, headers, stream, ext
```
Which we can use in the same way:
@ -1057,3 +1058,54 @@ Which we can use in the same way:
>>> response.json()
{"text": "Hello, world!"}
```
### Mounting transports
You can also mount transports against given schemes or domains, to control
which transport an outgoing request should be routed via, with [the same style
used for specifying proxy routing](#routing).
```python
import httpcore
import httpx
class HTTPSRedirectTransport(httpcore.SyncHTTPTransport):
"""
A transport that always redirects to HTTPS.
"""
def request(self, method, url, headers=None, stream=None, ext=None):
scheme, host, port, path = url
if port is None:
location = b"https://%s%s" % (host, path)
else:
location = b"https://%s:%d%s" % (host, port, path)
stream = httpcore.PlainByteStream(b"")
headers = [(b"location", location)]
ext = {"http_version": b"HTTP/1.1"}
return 303, headers, stream, ext
# A client where any `http` requests are always redirected to `https`
mounts = {'http://': HTTPSRedirectTransport()}
client = httpx.Client(mounts=mounts)
```
A couple of other sketches of how you might take advantage of mounted transports...
Mocking requests to a given domain:
```python
# All requests to "example.org" should be mocked out.
# Other requests occur as usual.
mounts = {"all://example.org": MockTransport()}
client = httpx.Client(mounts=mounts)
```
Adding support for custom schemes:
```python
# Support URLs like "file:///Users/sylvia_green/websites/new_client/index.html"
mounts = {"file://": FileSystemTransport()}
client = httpx.Client(mounts=mounts)
```

View File

@ -87,7 +87,7 @@ class BaseClient:
cookies: CookieTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
base_url: URLTypes = "",
trust_env: bool = True,
):
@ -561,11 +561,12 @@ class Client(BaseClient):
cert: CertTypes = None,
http2: bool = False,
proxies: ProxiesTypes = None,
mounts: typing.Mapping[str, httpcore.SyncHTTPTransport] = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
limits: Limits = DEFAULT_LIMITS,
pool_limits: Limits = None,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
base_url: URLTypes = "",
transport: httpcore.SyncHTTPTransport = None,
app: typing.Callable = None,
@ -611,7 +612,7 @@ class Client(BaseClient):
app=app,
trust_env=trust_env,
)
self._proxies: typing.Dict[
self._mounts: typing.Dict[
URLPattern, typing.Optional[httpcore.SyncHTTPTransport]
] = {
URLPattern(key): None
@ -626,7 +627,12 @@ class Client(BaseClient):
)
for key, proxy in proxy_map.items()
}
self._proxies = dict(sorted(self._proxies.items()))
if mounts is not None:
self._mounts.update(
{URLPattern(key): transport for key, transport in mounts.items()}
)
self._mounts = dict(sorted(self._mounts.items()))
def _init_transport(
self,
@ -681,7 +687,7 @@ class Client(BaseClient):
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
for pattern, transport in self._proxies.items():
for pattern, transport in self._mounts.items():
if pattern.matches(url):
return self._transport if transport is None else transport
@ -1109,17 +1115,17 @@ class Client(BaseClient):
self._state = ClientState.CLOSED
self._transport.close()
for proxy in self._proxies.values():
if proxy is not None:
proxy.close()
for transport in self._mounts.values():
if transport is not None:
transport.close()
def __enter__(self: T) -> T:
self._state = ClientState.OPENED
self._transport.__enter__()
for proxy in self._proxies.values():
if proxy is not None:
proxy.__enter__()
for transport in self._mounts.values():
if transport is not None:
transport.__enter__()
return self
def __exit__(
@ -1131,9 +1137,9 @@ class Client(BaseClient):
self._state = ClientState.CLOSED
self._transport.__exit__(exc_type, exc_value, traceback)
for proxy in self._proxies.values():
if proxy is not None:
proxy.__exit__(exc_type, exc_value, traceback)
for transport in self._mounts.values():
if transport is not None:
transport.__exit__(exc_type, exc_value, traceback)
def __del__(self) -> None:
self.close()
@ -1198,11 +1204,12 @@ class AsyncClient(BaseClient):
cert: CertTypes = None,
http2: bool = False,
proxies: ProxiesTypes = None,
mounts: typing.Mapping[str, httpcore.AsyncHTTPTransport] = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
limits: Limits = DEFAULT_LIMITS,
pool_limits: Limits = None,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
base_url: URLTypes = "",
transport: httpcore.AsyncHTTPTransport = None,
app: typing.Callable = None,
@ -1249,7 +1256,7 @@ class AsyncClient(BaseClient):
trust_env=trust_env,
)
self._proxies: typing.Dict[
self._mounts: typing.Dict[
URLPattern, typing.Optional[httpcore.AsyncHTTPTransport]
] = {
URLPattern(key): None
@ -1264,7 +1271,11 @@ class AsyncClient(BaseClient):
)
for key, proxy in proxy_map.items()
}
self._proxies = dict(sorted(self._proxies.items()))
if mounts is not None:
self._mounts.update(
{URLPattern(key): transport for key, transport in mounts.items()}
)
self._mounts = dict(sorted(self._mounts.items()))
def _init_transport(
self,
@ -1319,7 +1330,7 @@ class AsyncClient(BaseClient):
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
for pattern, transport in self._proxies.items():
for pattern, transport in self._mounts.items():
if pattern.matches(url):
return self._transport if transport is None else transport
@ -1499,7 +1510,7 @@ class AsyncClient(BaseClient):
await timer.async_start()
with map_exceptions(HTTPCORE_EXC_MAP, request=request):
(status_code, headers, stream, ext,) = await transport.arequest(
(status_code, headers, stream, ext) = await transport.arequest(
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
@ -1750,7 +1761,7 @@ class AsyncClient(BaseClient):
self._state = ClientState.CLOSED
await self._transport.aclose()
for proxy in self._proxies.values():
for proxy in self._mounts.values():
if proxy is not None:
await proxy.aclose()
@ -1758,7 +1769,7 @@ class AsyncClient(BaseClient):
self._state = ClientState.OPENED
await self._transport.__aenter__()
for proxy in self._proxies.values():
for proxy in self._mounts.values():
if proxy is not None:
await proxy.__aenter__()
return self
@ -1772,7 +1783,7 @@ class AsyncClient(BaseClient):
self._state = ClientState.CLOSED
await self._transport.__aexit__(exc_type, exc_value, traceback)
for proxy in self._proxies.values():
for proxy in self._mounts.values():
if proxy is not None:
await proxy.__aexit__(exc_type, exc_value, traceback)

View File

@ -1,3 +1,4 @@
import typing
from datetime import timedelta
import httpcore
@ -188,15 +189,8 @@ async def test_context_managed_transport():
await super().__aexit__(*args)
self.events.append("transport.__aexit__")
# Note that we're including 'proxies' here to *also* run through the
# proxy context management, although we can't easily test that at the
# moment, since we can't add proxies as transport instances.
#
# Once we have a more generalised Mount API we'll be able to remove this
# in favour of ensuring all mounts are context managed, which will
# also neccessarily include proxies.
transport = Transport()
async with httpx.AsyncClient(transport=transport, proxies="http://www.example.com"):
async with httpx.AsyncClient(transport=transport):
pass
assert transport.events == [
@ -206,6 +200,47 @@ async def test_context_managed_transport():
]
@pytest.mark.usefixtures("async_environment")
async def test_context_managed_transport_and_mount():
class Transport(httpcore.AsyncHTTPTransport):
def __init__(self, name: str):
self.name: str = name
self.events: typing.List[str] = []
async def aclose(self):
# The base implementation of httpcore.AsyncHTTPTransport just
# calls into `.aclose`, so simple transport cases can just override
# this method for any cleanup, where more complex cases
# might want to additionally override `__aenter__`/`__aexit__`.
self.events.append(f"{self.name}.aclose")
async def __aenter__(self):
await super().__aenter__()
self.events.append(f"{self.name}.__aenter__")
async def __aexit__(self, *args):
await super().__aexit__(*args)
self.events.append(f"{self.name}.__aexit__")
transport = Transport(name="transport")
mounted = Transport(name="mounted")
async with httpx.AsyncClient(
transport=transport, mounts={"http://www.example.org": mounted}
):
pass
assert transport.events == [
"transport.__aenter__",
"transport.aclose",
"transport.__aexit__",
]
assert mounted.events == [
"mounted.__aenter__",
"mounted.aclose",
"mounted.__aexit__",
]
def hello_world(request):
return httpx.Response(200, text="Hello, world!")
@ -242,3 +277,28 @@ async def test_deleting_unclosed_async_client_causes_warning():
await client.get("http://example.com")
with pytest.warns(UserWarning):
del client
def unmounted(request: httpx.Request) -> httpx.Response:
data = {"app": "unmounted"}
return httpx.Response(200, json=data)
def mounted(request: httpx.Request) -> httpx.Response:
data = {"app": "mounted"}
return httpx.Response(200, json=data)
@pytest.mark.usefixtures("async_environment")
async def test_mounted_transport():
transport = MockTransport(unmounted)
mounts = {"custom://": MockTransport(mounted)}
async with httpx.AsyncClient(transport=transport, mounts=mounts) as client:
response = await client.get("https://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "unmounted"}
response = await client.get("custom://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "mounted"}

View File

@ -1,3 +1,4 @@
import typing
from datetime import timedelta
import httpcore
@ -227,15 +228,8 @@ def test_context_managed_transport():
super().__exit__(*args)
self.events.append("transport.__exit__")
# Note that we're including 'proxies' here to *also* run through the
# proxy context management, although we can't easily test that at the
# moment, since we can't add proxies as transport instances.
#
# Once we have a more generalised Mount API we'll be able to remove this
# in favour of ensuring all mounts are context managed, which will
# also neccessarily include proxies.
transport = Transport()
with httpx.Client(transport=transport, proxies="http://www.example.com"):
with httpx.Client(transport=transport):
pass
assert transport.events == [
@ -245,6 +239,44 @@ def test_context_managed_transport():
]
def test_context_managed_transport_and_mount():
class Transport(httpcore.SyncHTTPTransport):
def __init__(self, name: str):
self.name: str = name
self.events: typing.List[str] = []
def close(self):
# The base implementation of httpcore.SyncHTTPTransport just
# calls into `.close`, so simple transport cases can just override
# this method for any cleanup, where more complex cases
# might want to additionally override `__enter__`/`__exit__`.
self.events.append(f"{self.name}.close")
def __enter__(self):
super().__enter__()
self.events.append(f"{self.name}.__enter__")
def __exit__(self, *args):
super().__exit__(*args)
self.events.append(f"{self.name}.__exit__")
transport = Transport(name="transport")
mounted = Transport(name="mounted")
with httpx.Client(transport=transport, mounts={"http://www.example.org": mounted}):
pass
assert transport.events == [
"transport.__enter__",
"transport.close",
"transport.__exit__",
]
assert mounted.events == [
"mounted.__enter__",
"mounted.close",
"mounted.__exit__",
]
def hello_world(request):
return httpx.Response(200, text="Hello, world!")
@ -300,3 +332,38 @@ def test_raw_client_header():
["User-Agent", f"python-httpx/{httpx.__version__}"],
["Example-Header", "example-value"],
]
def unmounted(request: httpx.Request) -> httpx.Response:
data = {"app": "unmounted"}
return httpx.Response(200, json=data)
def mounted(request: httpx.Request) -> httpx.Response:
data = {"app": "mounted"}
return httpx.Response(200, json=data)
def test_mounted_transport():
transport = MockTransport(unmounted)
mounts = {"custom://": MockTransport(mounted)}
client = httpx.Client(transport=transport, mounts=mounts)
response = client.get("https://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "unmounted"}
response = client.get("custom://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "mounted"}
def test_all_mounted_transport():
mounts = {"all://": MockTransport(mounted)}
client = httpx.Client(mounts=mounts)
response = client.get("https://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "mounted"}

View File

@ -41,12 +41,12 @@ def test_proxies_parameter(proxies, expected_proxies):
for proxy_key, url in expected_proxies:
pattern = URLPattern(proxy_key)
assert pattern in client._proxies
proxy = client._proxies[pattern]
assert pattern in client._mounts
proxy = client._mounts[pattern]
assert isinstance(proxy, httpcore.SyncHTTPProxy)
assert proxy.proxy_origin == url_to_origin(url)
assert len(expected_proxies) == len(client._proxies)
assert len(expected_proxies) == len(client._mounts)
PROXY_URL = "http://[::1]"