Httpcore interface (#804)

* First pass as switching dispatchers over to httpcore interface

* Updates for httpcore interface

* headers in dispatch API as plain list of bytes

* Integrate against httpcore 0.6

* Integrate against httpcore interface

* Drop UDS, since not supported by httpcore

* Fix base class for mock dispatchers in tests

* Merge master and mark as potential '0.13.dev0' release
This commit is contained in:
Tom Christie 2020-04-08 13:32:10 +01:00 committed by GitHub
parent 631ba97635
commit 3046e920ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 562 additions and 2458 deletions

View File

@ -6,7 +6,6 @@ from ._config import PoolLimits, Proxy, Timeout
from ._dispatch.asgi import ASGIDispatch
from ._dispatch.wsgi import WSGIDispatch
from ._exceptions import (
ConnectionClosed,
ConnectTimeout,
CookieConflict,
DecodingError,
@ -23,7 +22,6 @@ from ._exceptions import (
ResponseClosed,
ResponseNotRead,
StreamConsumed,
TimeoutException,
TooManyRedirects,
WriteTimeout,
)
@ -56,7 +54,6 @@ __all__ = [
"Timeout",
"ConnectTimeout",
"CookieConflict",
"ConnectionClosed",
"DecodingError",
"HTTPError",
"InvalidURL",
@ -79,7 +76,6 @@ __all__ = [
"Headers",
"QueryParams",
"Request",
"TimeoutException",
"Response",
"DigestAuth",
"WSGIDispatch",

View File

@ -1,3 +1,3 @@
__title__ = "httpx"
__description__ = "A next generation HTTP client, for Python 3."
__version__ = "0.12.1"
__version__ = "0.13.dev0"

View File

@ -3,6 +3,7 @@ import typing
from types import TracebackType
import hstspreload
import httpcore
from ._auth import Auth, AuthTypes, BasicAuth, FunctionAuth
from ._config import (
@ -14,6 +15,7 @@ from ._config import (
PoolLimits,
ProxiesTypes,
Proxy,
SSLConfig,
Timeout,
TimeoutTypes,
UnsetType,
@ -21,10 +23,6 @@ from ._config import (
)
from ._content_streams import ContentStream
from ._dispatch.asgi import ASGIDispatch
from ._dispatch.base import AsyncDispatcher, SyncDispatcher
from ._dispatch.connection_pool import ConnectionPool
from ._dispatch.proxy_http import HTTPProxy
from ._dispatch.urllib3 import URLLib3Dispatcher
from ._dispatch.wsgi import WSGIDispatch
from ._exceptions import HTTPError, InvalidURL, RequestBodyUnavailable, TooManyRedirects
from ._models import (
@ -96,7 +94,7 @@ class BaseClient:
elif isinstance(proxies, (str, URL, Proxy)):
proxy = Proxy(url=proxies) if isinstance(proxies, (str, URL)) else proxies
return {"all": proxy}
elif isinstance(proxies, AsyncDispatcher): # pragma: nocover
elif isinstance(proxies, httpcore.AsyncHTTPTransport): # pragma: nocover
raise RuntimeError(
"Passing a dispatcher instance to 'proxies=' is no longer "
"supported. Use `httpx.Proxy() instead.`"
@ -107,7 +105,7 @@ class BaseClient:
if isinstance(value, (str, URL, Proxy)):
proxy = Proxy(url=value) if isinstance(value, (str, URL)) else value
new_proxies[str(key)] = proxy
elif isinstance(value, AsyncDispatcher): # pragma: nocover
elif isinstance(value, httpcore.AsyncHTTPTransport): # pragma: nocover
raise RuntimeError(
"Passing a dispatcher instance to 'proxies=' is "
"no longer supported. Use `httpx.Proxy() instead.`"
@ -446,7 +444,7 @@ class Client(BaseClient):
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_url: URLTypes = None,
dispatch: SyncDispatcher = None,
dispatch: httpcore.SyncHTTPTransport = None,
app: typing.Callable = None,
trust_env: bool = True,
):
@ -471,7 +469,7 @@ class Client(BaseClient):
app=app,
trust_env=trust_env,
)
self.proxies: typing.Dict[str, SyncDispatcher] = {
self.proxies: typing.Dict[str, httpcore.SyncHTTPTransport] = {
key: self.init_proxy_dispatch(
proxy,
verify=verify,
@ -487,18 +485,26 @@ class Client(BaseClient):
verify: VerifyTypes = True,
cert: CertTypes = None,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
dispatch: SyncDispatcher = None,
dispatch: httpcore.SyncHTTPTransport = None,
app: typing.Callable = None,
trust_env: bool = True,
) -> SyncDispatcher:
) -> httpcore.SyncHTTPTransport:
if dispatch is not None:
return dispatch
if app is not None:
return WSGIDispatch(app=app)
return URLLib3Dispatcher(
verify=verify, cert=cert, pool_limits=pool_limits, trust_env=trust_env,
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
max_keepalive = pool_limits.soft_limit
max_connections = pool_limits.hard_limit
return httpcore.SyncConnectionPool(
ssl_context=ssl_context,
max_keepalive=max_keepalive,
max_connections=max_connections,
)
def init_proxy_dispatch(
@ -508,18 +514,25 @@ class Client(BaseClient):
cert: CertTypes = None,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
trust_env: bool = True,
) -> SyncDispatcher:
return URLLib3Dispatcher(
proxy=proxy,
verify=verify,
cert=cert,
pool_limits=pool_limits,
trust_env=trust_env,
) -> httpcore.SyncHTTPTransport:
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
max_keepalive = pool_limits.soft_limit
max_connections = pool_limits.hard_limit
return httpcore.SyncHTTPProxy(
proxy_origin=proxy.url.raw[:3],
proxy_headers=proxy.headers.raw,
proxy_mode=proxy.mode,
ssl_context=ssl_context,
max_keepalive=max_keepalive,
max_connections=max_connections,
)
def dispatcher_for_url(self, url: URL) -> SyncDispatcher:
def dispatcher_for_url(self, url: URL) -> httpcore.SyncHTTPTransport:
"""
Returns the SyncDispatcher instance that should be used for a given URL.
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
if self.proxies and not should_not_be_proxied(url):
@ -667,7 +680,7 @@ class Client(BaseClient):
request = next_request
history.append(response)
def send_single_request(self, request: Request, timeout: Timeout,) -> Response:
def send_single_request(self, request: Request, timeout: Timeout) -> Response:
"""
Sends a single request, without handling any redirections.
"""
@ -675,7 +688,19 @@ class Client(BaseClient):
dispatcher = self.dispatcher_for_url(request.url)
try:
response = dispatcher.send(request, timeout=timeout)
(
http_version,
status_code,
reason_phrase,
headers,
stream,
) = dispatcher.request(
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
stream=request.stream,
timeout=timeout.as_dict(),
)
except HTTPError as exc:
# Add the original request to any HTTPError unless
# there'a already a request attached in the case of
@ -683,6 +708,13 @@ class Client(BaseClient):
if exc._request is None:
exc._request = request
raise
response = Response(
status_code,
http_version=http_version.decode("ascii"),
headers=headers,
stream=stream, # type: ignore
request=request,
)
self.cookies.extract_cookies(response)
@ -928,7 +960,6 @@ class AsyncClient(BaseClient):
rather than sending actual network requests.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
* **uds** - *(optional)* A path to a Unix domain socket to connect through.
"""
def __init__(
@ -946,10 +977,9 @@ class AsyncClient(BaseClient):
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_url: URLTypes = None,
dispatch: AsyncDispatcher = None,
dispatch: httpcore.AsyncHTTPTransport = None,
app: typing.Callable = None,
trust_env: bool = True,
uds: str = None,
):
super().__init__(
auth=auth,
@ -972,9 +1002,8 @@ class AsyncClient(BaseClient):
dispatch=dispatch,
app=app,
trust_env=trust_env,
uds=uds,
)
self.proxies: typing.Dict[str, AsyncDispatcher] = {
self.proxies: typing.Dict[str, httpcore.AsyncHTTPTransport] = {
key: self.init_proxy_dispatch(
proxy,
verify=verify,
@ -992,24 +1021,27 @@ class AsyncClient(BaseClient):
cert: CertTypes = None,
http2: bool = False,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
dispatch: AsyncDispatcher = None,
dispatch: httpcore.AsyncHTTPTransport = None,
app: typing.Callable = None,
trust_env: bool = True,
uds: str = None,
) -> AsyncDispatcher:
) -> httpcore.AsyncHTTPTransport:
if dispatch is not None:
return dispatch
if app is not None:
return ASGIDispatch(app=app)
return ConnectionPool(
verify=verify,
cert=cert,
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
max_keepalive = pool_limits.soft_limit
max_connections = pool_limits.hard_limit
return httpcore.AsyncConnectionPool(
ssl_context=ssl_context,
max_keepalive=max_keepalive,
max_connections=max_connections,
http2=http2,
pool_limits=pool_limits,
trust_env=trust_env,
uds=uds,
)
def init_proxy_dispatch(
@ -1020,21 +1052,25 @@ class AsyncClient(BaseClient):
http2: bool = False,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
trust_env: bool = True,
) -> AsyncDispatcher:
return HTTPProxy(
proxy_url=proxy.url,
proxy_headers=proxy.headers,
) -> httpcore.AsyncHTTPTransport:
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
max_keepalive = pool_limits.soft_limit
max_connections = pool_limits.hard_limit
return httpcore.AsyncHTTPProxy(
proxy_origin=proxy.url.raw[:3],
proxy_headers=proxy.headers.raw,
proxy_mode=proxy.mode,
verify=verify,
cert=cert,
http2=http2,
pool_limits=pool_limits,
trust_env=trust_env,
ssl_context=ssl_context,
max_keepalive=max_keepalive,
max_connections=max_connections,
)
def dispatcher_for_url(self, url: URL) -> AsyncDispatcher:
def dispatcher_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport:
"""
Returns the AsyncDispatcher instance that should be used for a given URL.
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
if self.proxies and not should_not_be_proxied(url):
@ -1193,7 +1229,19 @@ class AsyncClient(BaseClient):
dispatcher = self.dispatcher_for_url(request.url)
try:
response = await dispatcher.send(request, timeout=timeout)
(
http_version,
status_code,
reason_phrase,
headers,
stream,
) = await dispatcher.request(
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
stream=request.stream,
timeout=timeout.as_dict(),
)
except HTTPError as exc:
# Add the original request to any HTTPError unless
# there'a already a request attached in the case of
@ -1201,6 +1249,13 @@ class AsyncClient(BaseClient):
if exc._request is None:
exc._request = request
raise
response = Response(
status_code,
http_version=http_version.decode("ascii"),
headers=headers,
stream=stream, # type: ignore
request=request,
)
self.cookies.extract_cookies(response)
@ -1383,9 +1438,9 @@ class AsyncClient(BaseClient):
)
async def aclose(self) -> None:
await self.dispatch.close()
await self.dispatch.aclose()
for proxy in self.proxies.values():
await proxy.close()
await proxy.aclose()
async def __aenter__(self) -> "AsyncClient":
return self

View File

@ -253,6 +253,14 @@ class Timeout:
timeout if isinstance(pool_timeout, UnsetType) else pool_timeout
)
def as_dict(self) -> typing.Dict[str, typing.Optional[float]]:
return {
"connect": self.connect_timeout,
"read": self.read_timeout,
"write": self.write_timeout,
"pool": self.pool_timeout,
}
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)

View File

@ -7,6 +7,8 @@ from json import dumps as json_dumps
from pathlib import Path
from urllib.parse import urlencode
import httpcore
from ._exceptions import StreamConsumed
from ._types import StrOrBytes
from ._utils import format_form_param
@ -35,7 +37,7 @@ RequestFiles = typing.Dict[
]
class ContentStream:
class ContentStream(httpcore.AsyncByteStream, httpcore.SyncByteStream):
def get_headers(self) -> typing.Dict[str, str]:
"""
Return a dictionary of headers that are implied by the encoding.

View File

@ -1,12 +1,11 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
import httpcore
from .._config import TimeoutTypes
from .._content_streams import ByteStream
from .._models import Request, Response
from .base import AsyncDispatcher
class ASGIDispatch(AsyncDispatcher):
class ASGIDispatch(httpcore.AsyncHTTPTransport):
"""
A custom AsyncDispatcher that handles sending requests directly to an ASGI app.
The simplest way to use this functionality is to use the `app` argument.
@ -41,37 +40,49 @@ class ASGIDispatch(AsyncDispatcher):
def __init__(
self,
app: typing.Callable,
app: Callable,
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
client: Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
async def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
async def request(
self,
method: bytes,
url: Tuple[bytes, bytes, int, bytes],
headers: List[Tuple[bytes, bytes]] = None,
stream: httpcore.AsyncByteStream = None,
timeout: Dict[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": request.headers.raw,
"scheme": request.url.scheme,
"path": request.url.path,
"query_string": request.url.query.encode("ascii"),
"server": request.url.host,
"method": method.decode(),
"headers": headers,
"scheme": scheme.decode("ascii"),
"path": path.decode("ascii"),
"query_string": query,
"server": (host.decode("ascii"), port),
"client": self.client,
"root_path": self.root_path,
}
status_code = None
headers = None
response_headers = None
body_parts = []
response_started = False
response_complete = False
request_body_chunks = request.stream.__aiter__()
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
request_body_chunks = stream.__aiter__()
async def receive() -> dict:
try:
@ -81,14 +92,14 @@ class ASGIDispatch(AsyncDispatcher):
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict) -> None:
nonlocal status_code, headers, body_parts
nonlocal status_code, response_headers, body_parts
nonlocal response_started, response_complete
if message["type"] == "http.response.start":
assert not response_started
status_code = message["status"]
headers = message.get("headers", [])
response_headers = message.get("headers", [])
response_started = True
elif message["type"] == "http.response.body":
@ -96,7 +107,7 @@ class ASGIDispatch(AsyncDispatcher):
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and request.method != "HEAD":
if body and method != b"HEAD":
body_parts.append(body)
if not more_body:
@ -110,14 +121,8 @@ class ASGIDispatch(AsyncDispatcher):
assert response_complete
assert status_code is not None
assert headers is not None
assert response_headers is not None
stream = ByteStream(b"".join(body_parts))
return Response(
status_code=status_code,
http_version="HTTP/1.1",
headers=headers,
stream=stream,
request=request,
)
return (b"HTTP/1.1", status_code, b"", response_headers, stream)

View File

@ -1,64 +0,0 @@
import typing
from types import TracebackType
from .._config import Timeout
from .._models import (
HeaderTypes,
QueryParamTypes,
Request,
RequestData,
Response,
URLTypes,
)
class SyncDispatcher:
"""
Base class for Dispatcher classes, that handle sending the request.
"""
def send(self, request: Request, timeout: Timeout = None) -> Response:
raise NotImplementedError() # pragma: nocover
def close(self) -> None:
pass # pragma: nocover
class AsyncDispatcher:
"""
Base class for AsyncDispatcher classes, that handle sending the request.
Stubs out the interface, as well as providing a `.request()` convenience
implementation, to make it easy to use or test stand-alone AsyncDispatchers,
without requiring a complete `AsyncClient` instance.
"""
async def request(
self,
method: str,
url: URLTypes,
*,
data: RequestData = b"",
params: QueryParamTypes = None,
headers: HeaderTypes = None,
timeout: Timeout = None,
) -> Response:
request = Request(method, url, data=data, params=params, headers=headers)
return await self.send(request, timeout=timeout)
async def send(self, request: Request, timeout: Timeout = None) -> Response:
raise NotImplementedError() # pragma: nocover
async def close(self) -> None:
pass # pragma: nocover
async def __aenter__(self) -> "AsyncDispatcher":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()

View File

@ -1,149 +0,0 @@
import functools
import typing
import h11
from .._backends.base import ConcurrencyBackend, lookup_backend
from .._config import SSLConfig, Timeout
from .._models import URL, Origin, Request, Response
from .._utils import get_logger
from .base import AsyncDispatcher
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
# Callback signature: async def callback(conn: HTTPConnection) -> None
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
logger = get_logger(__name__)
class HTTPConnection(AsyncDispatcher):
def __init__(
self,
origin: typing.Union[str, Origin],
ssl: SSLConfig = None,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
release_func: typing.Optional[ReleaseCallback] = None,
uds: typing.Optional[str] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig() if ssl is None else ssl
self.backend = lookup_backend(backend)
self.release_func = release_func
self.uds = uds
self.connection: typing.Union[None, HTTP11Connection, HTTP2Connection] = None
self.expires_at: typing.Optional[float] = None
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
if self.connection is None:
self.connection = await self.connect(timeout=timeout)
return await self.connection.send(request, timeout=timeout)
async def connect(
self, timeout: Timeout
) -> typing.Union[HTTP11Connection, HTTP2Connection]:
host = self.origin.host
port = self.origin.port
ssl_context = None if not self.origin.is_ssl else self.ssl.ssl_context
if self.release_func is None:
on_release = None
else:
on_release = functools.partial(self.release_func, self)
if self.uds is None:
logger.trace(
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
)
socket = await self.backend.open_tcp_stream(
host, port, ssl_context, timeout
)
else:
logger.trace(
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
)
socket = await self.backend.open_uds_stream(
self.uds, host, ssl_context, timeout
)
http_version = socket.get_http_version()
logger.trace(f"connected http_version={http_version!r}")
if http_version == "HTTP/2":
return HTTP2Connection(socket, self.backend, on_release=on_release)
return HTTP11Connection(socket, on_release=on_release)
async def tunnel_start_tls(
self, origin: Origin, proxy_url: URL, timeout: Timeout = None,
) -> None:
"""
Upgrade this connection to use TLS, assuming it represents a TCP tunnel.
"""
timeout = Timeout() if timeout is None else timeout
# First, check that we are in the correct state to start TLS, i.e. we've
# just agreed to switch protocols with the server via HTTP/1.1.
assert isinstance(self.connection, HTTP11Connection)
h11_connection = self.connection
assert h11_connection is not None
assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
# Store this information here so that we can transfer
# it to the new internal connection object after
# the old one goes to 'SWITCHED_PROTOCOL'.
# Note that the negotiated 'http_version' may change after the TLS upgrade.
http_version = "HTTP/1.1"
socket = h11_connection.socket
on_release = h11_connection.on_release
if origin.is_ssl:
# Pull the socket stream off the internal HTTP connection object,
# and run start_tls().
ssl_context = self.ssl.ssl_context
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
socket = await socket.start_tls(
hostname=origin.host, ssl_context=ssl_context, timeout=timeout
)
http_version = socket.get_http_version()
logger.trace(
f"tunnel_tls_complete "
f"proxy_url={proxy_url!r} "
f"origin={origin!r} "
f"http_version={http_version!r}"
)
else:
# User requested the use of a tunnel, but they're performing a plain-text
# HTTP request. Don't try to upgrade to TLS in this case.
pass
if http_version == "HTTP/2":
self.connection = HTTP2Connection(
socket, self.backend, on_release=on_release
)
else:
self.connection = HTTP11Connection(socket, on_release=on_release)
async def close(self) -> None:
logger.trace("close_connection")
if self.connection is not None:
await self.connection.close()
@property
def is_http2(self) -> bool:
return self.connection is not None and self.connection.is_http2
@property
def is_closed(self) -> bool:
return self.connection is not None and self.connection.is_closed
def is_connection_dropped(self) -> bool:
return self.connection is not None and self.connection.is_connection_dropped()
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(origin={self.origin!r})"

View File

@ -1,221 +0,0 @@
import typing
from .._backends.base import BaseSemaphore, ConcurrencyBackend, lookup_backend
from .._config import (
DEFAULT_POOL_LIMITS,
CertTypes,
PoolLimits,
SSLConfig,
Timeout,
VerifyTypes,
)
from .._exceptions import PoolTimeout
from .._models import Origin, Request, Response
from .._utils import get_logger
from .base import AsyncDispatcher
from .connection import HTTPConnection
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
logger = get_logger(__name__)
class NullSemaphore(BaseSemaphore):
async def acquire(self, timeout: float = None) -> None:
return
def release(self) -> None:
return
class ConnectionStore:
"""
We need to maintain collections of connections in a way that allows us to:
* Lookup connections by origin.
* Iterate over connections by insertion time.
* Return the total number of connections.
"""
def __init__(self) -> None:
self.all: typing.Dict[HTTPConnection, float] = {}
self.by_origin: typing.Dict[Origin, typing.Dict[HTTPConnection, float]] = {}
def pop_by_origin(
self, origin: Origin, http2_only: bool = False
) -> typing.Optional[HTTPConnection]:
try:
connections = self.by_origin[origin]
except KeyError:
return None
connection = next(reversed(list(connections.keys())))
if http2_only and not connection.is_http2:
return None
del connections[connection]
if not connections:
del self.by_origin[origin]
del self.all[connection]
return connection
def add(self, connection: HTTPConnection) -> None:
self.all[connection] = 0.0
try:
self.by_origin[connection.origin][connection] = 0.0
except KeyError:
self.by_origin[connection.origin] = {connection: 0.0}
def remove(self, connection: HTTPConnection) -> None:
del self.all[connection]
del self.by_origin[connection.origin][connection]
if not self.by_origin[connection.origin]:
del self.by_origin[connection.origin]
def clear(self) -> None:
self.all.clear()
self.by_origin.clear()
def __iter__(self) -> typing.Iterator[HTTPConnection]:
return iter(self.all.keys())
def __len__(self) -> int:
return len(self.all)
class ConnectionPool(AsyncDispatcher):
KEEP_ALIVE_EXPIRY = 5.0
def __init__(
self,
*,
verify: VerifyTypes = True,
cert: CertTypes = None,
trust_env: bool = None,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http2: bool = False,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
uds: typing.Optional[str] = None,
):
self.ssl = SSLConfig(verify=verify, cert=cert, trust_env=trust_env, http2=http2)
self.pool_limits = pool_limits
self.is_closed = False
self.uds = uds
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
self.backend = lookup_backend(backend)
self.next_keepalive_check = 0.0
@property
def max_connections(self) -> BaseSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_connections"):
limit = self.pool_limits.hard_limit
if limit:
self._max_connections = self.backend.create_semaphore(
limit, exc_class=PoolTimeout
)
else:
self._max_connections = NullSemaphore()
return self._max_connections
@property
def num_connections(self) -> int:
return len(self.keepalive_connections) + len(self.active_connections)
async def check_keepalive_expiry(self) -> None:
now = self.backend.time()
if now < self.next_keepalive_check:
return
self.next_keepalive_check = now + 1.0
# Iterate through all the keep alive connections.
# We create a list here to avoid any 'changed during iteration' errors.
keepalives = list(self.keepalive_connections.all.keys())
for connection in keepalives:
if connection.expires_at is not None and now > connection.expires_at:
self.keepalive_connections.remove(connection)
self.max_connections.release()
await connection.close()
async def send(self, request: Request, timeout: Timeout = None) -> Response:
await self.check_keepalive_expiry()
connection = await self.acquire_connection(
origin=Origin(request.url), timeout=timeout
)
try:
response = await connection.send(request, timeout=timeout)
except BaseException as exc:
self.active_connections.remove(connection)
self.max_connections.release()
raise exc
return response
async def acquire_connection(
self, origin: Origin, timeout: Timeout = None
) -> HTTPConnection:
logger.trace(f"acquire_connection origin={origin!r}")
connection = self.pop_connection(origin)
if connection is None:
pool_timeout = None if timeout is None else timeout.pool_timeout
await self.max_connections.acquire(timeout=pool_timeout)
connection = HTTPConnection(
origin,
ssl=self.ssl,
backend=self.backend,
release_func=self.release_connection,
uds=self.uds,
)
logger.trace(f"new_connection connection={connection!r}")
else:
logger.trace(f"reuse_connection connection={connection!r}")
self.active_connections.add(connection)
return connection
async def release_connection(self, connection: HTTPConnection) -> None:
logger.trace(f"release_connection connection={connection!r}")
if connection.is_closed:
self.active_connections.remove(connection)
self.max_connections.release()
elif (
self.pool_limits.soft_limit is not None
and self.num_connections > self.pool_limits.soft_limit
):
self.active_connections.remove(connection)
self.max_connections.release()
await connection.close()
else:
now = self.backend.time()
connection.expires_at = now + self.KEEP_ALIVE_EXPIRY
self.active_connections.remove(connection)
self.keepalive_connections.add(connection)
async def close(self) -> None:
self.is_closed = True
connections = list(self.keepalive_connections)
self.keepalive_connections.clear()
for connection in connections:
self.max_connections.release()
await connection.close()
def pop_connection(self, origin: Origin) -> typing.Optional[HTTPConnection]:
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin)
if connection is not None and connection.is_connection_dropped():
self.max_connections.release()
connection = None
return connection

View File

@ -1,206 +0,0 @@
import typing
import h11
from .._backends.base import BaseSocketStream
from .._config import Timeout
from .._content_streams import AsyncIteratorStream
from .._exceptions import ConnectionClosed, ProtocolError
from .._models import Request, Response
from .._utils import get_logger
H11Event = typing.Union[
h11.Request,
h11.Response,
h11.InformationalResponse,
h11.Data,
h11.EndOfMessage,
h11.ConnectionClosed,
]
# Callback signature: async def callback() -> None
# In practice the callback will be a functools partial, which binds
# the `ConnectionPool.release_connection(conn: HTTPConnection)` method.
OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
logger = get_logger(__name__)
class HTTP11Connection:
READ_NUM_BYTES = 4096
def __init__(
self,
socket: BaseSocketStream,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.socket = socket
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
@property
def is_http2(self) -> bool:
return False
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
await self._send_request(request, timeout)
await self._send_request_body(request, timeout)
http_version, status_code, headers = await self._receive_response(timeout)
stream = AsyncIteratorStream(
aiterator=self._receive_response_data(timeout),
close_func=self.response_closed,
)
return Response(
status_code=status_code,
http_version=http_version,
headers=headers,
stream=stream,
request=request,
)
async def close(self) -> None:
event = h11.ConnectionClosed()
try:
logger.trace(f"send_event event={event!r}")
self.h11_state.send(event)
except h11.LocalProtocolError: # pragma: no cover
# Premature client disconnect
pass
await self.socket.close()
async def _send_request(self, request: Request, timeout: Timeout) -> None:
"""
Send the request method, URL, and headers to the network.
"""
logger.trace(
f"send_headers method={request.method!r} "
f"target={request.url.full_path!r} "
f"headers={request.headers!r}"
)
method = request.method.encode("ascii")
target = request.url.full_path.encode("ascii")
headers = request.headers.raw
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event, timeout)
async def _send_request_body(self, request: Request, timeout: Timeout) -> None:
"""
Send the request body to the network.
"""
try:
# Send the request body.
async for chunk in request.stream:
logger.trace(f"send_data data=Data(<{len(chunk)} bytes>)")
event = h11.Data(data=chunk)
await self._send_event(event, timeout)
# Finalize sending the request.
event = h11.EndOfMessage()
await self._send_event(event, timeout)
except OSError: # pragma: nocover
# Once we've sent the initial part of the request we don't actually
# care about connection errors that occur when sending the body.
# Ignore these, and defer to any exceptions on reading the response.
self.h11_state.send_failed()
async def _send_event(self, event: H11Event, timeout: Timeout) -> None:
"""
Send a single `h11` event to the network, waiting for the data to
drain before returning.
"""
bytes_to_send = self.h11_state.send(event)
await self.socket.write(bytes_to_send, timeout)
async def _receive_response(
self, timeout: Timeout
) -> typing.Tuple[str, int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Read the response status and headers from the network.
"""
while True:
event = await self._receive_event(timeout)
if isinstance(event, h11.InformationalResponse):
continue
else:
assert isinstance(event, h11.Response)
break # pragma: no cover
http_version = "HTTP/%s" % event.http_version.decode("latin-1", errors="ignore")
return http_version, event.status_code, event.headers
async def _receive_response_data(
self, timeout: Timeout
) -> typing.AsyncIterator[bytes]:
"""
Read the response data from the network.
"""
while True:
event = await self._receive_event(timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
else:
assert isinstance(event, h11.EndOfMessage) or event is h11.PAUSED
break # pragma: no cover
async def _receive_event(self, timeout: Timeout) -> H11Event:
"""
Read a single `h11` event, reading more data from the network if needed.
"""
while True:
try:
event = self.h11_state.next_event()
except h11.RemoteProtocolError as e:
logger.debug(
"h11.RemoteProtocolError exception "
+ f"their_state={self.h11_state.their_state} "
+ f"error_status_hint={e.error_status_hint}"
)
if self.socket.is_connection_dropped():
raise ConnectionClosed(e)
raise ProtocolError(e)
if isinstance(event, h11.Data):
logger.trace(f"receive_event event=Data(<{len(event.data)} bytes>)")
else:
logger.trace(f"receive_event event={event!r}")
if event is h11.NEED_DATA:
try:
data = await self.socket.read(self.READ_NUM_BYTES, timeout)
except OSError: # pragma: nocover
data = b""
self.h11_state.receive_data(data)
else:
assert event is not h11.NEED_DATA
break # pragma: no cover
return event
async def response_closed(self) -> None:
logger.trace(
f"response_closed "
f"our_state={self.h11_state.our_state!r} "
f"their_state={self.h11_state.their_state}"
)
if (
self.h11_state.our_state is h11.DONE
and self.h11_state.their_state is h11.DONE
):
# Get ready for another request/response cycle.
self.h11_state.start_next_cycle()
else:
await self.close()
if self.on_release is not None:
await self.on_release()
@property
def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
def is_connection_dropped(self) -> bool:
return self.socket.is_connection_dropped()

View File

@ -1,298 +0,0 @@
import typing
import h2.connection
import h2.events
from h2.config import H2Configuration
from h2.settings import SettingCodes, Settings
from .._backends.base import (
BaseLock,
BaseSocketStream,
ConcurrencyBackend,
lookup_backend,
)
from .._config import Timeout
from .._content_streams import AsyncIteratorStream
from .._exceptions import ProtocolError
from .._models import Request, Response
from .._utils import get_logger
logger = get_logger(__name__)
class HTTP2Connection:
READ_NUM_BYTES = 4096
CONFIG = H2Configuration(validate_inbound_headers=False)
def __init__(
self,
socket: BaseSocketStream,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
on_release: typing.Callable = None,
):
self.socket = socket
self.backend = lookup_backend(backend)
self.on_release = on_release
self.state = h2.connection.H2Connection(config=self.CONFIG)
self.streams = {} # type: typing.Dict[int, HTTP2Stream]
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
self.sent_connection_init = False
@property
def is_http2(self) -> bool:
return True
@property
def init_lock(self) -> BaseLock:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_initialization_lock"):
self._initialization_lock = self.backend.create_lock()
return self._initialization_lock
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
async with self.init_lock:
if not self.sent_connection_init:
# The very first stream is responsible for initiating the connection.
await self.send_connection_init(timeout)
self.sent_connection_init = True
stream_id = self.state.get_next_available_stream_id()
stream = HTTP2Stream(stream_id=stream_id, connection=self)
self.streams[stream_id] = stream
self.events[stream_id] = []
return await stream.send(request, timeout)
async def send_connection_init(self, timeout: Timeout) -> None:
"""
The HTTP/2 connection requires some initial setup before we can start
using individual request/response streams on it.
"""
# Need to set these manually here instead of manipulating via
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
# frames in addition to sending the undesired defaults.
self.state.local_settings = Settings(
client=True,
initial_values={
# Disable PUSH_PROMISE frames from the server since we don't do anything
# with them for now. Maybe when we support caching?
SettingCodes.ENABLE_PUSH: 0,
# These two are taken from h2 for safe defaults
SettingCodes.MAX_CONCURRENT_STREAMS: 100,
SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
},
)
# Some websites (*cough* Yahoo *cough*) balk at this setting being
# present in the initial handshake since it's not defined in the original
# RFC despite the RFC mandating ignoring settings you don't know about.
del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL]
self.state.initiate_connection()
self.state.increment_flow_control_window(2 ** 24)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
@property
def is_closed(self) -> bool:
return False
def is_connection_dropped(self) -> bool:
return self.socket.is_connection_dropped()
async def close(self) -> None:
await self.socket.close()
async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int:
"""
Returns the maximum allowable outgoing flow for a given stream.
If the allowable flow is zero, then waits on the network until
WindowUpdated frames have increased the flow rate.
https://tools.ietf.org/html/rfc7540#section-6.9
"""
local_flow = self.state.local_flow_control_window(stream_id)
connection_flow = self.state.max_outbound_frame_size
flow = min(local_flow, connection_flow)
while flow == 0:
await self.receive_events(timeout)
local_flow = self.state.local_flow_control_window(stream_id)
connection_flow = self.state.max_outbound_frame_size
flow = min(local_flow, connection_flow)
return flow
async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
"""
Returns the next event for a given stream.
If no events are available yet, then waits on the network until
an event is available.
"""
while not self.events[stream_id]:
await self.receive_events(timeout)
return self.events[stream_id].pop(0)
async def receive_events(self, timeout: Timeout) -> None:
"""
Read some data from the network, and update the H2 state.
"""
data = await self.socket.read(self.READ_NUM_BYTES, timeout)
events = self.state.receive_data(data)
for event in events:
event_stream_id = getattr(event, "stream_id", 0)
logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}")
if hasattr(event, "error_code"):
raise ProtocolError(event)
if event_stream_id in self.events:
self.events[event_stream_id].append(event)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def send_headers(
self,
stream_id: int,
headers: typing.List[typing.Tuple[bytes, bytes]],
end_stream: bool,
timeout: Timeout,
) -> None:
self.state.send_headers(stream_id, headers, end_stream=end_stream)
self.state.increment_flow_control_window(2 ** 24, stream_id=stream_id)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def send_data(self, stream_id: int, chunk: bytes, timeout: Timeout) -> None:
self.state.send_data(stream_id, chunk)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
self.state.end_stream(stream_id)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def acknowledge_received_data(
self, stream_id: int, amount: int, timeout: Timeout
) -> None:
self.state.acknowledge_received_data(amount, stream_id)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def close_stream(self, stream_id: int) -> None:
del self.streams[stream_id]
del self.events[stream_id]
if not self.streams and self.on_release is not None:
await self.on_release()
class HTTP2Stream:
def __init__(self, stream_id: int, connection: HTTP2Connection) -> None:
self.stream_id = stream_id
self.connection = connection
async def send(self, request: Request, timeout: Timeout) -> Response:
# Send the request.
has_body = (
"Content-Length" in request.headers
or "Transfer-Encoding" in request.headers
)
await self.send_headers(request, has_body, timeout)
if has_body:
await self.send_body(request, timeout)
# Receive the response.
status_code, headers = await self.receive_response(timeout)
stream = AsyncIteratorStream(
aiterator=self.body_iter(timeout), close_func=self.close
)
return Response(
status_code=status_code,
http_version="HTTP/2",
headers=headers,
stream=stream,
request=request,
)
async def send_headers(
self, request: Request, has_body: bool, timeout: Timeout
) -> None:
headers = [
(b":method", request.method.encode("ascii")),
(b":authority", request.url.authority.encode("ascii")),
(b":scheme", request.url.scheme.encode("ascii")),
(b":path", request.url.full_path.encode("ascii")),
] + [
(k, v)
for k, v in request.headers.raw
if k not in (b"host", b"transfer-encoding")
]
end_stream = not has_body
logger.trace(
f"send_headers "
f"stream_id={self.stream_id} "
f"method={request.method!r} "
f"target={request.url.full_path!r} "
f"headers={headers!r}"
)
await self.connection.send_headers(self.stream_id, headers, end_stream, timeout)
async def send_body(self, request: Request, timeout: Timeout) -> None:
logger.trace(f"send_body stream_id={self.stream_id}")
async for data in request.stream:
while data:
max_flow = await self.connection.wait_for_outgoing_flow(
self.stream_id, timeout
)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
await self.connection.send_data(self.stream_id, chunk, timeout)
await self.connection.end_stream(self.stream_id, timeout)
async def receive_response(
self, timeout: Timeout
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Read the response status and headers from the network.
"""
while True:
event = await self.connection.wait_for_event(self.stream_id, timeout)
if isinstance(event, h2.events.ResponseReceived):
break
status_code = 200
headers = []
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode("ascii", errors="ignore"))
elif not k.startswith(b":"):
headers.append((k, v))
return (status_code, headers)
async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]:
while True:
event = await self.connection.wait_for_event(self.stream_id, timeout)
if isinstance(event, h2.events.DataReceived):
amount = event.flow_controlled_length
await self.connection.acknowledge_received_data(
self.stream_id, amount, timeout
)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
break
async def close(self) -> None:
await self.connection.close_stream(self.stream_id)

View File

@ -1,208 +0,0 @@
import enum
import typing
import warnings
from base64 import b64encode
from .._backends.base import ConcurrencyBackend
from .._config import (
DEFAULT_POOL_LIMITS,
CertTypes,
PoolLimits,
SSLConfig,
Timeout,
VerifyTypes,
)
from .._exceptions import ProxyError
from .._models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
from .._utils import get_logger
from .connection import HTTPConnection
from .connection_pool import ConnectionPool
logger = get_logger(__name__)
class HTTPProxyMode(enum.Enum):
# This enum is pending deprecation in order to reduce API surface area,
# but is currently still around for 0.8 backwards compat.
DEFAULT = "DEFAULT"
FORWARD_ONLY = "FORWARD_ONLY"
TUNNEL_ONLY = "TUNNEL_ONLY"
DEFAULT_MODE = "DEFAULT"
FORWARD_ONLY = "FORWARD_ONLY"
TUNNEL_ONLY = "TUNNEL_ONLY"
class HTTPProxy(ConnectionPool):
"""A proxy that sends requests to the recipient server
on behalf of the connecting client.
"""
def __init__(
self,
proxy_url: URLTypes,
*,
proxy_headers: HeaderTypes = None,
proxy_mode: str = "DEFAULT",
verify: VerifyTypes = True,
cert: CertTypes = None,
trust_env: bool = None,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http2: bool = False,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
):
if isinstance(proxy_mode, HTTPProxyMode): # pragma: nocover
warnings.warn(
"The 'HTTPProxyMode' enum is pending deprecation. "
"Use a plain string instead. proxy_mode='FORWARD_ONLY', or "
"proxy_mode='TUNNEL_ONLY'."
)
proxy_mode = proxy_mode.value
assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
self.tunnel_ssl = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env, http2=False
)
super(HTTPProxy, self).__init__(
verify=verify,
cert=cert,
pool_limits=pool_limits,
backend=backend,
trust_env=trust_env,
http2=http2,
)
self.proxy_url = URL(proxy_url)
self.proxy_mode = proxy_mode
self.proxy_headers = Headers(proxy_headers)
url = self.proxy_url
if url.username or url.password:
self.proxy_headers.setdefault(
"Proxy-Authorization",
self.build_auth_header(url.username, url.password),
)
# Remove userinfo from the URL authority, e.g.:
# 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
credentials, _, authority = url.authority.rpartition("@")
self.proxy_url = url.copy_with(authority=authority)
def build_auth_header(self, username: str, password: str) -> str:
userpass = (username.encode("utf-8"), password.encode("utf-8"))
token = b64encode(b":".join(userpass)).decode().strip()
return f"Basic {token}"
async def acquire_connection(
self, origin: Origin, timeout: Timeout = None
) -> HTTPConnection:
if self.should_forward_origin(origin):
logger.trace(
f"forward_connection proxy_url={self.proxy_url!r} origin={origin!r}"
)
return await super().acquire_connection(Origin(self.proxy_url), timeout)
else:
logger.trace(
f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}"
)
return await self.tunnel_connection(origin, timeout)
async def tunnel_connection(
self, origin: Origin, timeout: Timeout = None
) -> HTTPConnection:
"""Creates a new HTTPConnection via the CONNECT method
usually reserved for proxying HTTPS connections.
"""
connection = self.pop_connection(origin)
if connection is None:
connection = await self.request_tunnel_proxy_connection(origin)
# After we receive the 2XX response from the proxy that our
# tunnel is open we switch the connection's origin
# to the original so the tunnel can be re-used.
self.active_connections.remove(connection)
connection.origin = origin
self.active_connections.add(connection)
await connection.tunnel_start_tls(
origin=origin, proxy_url=self.proxy_url, timeout=timeout,
)
else:
self.active_connections.add(connection)
return connection
async def request_tunnel_proxy_connection(self, origin: Origin) -> HTTPConnection:
"""Creates an HTTPConnection by setting up a TCP tunnel"""
proxy_headers = self.proxy_headers.copy()
proxy_headers.setdefault("Accept", "*/*")
proxy_request = Request(
method="CONNECT", url=self.proxy_url.copy_with(), headers=proxy_headers
)
proxy_request.url.full_path = f"{origin.host}:{origin.port}"
await self.max_connections.acquire()
connection = HTTPConnection(
Origin(self.proxy_url),
ssl=self.tunnel_ssl,
backend=self.backend,
release_func=self.release_connection,
)
self.active_connections.add(connection)
# See if our tunnel has been opened successfully
proxy_response = await connection.send(proxy_request)
logger.trace(
f"tunnel_response "
f"proxy_url={self.proxy_url!r} "
f"origin={origin!r} "
f"response={proxy_response!r}"
)
if not (200 <= proxy_response.status_code <= 299):
await proxy_response.aread()
raise ProxyError(
f"Non-2XX response received from HTTP proxy "
f"({proxy_response.status_code})",
request=proxy_request,
response=proxy_response,
)
else:
# Hack to ingest the response, without closing it.
async for chunk in proxy_response._raw_stream:
pass
return connection
def should_forward_origin(self, origin: Origin) -> bool:
"""Determines if the given origin should
be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'
then the proxy will forward all 'HTTP' requests and
tunnel all 'HTTPS' requests.
"""
return (
self.proxy_mode == DEFAULT_MODE and not origin.is_ssl
) or self.proxy_mode == FORWARD_ONLY
async def send(self, request: Request, timeout: Timeout = None) -> Response:
if self.should_forward_origin(Origin(request.url)):
# Change the request to have the target URL
# as its full_path and switch the proxy URL
# for where the request will be sent.
target_url = str(request.url)
request.url = self.proxy_url.copy_with()
request.url.full_path = target_url
for name, value in self.proxy_headers.items():
request.headers.setdefault(name, value)
return await super().send(request=request, timeout=timeout)
def __repr__(self) -> str:
return (
f"HTTPProxy(proxy_url={self.proxy_url!r} "
f"proxy_headers={self.proxy_headers!r} "
f"proxy_mode={self.proxy_mode!r})"
)

View File

@ -1,8 +1,9 @@
import math
import socket
import ssl
import typing
from typing import Dict, Iterator, List, Optional, Tuple, Union
import httpcore
import urllib3
from urllib3.exceptions import MaxRetryError, SSLError
@ -12,16 +13,13 @@ from .._config import (
PoolLimits,
Proxy,
SSLConfig,
Timeout,
VerifyTypes,
)
from .._content_streams import IteratorStream
from .._models import Request, Response
from .._content_streams import ByteStream, IteratorStream
from .._utils import as_network_error
from .base import SyncDispatcher
class URLLib3Dispatcher(SyncDispatcher):
class URLLib3Dispatcher(httpcore.SyncHTTPTransport):
def __init__(
self,
*,
@ -62,12 +60,12 @@ class URLLib3Dispatcher(SyncDispatcher):
def init_pool_manager(
self,
proxy: typing.Optional[Proxy],
proxy: Optional[Proxy],
ssl_context: ssl.SSLContext,
num_pools: int,
maxsize: int,
block: bool,
) -> typing.Union[urllib3.PoolManager, urllib3.ProxyManager]:
) -> Union[urllib3.PoolManager, urllib3.ProxyManager]:
if proxy is None:
return urllib3.PoolManager(
ssl_context=ssl_context,
@ -85,20 +83,58 @@ class URLLib3Dispatcher(SyncDispatcher):
block=block,
)
def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
def request(
self,
method: bytes,
url: Tuple[bytes, bytes, int, bytes],
headers: List[Tuple[bytes, bytes]] = None,
stream: httpcore.SyncByteStream = None,
timeout: Dict[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
timeout = {} if timeout is None else timeout
urllib3_timeout = urllib3.util.Timeout(
connect=timeout.connect_timeout, read=timeout.read_timeout
connect=timeout.get("connect"), read=timeout.get("read")
)
chunked = request.headers.get("Transfer-Encoding") == "chunked"
content_length = int(request.headers.get("Content-Length", "0"))
body = request.stream if chunked or content_length else None
chunked = False
content_length = 0
for header_key, header_value in headers:
header_key = header_key.lower()
if header_key == b"transfer-encoding":
chunked = header_value == b"chunked"
if header_key == b"content-length":
content_length = int(header_value.decode("ascii"))
body = stream if chunked or content_length else None
scheme, host, port, path = url
default_scheme = {80: b"http", 443: "https"}.get(port)
if scheme == default_scheme:
url_str = "%s://%s%s" % (
scheme.decode("ascii"),
host.decode("ascii"),
path.decode("ascii"),
)
else:
url_str = "%s://%s:%d%s" % (
scheme.decode("ascii"),
host.decode("ascii"),
port,
path.decode("ascii"),
)
with as_network_error(MaxRetryError, SSLError, socket.error):
conn = self.pool.urlopen(
method=request.method,
url=str(request.url),
headers=dict(request.headers),
method=method.decode(),
url=url_str,
headers=dict(
[
(key.decode("ascii"), value.decode("ascii"))
for key, value in headers
]
),
body=body,
redirect=False,
assert_same_host=False,
@ -106,23 +142,20 @@ class URLLib3Dispatcher(SyncDispatcher):
preload_content=False,
chunked=chunked,
timeout=urllib3_timeout,
pool_timeout=timeout.pool_timeout,
pool_timeout=timeout.get("pool"),
)
def response_bytes() -> typing.Iterator[bytes]:
def response_bytes() -> Iterator[bytes]:
with as_network_error(socket.error):
for chunk in conn.stream(4096, decode_content=False):
yield chunk
return Response(
status_code=conn.status,
http_version="HTTP/1.1",
headers=list(conn.headers.items()),
stream=IteratorStream(
iterator=response_bytes(), close_func=conn.release_conn
),
request=request,
status_code = conn.status
headers = list(conn.headers.items())
response_stream = IteratorStream(
iterator=response_bytes(), close_func=conn.release_conn
)
return (b"HTTP/1.1", status_code, conn.reason, headers, response_stream)
def close(self) -> None:
self.pool.clear()

View File

@ -2,10 +2,9 @@ import io
import itertools
import typing
from .._config import TimeoutTypes
from .._content_streams import IteratorStream
from .._models import Request, Response
from .base import SyncDispatcher
import httpcore
from .._content_streams import ByteStream, IteratorStream
def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
@ -16,9 +15,9 @@ def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
return []
class WSGIDispatch(SyncDispatcher):
class WSGIDispatch(httpcore.SyncHTTPTransport):
"""
A custom SyncDispatcher that handles sending requests directly to an WSGI app.
A custom transport that handles sending requests directly to an WSGI app.
The simplest way to use this functionality is to use the `app` argument.
```
@ -61,28 +60,46 @@ class WSGIDispatch(SyncDispatcher):
self.script_name = script_name
self.remote_addr = remote_addr
def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
def request(
self,
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]] = None,
stream: httpcore.SyncByteStream = None,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes,
int,
bytes,
typing.List[typing.Tuple[bytes, bytes]],
httpcore.SyncByteStream,
]:
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
environ = {
"wsgi.version": (1, 0),
"wsgi.url_scheme": request.url.scheme,
"wsgi.input": io.BytesIO(request.read()),
"wsgi.url_scheme": scheme.decode("ascii"),
"wsgi.input": io.BytesIO(b"".join([chunk for chunk in stream])),
"wsgi.errors": io.BytesIO(),
"wsgi.multithread": True,
"wsgi.multiprocess": False,
"wsgi.run_once": False,
"REQUEST_METHOD": request.method,
"REQUEST_METHOD": method.decode(),
"SCRIPT_NAME": self.script_name,
"PATH_INFO": request.url.path,
"QUERY_STRING": request.url.query,
"SERVER_NAME": request.url.host,
"SERVER_PORT": str(request.url.port),
"PATH_INFO": path.decode("ascii"),
"QUERY_STRING": query.decode("ascii"),
"SERVER_NAME": host.decode("ascii"),
"SERVER_PORT": str(port),
"REMOTE_ADDR": self.remote_addr,
}
for key, value in request.headers.items():
key = key.upper().replace("-", "_")
for header_key, header_value in headers:
key = header_key.decode("ascii").upper().replace("-", "_")
if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
key = "HTTP_" + key
environ[key] = value
environ[key] = header_value.decode("ascii")
seen_status = None
seen_response_headers = None
@ -106,10 +123,11 @@ class WSGIDispatch(SyncDispatcher):
if seen_exc_info and self.raise_app_exceptions:
raise seen_exc_info[1]
return Response(
status_code=int(seen_status.split()[0]),
http_version="HTTP/1.1",
headers=seen_response_headers,
stream=IteratorStream(chunk for chunk in result),
request=request,
)
status_code = int(seen_status.split()[0])
headers = [
(key.encode("ascii"), value.encode("ascii"))
for key, value in seen_response_headers
]
stream = IteratorStream(chunk for chunk in result)
return (b"HTTP/1.1", status_code, b"", headers, stream)

View File

@ -1,5 +1,7 @@
import typing
import httpcore
if typing.TYPE_CHECKING:
from ._models import Request, Response # pragma: nocover
@ -26,73 +28,36 @@ class HTTPError(Exception):
# Timeout exceptions...
class TimeoutException(HTTPError):
"""
A base class for all timeouts.
"""
ConnectTimeout = httpcore.ConnectTimeout
ReadTimeout = httpcore.ReadTimeout
WriteTimeout = httpcore.WriteTimeout
PoolTimeout = httpcore.PoolTimeout
class ConnectTimeout(TimeoutException):
"""
Timeout while establishing a connection.
"""
# Core networking exceptions...
NetworkError = httpcore.NetworkError
ReadError = httpcore.ReadError
WriteError = httpcore.WriteError
ConnectError = httpcore.ConnectError
CloseError = httpcore.CloseError
class ReadTimeout(TimeoutException):
"""
Timeout while reading response data.
"""
# Other transport exceptions...
class WriteTimeout(TimeoutException):
"""
Timeout while writing request data.
"""
class PoolTimeout(TimeoutException):
"""
Timeout while waiting to acquire a connection from the pool.
"""
class ProxyError(HTTPError):
"""
Error from within a proxy
"""
ProxyError = httpcore.ProxyError
ProtocolError = httpcore.ProtocolError
# HTTP exceptions...
class ProtocolError(HTTPError):
"""
Malformed HTTP.
"""
class DecodingError(HTTPError):
"""
Decoding of the response failed.
"""
# Network exceptions...
class NetworkError(HTTPError):
"""
A failure occurred while trying to access the network.
"""
class ConnectionClosed(NetworkError):
"""
Expected more data from peer, but connection was closed.
"""
# Redirect exceptions...

View File

@ -174,6 +174,15 @@ class URL:
def fragment(self) -> str:
return self._uri_reference.fragment or ""
@property
def raw(self) -> typing.Tuple[bytes, bytes, int, bytes]:
return (
self.scheme.encode("ascii"),
self.host.encode("ascii"),
self.port,
self.full_path.encode("ascii"),
)
@property
def is_ssl(self) -> bool:
return self.scheme == "https"

View File

@ -11,7 +11,7 @@ combine_as_imports = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = httpx,tests
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn
known_third_party = brotli,certifi,chardet,cryptography,hstspreload,httpcore,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn
line_length = 88
multi_line_output = 3

View File

@ -58,12 +58,9 @@ setup(
"certifi",
"hstspreload",
"chardet==3.*",
"h11>=0.8,<0.10",
"h2==3.*",
"idna==2.*",
"rfc3986>=1.3,<2",
"sniffio==1.*",
"urllib3==1.*",
"httpcore==0.7.*",
],
classifiers=[
"Development Status :: 4 - Beta",

View File

@ -148,15 +148,3 @@ async def test_100_continue(server):
assert response.status_code == 200
assert response.content == data
@pytest.mark.usefixtures("async_environment")
async def test_uds(uds_server):
url = uds_server.url
uds = uds_server.config.uds
assert uds is not None
async with httpx.AsyncClient(uds=uds) as client:
response = await client.get(url)
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.encoding == "iso-8859-1"

View File

@ -1,8 +1,8 @@
import hashlib
import json
import os
import typing
import httpcore
import pytest
from httpx import (
@ -10,35 +10,47 @@ from httpx import (
AsyncClient,
Auth,
DigestAuth,
Headers,
ProtocolError,
Request,
RequestBodyUnavailable,
Response,
)
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
from httpx._dispatch.base import AsyncDispatcher
from httpx._content_streams import ContentStream, JSONStream
class MockDispatch(AsyncDispatcher):
def __init__(self, auth_header: str = "", status_code: int = 200) -> None:
def get_header_value(headers, key, default=None):
lookup = key.encode("ascii").lower()
for header_key, header_value in headers:
if header_key.lower() == lookup:
return header_value.decode("ascii")
return default
class MockDispatch(httpcore.AsyncHTTPTransport):
def __init__(self, auth_header: bytes = b"", status_code: int = 200) -> None:
self.auth_header = auth_header
self.status_code = status_code
async def send(
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
headers = [("www-authenticate", self.auth_header)] if self.auth_header else []
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
return Response(
self.status_code, headers=headers, content=body, request=request
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
authorization = get_header_value(headers, "Authorization")
response_headers = (
[(b"www-authenticate", self.auth_header)] if self.auth_header else []
)
response_stream = JSONStream({"auth": authorization})
return b"HTTP/1.1", self.status_code, b"", response_headers, response_stream
class MockDigestAuthDispatch(AsyncDispatcher):
class MockDigestAuthDispatch(httpcore.AsyncHTTPTransport):
def __init__(
self,
algorithm: str = "SHA-256",
@ -52,20 +64,26 @@ class MockDigestAuthDispatch(AsyncDispatcher):
self._regenerate_nonce = regenerate_nonce
self._response_count = 0
async def send(
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
if self._response_count < self.send_response_after_attempt:
return self.challenge_send(request)
return self.challenge_send(method, url, headers, stream)
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
return Response(200, content=body, request=request)
authorization = get_header_value(headers, "Authorization")
body = JSONStream({"auth": authorization})
return b"HTTP/1.1", 200, b"", [], body
def challenge_send(self, request: Request) -> Response:
def challenge_send(
self, method: bytes, url: URL, headers: Headers, stream: ContentStream,
) -> typing.Tuple[int, bytes, Headers, ContentStream]:
self._response_count += 1
nonce = (
hashlib.sha256(os.urandom(8)).hexdigest()
@ -88,9 +106,12 @@ class MockDigestAuthDispatch(AsyncDispatcher):
)
headers = [
("www-authenticate", 'Digest realm="httpx@example.org", ' + challenge_str)
(
b"www-authenticate",
b'Digest realm="httpx@example.org", ' + challenge_str.encode("ascii"),
)
]
return Response(401, headers=headers, content=b"", request=request)
return b"HTTP/1.1", 401, b"", headers, ContentStream()
@pytest.mark.asyncio
@ -234,7 +255,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
async def test_digest_auth_200_response_including_digest_auth_header() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
auth_header = b'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
client = AsyncClient(
dispatch=MockDispatch(auth_header=auth_header, status_code=200)
@ -251,7 +272,7 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
client = AsyncClient(dispatch=MockDispatch(auth_header="", status_code=401))
client = AsyncClient(dispatch=MockDispatch(auth_header=b"", status_code=401))
response = await client.get(url, auth=auth)
assert response.status_code == 401
@ -382,16 +403,16 @@ async def test_digest_auth_incorrect_credentials() -> None:
@pytest.mark.parametrize(
"auth_header",
[
'Digest realm="httpx@example.org", qop="auth"', # missing fields
'realm="httpx@example.org", qop="auth"', # not starting with Digest
'DigestZ realm="httpx@example.org", qop="auth"'
'qop="auth,auth-int",nonce="abc",opaque="xyz"',
'Digest realm="httpx@example.org", qop="auth,au', # malformed fields list
b'Digest realm="httpx@example.org", qop="auth"', # missing fields
b'realm="httpx@example.org", qop="auth"', # not starting with Digest
b'DigestZ realm="httpx@example.org", qop="auth"'
b'qop="auth,auth-int",nonce="abc",opaque="xyz"',
b'Digest realm="httpx@example.org", qop="auth,au', # malformed fields list
],
)
@pytest.mark.asyncio
async def test_digest_auth_raises_protocol_error_on_malformed_header(
auth_header: str,
auth_header: bytes,
) -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
@ -439,7 +460,7 @@ async def test_auth_history() -> None:
url = "https://example.org/"
auth = RepeatAuth(repeat=2)
client = AsyncClient(dispatch=MockDispatch(auth_header="abc"))
client = AsyncClient(dispatch=MockDispatch(auth_header=b"abc"))
response = await client.get(url, auth=auth)
assert response.status_code == 200

View File

@ -1,29 +1,43 @@
import json
import typing
from http.cookiejar import Cookie, CookieJar
import httpcore
import pytest
from httpx import AsyncClient, Cookies, Request, Response
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
from httpx._dispatch.base import AsyncDispatcher
from httpx import AsyncClient, Cookies
from httpx._content_streams import ByteStream, ContentStream, JSONStream
class MockDispatch(AsyncDispatcher):
async def send(
def get_header_value(headers, key, default=None):
lookup = key.encode("ascii").lower()
for header_key, header_value in headers:
if header_key.lower() == lookup:
return header_value.decode("ascii")
return default
class MockDispatch(httpcore.AsyncHTTPTransport):
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path.startswith("/echo_cookies"):
body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
return Response(200, content=body, request=request)
elif request.url.path.startswith("/set_cookie"):
headers = {"set-cookie": "example-name=example-value"}
return Response(200, headers=headers, request=request)
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
host, scheme, port, path = url
if path.startswith(b"/echo_cookies"):
cookie = get_header_value(headers, "cookie")
body = JSONStream({"cookies": cookie})
return b"HTTP/1.1", 200, b"OK", [], body
elif path.startswith(b"/set_cookie"):
headers = [(b"set-cookie", b"example-name=example-value")]
body = ByteStream(b"")
return b"HTTP/1.1", 200, b"OK", headers, body
else:
raise NotImplementedError # pragma: no cover
raise NotImplementedError() # pragma: no cover
@pytest.mark.asyncio

View File

@ -1,26 +1,30 @@
#!/usr/bin/env python3
import json
import typing
import httpcore
import pytest
from httpx import AsyncClient, Headers, Request, Response, __version__
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
from httpx._dispatch.base import AsyncDispatcher
from httpx import AsyncClient, Headers, __version__
from httpx._content_streams import ContentStream, JSONStream
class MockDispatch(AsyncDispatcher):
async def send(
class MockDispatch(httpcore.AsyncHTTPTransport):
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path.startswith("/echo_headers"):
request_headers = dict(request.headers.items())
body = json.dumps({"headers": request_headers}).encode()
return Response(200, content=body, request=request)
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
headers_dict = dict(
[(key.decode("ascii"), value.decode("ascii")) for key, value in headers]
)
body = JSONStream({"headers": headers_dict})
return b"HTTP/1.1", 200, b"OK", [], body
@pytest.mark.asyncio

View File

@ -24,7 +24,7 @@ def test_proxies_parameter(proxies, expected_proxies):
for proxy_key, url in expected_proxies:
assert proxy_key in client.proxies
assert client.proxies[proxy_key].proxy_url == url
assert client.proxies[proxy_key].proxy_origin == httpx.URL(url).raw[:3]
assert len(expected_proxies) == len(client.proxies)
@ -81,7 +81,7 @@ def test_dispatcher_for_request(url, proxies, expected):
if expected is None:
assert dispatcher is client.dispatch
else:
assert dispatcher.proxy_url == expected
assert dispatcher.proxy_origin == httpx.URL(expected).raw[:3]
def test_unsupported_proxy_scheme():
@ -115,4 +115,4 @@ def test_proxies_environ(monkeypatch, url, env, expected):
if expected is None:
assert dispatcher == client.dispatch
else:
assert dispatcher.proxy_url == expected
assert dispatcher.proxy_origin == httpx.URL(expected).raw[:3]

View File

@ -1,23 +1,26 @@
import json
import typing
import httpcore
import pytest
from httpx import URL, AsyncClient, QueryParams, Request, Response
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
from httpx._dispatch.base import AsyncDispatcher
from httpx import URL, AsyncClient, Headers, QueryParams
from httpx._content_streams import ContentStream, JSONStream
class MockDispatch(AsyncDispatcher):
async def send(
class MockDispatch(httpcore.AsyncHTTPTransport):
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path.startswith("/echo_queryparams"):
body = json.dumps({"ok": "ok"}).encode()
return Response(200, content=body, request=request)
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
headers = Headers()
body = JSONStream({"ok": "ok"})
return b"HTTP/1.1", 200, b"OK", headers, body
def test_client_queryparams():

View File

@ -1,118 +1,146 @@
import json
import typing
from urllib.parse import parse_qs
import httpcore
import pytest
from httpx import (
URL,
AsyncClient,
NotRedirectResponse,
Request,
RequestBodyUnavailable,
Response,
TooManyRedirects,
codes,
)
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
from httpx._content_streams import AsyncIteratorStream
from httpx._dispatch.base import AsyncDispatcher
from httpx._content_streams import AsyncIteratorStream, ByteStream, ContentStream
class MockDispatch(AsyncDispatcher):
async def send(
def get_header_value(headers, key, default=None):
lookup = key.encode("ascii").lower()
for header_key, header_value in headers:
if header_key.lower() == lookup:
return header_value.decode("ascii")
return default
class MockDispatch(httpcore.AsyncHTTPTransport):
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path == "/no_redirect":
return Response(codes.OK, request=request)
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
scheme, host, port, path = url
path, _, query = path.partition(b"?")
if path == b"/no_redirect":
return b"HTTP/1.1", codes.OK, b"OK", [], ByteStream(b"")
elif request.url.path == "/redirect_301":
elif path == b"/redirect_301":
async def body():
yield b"<a href='https://example.org/'>here</a>"
status_code = codes.MOVED_PERMANENTLY
headers = {"location": "https://example.org/"}
headers = [(b"location", b"https://example.org/")]
stream = AsyncIteratorStream(aiterator=body())
return Response(
status_code, stream=stream, headers=headers, request=request
)
return b"HTTP/1.1", status_code, b"Moved Permanently", headers, stream
elif request.url.path == "/redirect_302":
elif path == b"/redirect_302":
status_code = codes.FOUND
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
headers = [(b"location", b"https://example.org/")]
return b"HTTP/1.1", status_code, b"Found", headers, ByteStream(b"")
elif request.url.path == "/redirect_303":
elif path == b"/redirect_303":
status_code = codes.SEE_OTHER
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
headers = [(b"location", b"https://example.org/")]
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/relative_redirect":
headers = {"location": "/"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif path == b"/relative_redirect":
status_code = codes.SEE_OTHER
headers = [(b"location", b"/")]
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/malformed_redirect":
headers = {"location": "https://:443/"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif path == b"/malformed_redirect":
status_code = codes.SEE_OTHER
headers = [(b"location", b"https://:443/")]
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/no_scheme_redirect":
headers = {"location": "//example.org/"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif path == b"/no_scheme_redirect":
status_code = codes.SEE_OTHER
headers = [(b"location", b"//example.org/")]
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
elif path == b"/multiple_redirects":
params = parse_qs(query.decode("ascii"))
count = int(params.get("count", "0")[0])
redirect_count = count - 1
code = codes.SEE_OTHER if count else codes.OK
location = "/multiple_redirects"
phrase = b"See Other" if count else b"OK"
location = b"/multiple_redirects"
if redirect_count:
location += "?count=" + str(redirect_count)
headers = {"location": location} if count else {}
return Response(code, headers=headers, request=request)
location += b"?count=" + str(redirect_count).encode("ascii")
headers = [(b"location", location)] if count else []
return b"HTTP/1.1", code, phrase, headers, ByteStream(b"")
if request.url.path == "/redirect_loop":
headers = {"location": "/redirect_loop"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
if path == b"/redirect_loop":
code = codes.SEE_OTHER
headers = [(b"location", b"/redirect_loop")]
return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/cross_domain":
headers = {"location": "https://example.org/cross_domain_target"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif path == b"/cross_domain":
code = codes.SEE_OTHER
headers = [(b"location", b"https://example.org/cross_domain_target")]
return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/cross_domain_target":
headers = dict(request.headers.items())
content = json.dumps({"headers": headers}).encode()
return Response(codes.OK, content=content, request=request)
elif path == b"/cross_domain_target":
headers_dict = dict(
[(key.decode("ascii"), value.decode("ascii")) for key, value in headers]
)
content = ByteStream(json.dumps({"headers": headers_dict}).encode())
return b"HTTP/1.1", 200, b"OK", [], content
elif request.url.path == "/redirect_body":
body = b"".join([part async for part in request.stream])
headers = {"location": "/redirect_body_target"}
return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
elif path == b"/redirect_body":
_ = b"".join([part async for part in stream])
code = codes.PERMANENT_REDIRECT
headers = [(b"location", b"/redirect_body_target")]
return b"HTTP/1.1", code, b"Permanent Redirect", headers, ByteStream(b"")
elif request.url.path == "/redirect_no_body":
content = b"".join([part async for part in request.stream])
headers = {"location": "/redirect_body_target"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif path == b"/redirect_no_body":
_ = b"".join([part async for part in stream])
code = codes.SEE_OTHER
headers = [(b"location", b"/redirect_body_target")]
return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
elif request.url.path == "/redirect_body_target":
content = b"".join([part async for part in request.stream])
headers = dict(request.headers.items())
body = json.dumps({"body": content.decode(), "headers": headers}).encode()
return Response(codes.OK, content=body, request=request)
elif path == b"/redirect_body_target":
content = b"".join([part async for part in stream])
headers_dict = dict(
[(key.decode("ascii"), value.decode("ascii")) for key, value in headers]
)
body = ByteStream(
json.dumps({"body": content.decode(), "headers": headers_dict}).encode()
)
return b"HTTP/1.1", 200, b"OK", [], body
elif request.url.path == "/cross_subdomain":
if request.headers["host"] != "www.example.org":
headers = {"location": "https://www.example.org/cross_subdomain"}
return Response(
codes.PERMANENT_REDIRECT, headers=headers, request=request
elif path == b"/cross_subdomain":
host = get_header_value(headers, "host")
if host != "www.example.org":
headers = [(b"location", b"https://www.example.org/cross_subdomain")]
return (
b"HTTP/1.1",
codes.PERMANENT_REDIRECT,
b"Permanent Redirect",
headers,
ByteStream(b""),
)
else:
return Response(codes.OK, content=b"Hello, world!", request=request)
return b"HTTP/1.1", 200, b"OK", [], ByteStream(b"Hello, world!")
return Response(codes.OK, content=b"Hello, world!", request=request)
return b"HTTP/1.1", 200, b"OK", [], ByteStream(b"Hello, world!")
@pytest.mark.usefixtures("async_environment")
@ -326,43 +354,53 @@ async def test_cross_subdomain_redirect():
assert response.url == URL("https://www.example.org/cross_subdomain")
class MockCookieDispatch(AsyncDispatcher):
async def send(
class MockCookieDispatch(httpcore.AsyncHTTPTransport):
async def request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path == "/":
if "cookie" in request.headers:
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]],
stream: ContentStream,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
]:
scheme, host, port, path = url
if path == b"/":
cookie = get_header_value(headers, "Cookie")
if cookie is not None:
content = b"Logged in"
else:
content = b"Not logged in"
return Response(codes.OK, content=content, request=request)
return b"HTTP/1.1", 200, b"OK", [], ByteStream(content)
elif request.url.path == "/login":
elif path == b"/login":
status_code = codes.SEE_OTHER
headers = {
"location": "/",
"set-cookie": (
"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; "
"httponly; samesite=lax"
headers = [
(b"location", b"/"),
(
b"set-cookie",
(
b"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; "
b"httponly; samesite=lax"
),
),
}
]
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/logout":
elif path == b"/logout":
status_code = codes.SEE_OTHER
headers = {
"location": "/",
"set-cookie": (
"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; "
"httponly; samesite=lax"
headers = [
(b"location", b"/"),
(
b"set-cookie",
(
b"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; "
b"httponly; samesite=lax"
),
),
}
return Response(status_code, headers=headers, request=request)
]
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
@pytest.mark.usefixtures("async_environment")

View File

@ -295,15 +295,6 @@ def server():
yield from serve_in_thread(server)
@pytest.fixture(scope=SERVER_SCOPE)
def uds_server():
uds = "test_server.sock"
config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
server = TestServer(config=config)
yield from serve_in_thread(server)
os.remove(uds)
@pytest.fixture(scope=SERVER_SCOPE)
def https_server(cert_pem_file, cert_private_key_file):
config = Config(
@ -317,19 +308,3 @@ def https_server(cert_pem_file, cert_private_key_file):
)
server = TestServer(config=config)
yield from serve_in_thread(server)
@pytest.fixture(scope=SERVER_SCOPE)
def https_uds_server(cert_pem_file, cert_private_key_file):
uds = "https_test_server.sock"
config = Config(
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
uds=uds,
loop="asyncio",
)
server = TestServer(config=config)
yield from serve_in_thread(server)
os.remove(uds)

View File

@ -1,246 +0,0 @@
import pytest
import httpx
from httpx._dispatch.connection_pool import ConnectionPool
@pytest.mark.usefixtures("async_environment")
async def test_keepalive_connections(server):
"""
Connections should default to staying in a keep-alive state.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.usefixtures("async_environment")
async def test_keepalive_timeout(server):
"""
Keep-alive connections should timeout.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
http.next_keepalive_check = 0.0
await http.check_keepalive_expiry()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
async with ConnectionPool() as http:
http.KEEP_ALIVE_EXPIRY = 0.0
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
http.next_keepalive_check = 0.0
await http.check_keepalive_expiry()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
@pytest.mark.usefixtures("async_environment")
async def test_differing_connection_keys(server):
"""
Connections to differing connection keys should result in multiple connections.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@pytest.mark.usefixtures("async_environment")
async def test_soft_limit(server):
"""
The soft_limit config should limit the maximum number of keep-alive connections.
"""
pool_limits = httpx.PoolLimits(soft_limit=1)
async with ConnectionPool(pool_limits=pool_limits) as http:
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.usefixtures("async_environment")
async def test_streaming_response_holds_connection(server):
"""
A streaming request should hold the connection open until the response is read.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.usefixtures("async_environment")
async def test_multiple_concurrent_connections(server):
"""
Multiple concurrent requests should open multiple concurrent connections.
"""
async with ConnectionPool() as http:
response_a = await http.request("GET", server.url)
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
response_b = await http.request("GET", server.url)
assert len(http.active_connections) == 2
assert len(http.keepalive_connections) == 0
await response_b.aread()
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 1
await response_a.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@pytest.mark.usefixtures("async_environment")
async def test_close_connections(server):
"""
Using a `Connection: close` header should close the connection.
"""
headers = [(b"connection", b"close")]
async with ConnectionPool() as http:
response = await http.request("GET", server.url, headers=headers)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
@pytest.mark.usefixtures("async_environment")
async def test_standard_response_close(server):
"""
A standard close should keep the connection open.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aread()
await response.aclose()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.usefixtures("async_environment")
async def test_premature_response_close(server):
"""
A premature close should close the connection.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aclose()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
@pytest.mark.usefixtures("async_environment")
async def test_keepalive_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection
should be reestablished.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aread()
# Shutdown the server to close the keep-alive connection
await server.restart()
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.usefixtures("async_environment")
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection
should be reestablished.
"""
async with ConnectionPool() as http:
response = await http.request("GET", server.url)
await response.aread()
# Shutdown the server to close the keep-alive connection
await server.restart()
response = await http.request("GET", server.url)
await response.aread()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.usefixtures("async_environment")
async def test_connection_closed_free_semaphore_on_acquire(server):
"""
Verify that max_connections semaphore is released
properly on a disconnected connection.
"""
async with ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
response = await http.request("GET", server.url)
await response.aread()
# Close the connection so we're forced to recycle it
await server.restart()
response = await http.request("GET", server.url)
assert response.status_code == 200
@pytest.mark.usefixtures("async_environment")
async def test_connection_pool_closed_close_keepalive_and_free_semaphore(server):
"""
Closing the connection pool should close remaining keepalive connections and
release the max_connections semaphore.
"""
http = ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1))
async with http:
response = await http.request("GET", server.url)
await response.aread()
assert response.status_code == 200
assert len(http.keepalive_connections) == 1
assert len(http.keepalive_connections) == 0
# Perform a second round of requests to make sure the max_connections semaphore
# was released properly.
async with http:
response = await http.request("GET", server.url)
await response.aread()
assert response.status_code == 200

View File

@ -1,44 +0,0 @@
import pytest
import httpx
from httpx._config import SSLConfig
from httpx._dispatch.connection import HTTPConnection
@pytest.mark.usefixtures("async_environment")
async def test_get(server):
async with HTTPConnection(origin=server.url) as conn:
response = await conn.request("GET", server.url)
await response.aread()
assert response.status_code == 200
assert response.content == b"Hello, world!"
@pytest.mark.usefixtures("async_environment")
async def test_post(server):
async with HTTPConnection(origin=server.url) as conn:
response = await conn.request("GET", server.url, data=b"Hello, world!")
assert response.status_code == 200
@pytest.mark.usefixtures("async_environment")
async def test_premature_close(server):
with pytest.raises(httpx.ConnectionClosed):
async with HTTPConnection(origin=server.url) as conn:
response = await conn.request(
"GET", server.url.copy_with(path="/premature_close")
)
await response.aread()
@pytest.mark.usefixtures("async_environment")
async def test_https_get_with_ssl(https_server, ca_cert_pem_file):
"""
An HTTPS request, with SSL configuration set on the client.
"""
ssl = SSLConfig(verify=ca_cert_pem_file)
async with HTTPConnection(origin=https_server.url, ssl=ssl) as conn:
response = await conn.request("GET", https_server.url)
await response.aread()
assert response.status_code == 200
assert response.content == b"Hello, world!"

View File

@ -1,158 +0,0 @@
import json
import socket
import h2.connection
import h2.events
import pytest
from h2.settings import SettingCodes
from httpx import AsyncClient, Response, TimeoutException
from httpx._dispatch.connection_pool import ConnectionPool
from .utils import MockHTTP2Backend
async def app(request):
method = request.method
path = request.url.path
body = b"".join([part async for part in request.stream])
content = json.dumps(
{"method": method, "path": path, "body": body.decode()}
).encode()
headers = {"Content-Length": str(len(content))}
return Response(200, headers=headers, content=content, request=request)
@pytest.mark.asyncio
async def test_http2_get_request():
backend = MockHTTP2Backend(app=app)
dispatch = ConnectionPool(backend=backend, http2=True)
async with AsyncClient(dispatch=dispatch) as client:
response = await client.get("http://example.org")
assert response.status_code == 200
assert json.loads(response.content) == {"method": "GET", "path": "/", "body": ""}
@pytest.mark.asyncio
async def test_http2_post_request():
backend = MockHTTP2Backend(app=app)
dispatch = ConnectionPool(backend=backend, http2=True)
async with AsyncClient(dispatch=dispatch) as client:
response = await client.post("http://example.org", data=b"<data>")
assert response.status_code == 200
assert json.loads(response.content) == {
"method": "POST",
"path": "/",
"body": "<data>",
}
@pytest.mark.asyncio
async def test_http2_large_post_request():
backend = MockHTTP2Backend(app=app)
dispatch = ConnectionPool(backend=backend, http2=True)
data = b"a" * 100000
async with AsyncClient(dispatch=dispatch) as client:
response = await client.post("http://example.org", data=data)
assert response.status_code == 200
assert json.loads(response.content) == {
"method": "POST",
"path": "/",
"body": data.decode(),
}
@pytest.mark.asyncio
async def test_http2_multiple_requests():
backend = MockHTTP2Backend(app=app)
dispatch = ConnectionPool(backend=backend, http2=True)
async with AsyncClient(dispatch=dispatch) as client:
response_1 = await client.get("http://example.org/1")
response_2 = await client.get("http://example.org/2")
response_3 = await client.get("http://example.org/3")
assert response_1.status_code == 200
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
assert response_2.status_code == 200
assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
assert response_3.status_code == 200
assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
@pytest.mark.asyncio
async def test_http2_reconnect():
"""
If a connection has been dropped between requests, then we should
be seemlessly reconnected.
"""
backend = MockHTTP2Backend(app=app)
dispatch = ConnectionPool(backend=backend, http2=True)
async with AsyncClient(dispatch=dispatch) as client:
response_1 = await client.get("http://example.org/1")
backend.server.close_connection = True
response_2 = await client.get("http://example.org/2")
assert response_1.status_code == 200
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
assert response_2.status_code == 200
assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
@pytest.mark.asyncio
async def test_http2_settings_in_handshake():
backend = MockHTTP2Backend(app=app)
dispatch = ConnectionPool(backend=backend, http2=True)
async with AsyncClient(dispatch=dispatch) as client:
await client.get("http://example.org")
h2_conn = backend.server.conn
assert isinstance(h2_conn, h2.connection.H2Connection)
expected_settings = {
SettingCodes.HEADER_TABLE_SIZE: 4096,
SettingCodes.ENABLE_PUSH: 0,
SettingCodes.MAX_CONCURRENT_STREAMS: 100,
SettingCodes.INITIAL_WINDOW_SIZE: 65535,
SettingCodes.MAX_FRAME_SIZE: 16384,
SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
# This one's here because h2 helpfully populates remote_settings
# with default values even if the peer doesn't send the setting.
SettingCodes.ENABLE_CONNECT_PROTOCOL: 0,
}
assert dict(h2_conn.remote_settings) == expected_settings
# We don't expect the ENABLE_CONNECT_PROTOCOL to be in the handshake
expected_settings.pop(SettingCodes.ENABLE_CONNECT_PROTOCOL)
assert len(backend.server.settings_changed) == 1
settings = backend.server.settings_changed[0]
assert isinstance(settings, h2.events.RemoteSettingsChanged)
assert len(settings.changed_settings) == len(expected_settings)
for setting_code, changed_setting in settings.changed_settings.items():
assert isinstance(changed_setting, h2.settings.ChangedSetting)
assert changed_setting.new_value == expected_settings[setting_code]
@pytest.mark.usefixtures("async_environment")
async def test_http2_live_request():
async with AsyncClient(http2=True) as client:
try:
resp = await client.get("https://nghttp2.org/httpbin/anything")
except TimeoutException: # pragma: nocover
pytest.xfail(reason="nghttp2.org appears to be unresponsive")
except socket.gaierror: # pragma: nocover
pytest.xfail(reason="You appear to be offline")
assert resp.status_code == 200
assert resp.http_version == "HTTP/2"

View File

@ -1,207 +0,0 @@
import pytest
import httpx
from httpx._dispatch.proxy_http import HTTPProxy
from .utils import MockRawSocketBackend
@pytest.mark.asyncio
async def test_proxy_tunnel_success():
raw_io = MockRawSocketBackend(
data_to_send=(
[
b"HTTP/1.1 200 OK\r\n"
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: proxy-server\r\n"
b"\r\n",
b"HTTP/1.1 404 Not Found\r\n"
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: origin-server\r\n"
b"\r\n",
]
),
)
async with HTTPProxy(
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
) as proxy:
response = await proxy.request("GET", "http://example.com")
assert response.status_code == 404
assert response.headers["Server"] == "origin-server"
assert response.request.method == "GET"
assert response.request.url == "http://example.com"
assert response.request.headers["Host"] == "example.com"
recv = raw_io.received_data
assert len(recv) == 3
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
assert recv[1].startswith(
b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
)
assert recv[2].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
@pytest.mark.asyncio
@pytest.mark.parametrize("status_code", [300, 304, 308, 401, 500])
async def test_proxy_tunnel_non_2xx_response(status_code):
raw_io = MockRawSocketBackend(
data_to_send=(
[
b"HTTP/1.1 %d Not Good\r\n" % status_code,
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: proxy-server\r\n"
b"\r\n",
]
),
)
with pytest.raises(httpx.ProxyError) as e:
async with HTTPProxy(
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
) as proxy:
await proxy.request("GET", "http://example.com")
# ProxyError.request should be the CONNECT request not the original request
assert e.value.request.method == "CONNECT"
assert e.value.request.headers["Host"] == "127.0.0.1:8000"
assert e.value.request.url.full_path == "example.com:80"
# ProxyError.response should be the CONNECT response
assert e.value.response.status_code == status_code
assert e.value.response.headers["Server"] == "proxy-server"
# Verify that the request wasn't sent after receiving an error from CONNECT
recv = raw_io.received_data
assert len(recv) == 2
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
assert recv[1].startswith(
b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
)
@pytest.mark.asyncio
async def test_proxy_tunnel_start_tls():
raw_io = MockRawSocketBackend(
data_to_send=(
[
# Tunnel Response
b"HTTP/1.1 200 OK\r\n"
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: proxy-server\r\n"
b"\r\n",
# Response 1
b"HTTP/1.1 404 Not Found\r\n"
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: origin-server\r\n"
b"Connection: keep-alive\r\n"
b"Content-Length: 0\r\n"
b"\r\n",
# Response 2
b"HTTP/1.1 200 OK\r\n"
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: origin-server\r\n"
b"Connection: keep-alive\r\n"
b"Content-Length: 0\r\n"
b"\r\n",
]
),
)
async with HTTPProxy(
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
) as proxy:
resp = await proxy.request("GET", "https://example.com")
assert resp.status_code == 404
assert resp.headers["Server"] == "origin-server"
assert resp.request.method == "GET"
assert resp.request.url == "https://example.com"
assert resp.request.headers["Host"] == "example.com"
await resp.aread()
# Make another request to see that the tunnel is re-used.
resp = await proxy.request("GET", "https://example.com/target")
assert resp.status_code == 200
assert resp.headers["Server"] == "origin-server"
assert resp.request.method == "GET"
assert resp.request.url == "https://example.com/target"
assert resp.request.headers["Host"] == "example.com"
await resp.aread()
recv = raw_io.received_data
assert len(recv) == 5
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
assert recv[1].startswith(
b"CONNECT example.com:443 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
)
assert recv[2] == b"--- START_TLS(example.com) ---"
assert recv[3].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
assert recv[4].startswith(b"GET /target HTTP/1.1\r\nhost: example.com\r\n")
@pytest.mark.asyncio
@pytest.mark.parametrize("proxy_mode", ["FORWARD_ONLY", "DEFAULT"])
async def test_proxy_forwarding(proxy_mode):
raw_io = MockRawSocketBackend(
data_to_send=(
[
b"HTTP/1.1 200 OK\r\n"
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
b"Server: origin-server\r\n"
b"\r\n"
]
),
)
async with HTTPProxy(
proxy_url="http://127.0.0.1:8000",
backend=raw_io,
proxy_mode=proxy_mode,
proxy_headers={"Proxy-Authorization": "test", "Override": "2"},
) as proxy:
response = await proxy.request(
"GET", "http://example.com", headers={"override": "1"}
)
assert response.status_code == 200
assert response.headers["Server"] == "origin-server"
assert response.request.method == "GET"
assert response.request.url == "http://127.0.0.1:8000"
assert response.request.url.full_path == "http://example.com"
assert response.request.headers["Host"] == "example.com"
recv = raw_io.received_data
assert len(recv) == 2
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
assert recv[1].startswith(
b"GET http://example.com HTTP/1.1\r\nhost: example.com\r\n"
)
assert b"proxy-authorization: test" in recv[1]
assert b"override: 1" in recv[1]
def test_proxy_url_with_username_and_password():
proxy = HTTPProxy("http://user:password@example.com:1080")
assert proxy.proxy_url == "http://example.com:1080"
assert proxy.proxy_headers["Proxy-Authorization"] == "Basic dXNlcjpwYXNzd29yZA=="
def test_proxy_repr():
proxy = HTTPProxy(
"http://127.0.0.1:1080",
proxy_headers={"Custom": "Header"},
proxy_mode="DEFAULT",
)
assert repr(proxy) == (
"HTTPProxy(proxy_url=URL('http://127.0.0.1:1080') "
"proxy_headers=Headers({'custom': 'Header'}) "
"proxy_mode='DEFAULT')"
)

View File

@ -1,204 +0,0 @@
import ssl
import typing
import h2.config
import h2.connection
import h2.events
from httpx import Request, Timeout
from httpx._backends.base import BaseSocketStream, lookup_backend
class MockHTTP2Backend:
def __init__(self, app):
self.app = app
self.backend = lookup_backend()
self.server = None
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> BaseSocketStream:
self.server = MockHTTP2Server(self.app, backend=self.backend)
return self.server
# Defer all other attributes and methods to the underlying backend.
def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)
class MockHTTP2Server(BaseSocketStream):
def __init__(self, app, backend: MockHTTP2Backend):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
self.app = app
self.backend = backend
self.buffer = b""
self.requests = {}
self.close_connection = False
self.return_data = {}
self.settings_changed = []
# Socket stream interface
def get_http_version(self) -> str:
return "HTTP/2"
async def read(self, n, timeout, flag=None) -> bytes:
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send
async def write(self, data: bytes, timeout) -> None:
if not data:
return
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events:
if isinstance(event, h2.events.RequestReceived):
self.request_received(event.headers, event.stream_id)
elif isinstance(event, h2.events.DataReceived):
self.receive_data(event.data, event.stream_id)
# This should send an UPDATE_WINDOW for both the stream and the
# connection increasing it by the amount
# consumed keeping the flow control window constant
flow_control_consumed = event.flow_controlled_length
if flow_control_consumed > 0:
self.conn.increment_flow_control_window(flow_control_consumed)
self.buffer += self.conn.data_to_send()
self.conn.increment_flow_control_window(
flow_control_consumed, event.stream_id
)
self.buffer += self.conn.data_to_send()
elif isinstance(event, h2.events.StreamEnded):
await self.stream_complete(event.stream_id)
elif isinstance(event, h2.events.RemoteSettingsChanged):
self.settings_changed.append(event)
async def close(self) -> None:
pass
def is_connection_dropped(self) -> bool:
return self.close_connection
# Server implementation
def request_received(self, headers, stream_id):
"""
Handler for when the initial part of the HTTP request is received.
"""
if stream_id not in self.requests:
self.requests[stream_id] = []
self.requests[stream_id].append({"headers": headers, "data": b""})
def receive_data(self, data, stream_id):
"""
Handler for when a data part of the HTTP request is received.
"""
self.requests[stream_id][-1]["data"] += data
async def stream_complete(self, stream_id):
"""
Handler for when the HTTP request is completed.
"""
request = self.requests[stream_id].pop(0)
if not self.requests[stream_id]:
del self.requests[stream_id]
headers_dict = dict(request["headers"])
method = headers_dict[b":method"].decode("ascii")
url = "%s://%s%s" % (
headers_dict[b":scheme"].decode("ascii"),
headers_dict[b":authority"].decode("ascii"),
headers_dict[b":path"].decode("ascii"),
)
headers = [(k, v) for k, v in request["headers"] if not k.startswith(b":")]
data = request["data"]
# Call out to the app.
request = Request(method, url, headers=headers, data=data)
response = await self.app(request)
# Write the response to the buffer.
status_code_bytes = str(response.status_code).encode("ascii")
response_headers = [(b":status", status_code_bytes)] + response.headers.raw
self.conn.send_headers(stream_id, response_headers)
self.buffer += self.conn.data_to_send()
self.return_data[stream_id] = response.content
self.send_return_data(stream_id)
def send_return_data(self, stream_id):
while self.return_data[stream_id]:
flow_control = self.conn.local_flow_control_window(stream_id)
chunk_size = min(
len(self.return_data[stream_id]),
flow_control,
self.conn.max_outbound_frame_size,
)
if chunk_size > 0:
chunk, self.return_data[stream_id] = (
self.return_data[stream_id][:chunk_size],
self.return_data[stream_id][chunk_size:],
)
self.conn.send_data(stream_id, chunk)
self.buffer += self.conn.data_to_send()
self.conn.end_stream(stream_id)
self.buffer += self.conn.data_to_send()
class MockRawSocketBackend:
def __init__(self, data_to_send=b""):
self.backend = lookup_backend()
self.data_to_send = data_to_send
self.received_data = []
self.stream = MockRawSocketStream(self)
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> BaseSocketStream:
self.received_data.append(
b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port)
)
return self.stream
# Defer all other attributes and methods to the underlying backend.
def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)
class MockRawSocketStream(BaseSocketStream):
def __init__(self, backend: MockRawSocketBackend):
self.backend = backend
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
) -> BaseSocketStream:
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return MockRawSocketStream(self.backend)
def get_http_version(self) -> str:
return "HTTP/1.1"
async def write(self, data: bytes, timeout) -> None:
if not data:
return
self.backend.received_data.append(data)
async def read(self, n, timeout, flag=None) -> bytes:
if not self.backend.data_to_send:
return b""
return self.backend.data_to_send.pop(0)
def is_connection_dropped(self) -> bool:
return False
async def close(self) -> None:
pass

View File

@ -56,33 +56,6 @@ async def test_start_tls_on_tcp_socket_stream(https_server):
await stream.close()
@pytest.mark.usefixtures("async_environment")
async def test_start_tls_on_uds_socket_stream(https_uds_server):
backend = lookup_backend()
ctx = SSLConfig().load_ssl_context_no_verify()
timeout = Timeout(5)
stream = await backend.open_uds_stream(
https_uds_server.config.uds, https_uds_server.url.host, None, timeout
)
try:
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is None
stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
response = await read_response(stream, timeout, should_contain=b"Hello, world")
assert response.startswith(b"HTTP/1.1 200 OK\r\n")
finally:
await stream.close()
@pytest.mark.usefixtures("async_environment")
async def test_concurrent_read(server):
"""

View File

@ -2,27 +2,34 @@ import binascii
import cgi
import io
import os
import typing
from unittest import mock
import httpcore
import pytest
import httpx
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
from httpx._content_streams import encode
from httpx._dispatch.base import AsyncDispatcher
from httpx._content_streams import AsyncIteratorStream, encode
from httpx._utils import format_form_param
class MockDispatch(AsyncDispatcher):
async def send(
class MockDispatch(httpcore.AsyncHTTPTransport):
async def request(
self,
request: httpx.Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> httpx.Response:
content = b"".join([part async for part in request.stream])
return httpx.Response(200, content=content, request=request)
method: bytes,
url: typing.Tuple[bytes, bytes, int, bytes],
headers: typing.List[typing.Tuple[bytes, bytes]] = None,
stream: httpcore.AsyncByteStream = None,
timeout: typing.Dict[str, typing.Optional[float]] = None,
) -> typing.Tuple[
bytes,
int,
bytes,
typing.List[typing.Tuple[bytes, bytes]],
httpcore.AsyncByteStream,
]:
content = AsyncIteratorStream(aiterator=(part async for part in stream))
return b"HTTP/1.1", 200, b"OK", [], content
@pytest.mark.parametrize(("value,output"), (("abc", b"abc"), (b"abc", b"abc")))

View File

@ -98,7 +98,6 @@ async def test_logs_debug(server, capsys):
assert response.status_code == 200
stderr = capsys.readouterr().err
assert 'HTTP Request: GET http://127.0.0.1:8000/ "HTTP/1.1 200 OK"' in stderr
assert "httpx._dispatch.connection_pool" not in stderr
@pytest.mark.asyncio
@ -109,7 +108,6 @@ async def test_logs_trace(server, capsys):
assert response.status_code == 200
stderr = capsys.readouterr().err
assert 'HTTP Request: GET http://127.0.0.1:8000/ "HTTP/1.1 200 OK"' in stderr
assert "httpx._dispatch.connection_pool" in stderr
@pytest.mark.asyncio