Httpcore interface (#804)
* First pass as switching dispatchers over to httpcore interface * Updates for httpcore interface * headers in dispatch API as plain list of bytes * Integrate against httpcore 0.6 * Integrate against httpcore interface * Drop UDS, since not supported by httpcore * Fix base class for mock dispatchers in tests * Merge master and mark as potential '0.13.dev0' release
This commit is contained in:
parent
631ba97635
commit
3046e920ea
@ -6,7 +6,6 @@ from ._config import PoolLimits, Proxy, Timeout
|
||||
from ._dispatch.asgi import ASGIDispatch
|
||||
from ._dispatch.wsgi import WSGIDispatch
|
||||
from ._exceptions import (
|
||||
ConnectionClosed,
|
||||
ConnectTimeout,
|
||||
CookieConflict,
|
||||
DecodingError,
|
||||
@ -23,7 +22,6 @@ from ._exceptions import (
|
||||
ResponseClosed,
|
||||
ResponseNotRead,
|
||||
StreamConsumed,
|
||||
TimeoutException,
|
||||
TooManyRedirects,
|
||||
WriteTimeout,
|
||||
)
|
||||
@ -56,7 +54,6 @@ __all__ = [
|
||||
"Timeout",
|
||||
"ConnectTimeout",
|
||||
"CookieConflict",
|
||||
"ConnectionClosed",
|
||||
"DecodingError",
|
||||
"HTTPError",
|
||||
"InvalidURL",
|
||||
@ -79,7 +76,6 @@ __all__ = [
|
||||
"Headers",
|
||||
"QueryParams",
|
||||
"Request",
|
||||
"TimeoutException",
|
||||
"Response",
|
||||
"DigestAuth",
|
||||
"WSGIDispatch",
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
__title__ = "httpx"
|
||||
__description__ = "A next generation HTTP client, for Python 3."
|
||||
__version__ = "0.12.1"
|
||||
__version__ = "0.13.dev0"
|
||||
|
||||
157
httpx/_client.py
157
httpx/_client.py
@ -3,6 +3,7 @@ import typing
|
||||
from types import TracebackType
|
||||
|
||||
import hstspreload
|
||||
import httpcore
|
||||
|
||||
from ._auth import Auth, AuthTypes, BasicAuth, FunctionAuth
|
||||
from ._config import (
|
||||
@ -14,6 +15,7 @@ from ._config import (
|
||||
PoolLimits,
|
||||
ProxiesTypes,
|
||||
Proxy,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
TimeoutTypes,
|
||||
UnsetType,
|
||||
@ -21,10 +23,6 @@ from ._config import (
|
||||
)
|
||||
from ._content_streams import ContentStream
|
||||
from ._dispatch.asgi import ASGIDispatch
|
||||
from ._dispatch.base import AsyncDispatcher, SyncDispatcher
|
||||
from ._dispatch.connection_pool import ConnectionPool
|
||||
from ._dispatch.proxy_http import HTTPProxy
|
||||
from ._dispatch.urllib3 import URLLib3Dispatcher
|
||||
from ._dispatch.wsgi import WSGIDispatch
|
||||
from ._exceptions import HTTPError, InvalidURL, RequestBodyUnavailable, TooManyRedirects
|
||||
from ._models import (
|
||||
@ -96,7 +94,7 @@ class BaseClient:
|
||||
elif isinstance(proxies, (str, URL, Proxy)):
|
||||
proxy = Proxy(url=proxies) if isinstance(proxies, (str, URL)) else proxies
|
||||
return {"all": proxy}
|
||||
elif isinstance(proxies, AsyncDispatcher): # pragma: nocover
|
||||
elif isinstance(proxies, httpcore.AsyncHTTPTransport): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Passing a dispatcher instance to 'proxies=' is no longer "
|
||||
"supported. Use `httpx.Proxy() instead.`"
|
||||
@ -107,7 +105,7 @@ class BaseClient:
|
||||
if isinstance(value, (str, URL, Proxy)):
|
||||
proxy = Proxy(url=value) if isinstance(value, (str, URL)) else value
|
||||
new_proxies[str(key)] = proxy
|
||||
elif isinstance(value, AsyncDispatcher): # pragma: nocover
|
||||
elif isinstance(value, httpcore.AsyncHTTPTransport): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Passing a dispatcher instance to 'proxies=' is "
|
||||
"no longer supported. Use `httpx.Proxy() instead.`"
|
||||
@ -446,7 +444,7 @@ class Client(BaseClient):
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
base_url: URLTypes = None,
|
||||
dispatch: SyncDispatcher = None,
|
||||
dispatch: httpcore.SyncHTTPTransport = None,
|
||||
app: typing.Callable = None,
|
||||
trust_env: bool = True,
|
||||
):
|
||||
@ -471,7 +469,7 @@ class Client(BaseClient):
|
||||
app=app,
|
||||
trust_env=trust_env,
|
||||
)
|
||||
self.proxies: typing.Dict[str, SyncDispatcher] = {
|
||||
self.proxies: typing.Dict[str, httpcore.SyncHTTPTransport] = {
|
||||
key: self.init_proxy_dispatch(
|
||||
proxy,
|
||||
verify=verify,
|
||||
@ -487,18 +485,26 @@ class Client(BaseClient):
|
||||
verify: VerifyTypes = True,
|
||||
cert: CertTypes = None,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
dispatch: SyncDispatcher = None,
|
||||
dispatch: httpcore.SyncHTTPTransport = None,
|
||||
app: typing.Callable = None,
|
||||
trust_env: bool = True,
|
||||
) -> SyncDispatcher:
|
||||
) -> httpcore.SyncHTTPTransport:
|
||||
if dispatch is not None:
|
||||
return dispatch
|
||||
|
||||
if app is not None:
|
||||
return WSGIDispatch(app=app)
|
||||
|
||||
return URLLib3Dispatcher(
|
||||
verify=verify, cert=cert, pool_limits=pool_limits, trust_env=trust_env,
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
max_keepalive = pool_limits.soft_limit
|
||||
max_connections = pool_limits.hard_limit
|
||||
|
||||
return httpcore.SyncConnectionPool(
|
||||
ssl_context=ssl_context,
|
||||
max_keepalive=max_keepalive,
|
||||
max_connections=max_connections,
|
||||
)
|
||||
|
||||
def init_proxy_dispatch(
|
||||
@ -508,18 +514,25 @@ class Client(BaseClient):
|
||||
cert: CertTypes = None,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
trust_env: bool = True,
|
||||
) -> SyncDispatcher:
|
||||
return URLLib3Dispatcher(
|
||||
proxy=proxy,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
pool_limits=pool_limits,
|
||||
trust_env=trust_env,
|
||||
) -> httpcore.SyncHTTPTransport:
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
max_keepalive = pool_limits.soft_limit
|
||||
max_connections = pool_limits.hard_limit
|
||||
|
||||
return httpcore.SyncHTTPProxy(
|
||||
proxy_origin=proxy.url.raw[:3],
|
||||
proxy_headers=proxy.headers.raw,
|
||||
proxy_mode=proxy.mode,
|
||||
ssl_context=ssl_context,
|
||||
max_keepalive=max_keepalive,
|
||||
max_connections=max_connections,
|
||||
)
|
||||
|
||||
def dispatcher_for_url(self, url: URL) -> SyncDispatcher:
|
||||
def dispatcher_for_url(self, url: URL) -> httpcore.SyncHTTPTransport:
|
||||
"""
|
||||
Returns the SyncDispatcher instance that should be used for a given URL.
|
||||
Returns the transport instance that should be used for a given URL.
|
||||
This will either be the standard connection pool, or a proxy.
|
||||
"""
|
||||
if self.proxies and not should_not_be_proxied(url):
|
||||
@ -667,7 +680,7 @@ class Client(BaseClient):
|
||||
request = next_request
|
||||
history.append(response)
|
||||
|
||||
def send_single_request(self, request: Request, timeout: Timeout,) -> Response:
|
||||
def send_single_request(self, request: Request, timeout: Timeout) -> Response:
|
||||
"""
|
||||
Sends a single request, without handling any redirections.
|
||||
"""
|
||||
@ -675,7 +688,19 @@ class Client(BaseClient):
|
||||
dispatcher = self.dispatcher_for_url(request.url)
|
||||
|
||||
try:
|
||||
response = dispatcher.send(request, timeout=timeout)
|
||||
(
|
||||
http_version,
|
||||
status_code,
|
||||
reason_phrase,
|
||||
headers,
|
||||
stream,
|
||||
) = dispatcher.request(
|
||||
request.method.encode(),
|
||||
request.url.raw,
|
||||
headers=request.headers.raw,
|
||||
stream=request.stream,
|
||||
timeout=timeout.as_dict(),
|
||||
)
|
||||
except HTTPError as exc:
|
||||
# Add the original request to any HTTPError unless
|
||||
# there'a already a request attached in the case of
|
||||
@ -683,6 +708,13 @@ class Client(BaseClient):
|
||||
if exc._request is None:
|
||||
exc._request = request
|
||||
raise
|
||||
response = Response(
|
||||
status_code,
|
||||
http_version=http_version.decode("ascii"),
|
||||
headers=headers,
|
||||
stream=stream, # type: ignore
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.cookies.extract_cookies(response)
|
||||
|
||||
@ -928,7 +960,6 @@ class AsyncClient(BaseClient):
|
||||
rather than sending actual network requests.
|
||||
* **trust_env** - *(optional)* Enables or disables usage of environment
|
||||
variables for configuration.
|
||||
* **uds** - *(optional)* A path to a Unix domain socket to connect through.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -946,10 +977,9 @@ class AsyncClient(BaseClient):
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
base_url: URLTypes = None,
|
||||
dispatch: AsyncDispatcher = None,
|
||||
dispatch: httpcore.AsyncHTTPTransport = None,
|
||||
app: typing.Callable = None,
|
||||
trust_env: bool = True,
|
||||
uds: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
auth=auth,
|
||||
@ -972,9 +1002,8 @@ class AsyncClient(BaseClient):
|
||||
dispatch=dispatch,
|
||||
app=app,
|
||||
trust_env=trust_env,
|
||||
uds=uds,
|
||||
)
|
||||
self.proxies: typing.Dict[str, AsyncDispatcher] = {
|
||||
self.proxies: typing.Dict[str, httpcore.AsyncHTTPTransport] = {
|
||||
key: self.init_proxy_dispatch(
|
||||
proxy,
|
||||
verify=verify,
|
||||
@ -992,24 +1021,27 @@ class AsyncClient(BaseClient):
|
||||
cert: CertTypes = None,
|
||||
http2: bool = False,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
dispatch: AsyncDispatcher = None,
|
||||
dispatch: httpcore.AsyncHTTPTransport = None,
|
||||
app: typing.Callable = None,
|
||||
trust_env: bool = True,
|
||||
uds: str = None,
|
||||
) -> AsyncDispatcher:
|
||||
) -> httpcore.AsyncHTTPTransport:
|
||||
if dispatch is not None:
|
||||
return dispatch
|
||||
|
||||
if app is not None:
|
||||
return ASGIDispatch(app=app)
|
||||
|
||||
return ConnectionPool(
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
max_keepalive = pool_limits.soft_limit
|
||||
max_connections = pool_limits.hard_limit
|
||||
|
||||
return httpcore.AsyncConnectionPool(
|
||||
ssl_context=ssl_context,
|
||||
max_keepalive=max_keepalive,
|
||||
max_connections=max_connections,
|
||||
http2=http2,
|
||||
pool_limits=pool_limits,
|
||||
trust_env=trust_env,
|
||||
uds=uds,
|
||||
)
|
||||
|
||||
def init_proxy_dispatch(
|
||||
@ -1020,21 +1052,25 @@ class AsyncClient(BaseClient):
|
||||
http2: bool = False,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
trust_env: bool = True,
|
||||
) -> AsyncDispatcher:
|
||||
return HTTPProxy(
|
||||
proxy_url=proxy.url,
|
||||
proxy_headers=proxy.headers,
|
||||
) -> httpcore.AsyncHTTPTransport:
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
max_keepalive = pool_limits.soft_limit
|
||||
max_connections = pool_limits.hard_limit
|
||||
|
||||
return httpcore.AsyncHTTPProxy(
|
||||
proxy_origin=proxy.url.raw[:3],
|
||||
proxy_headers=proxy.headers.raw,
|
||||
proxy_mode=proxy.mode,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
http2=http2,
|
||||
pool_limits=pool_limits,
|
||||
trust_env=trust_env,
|
||||
ssl_context=ssl_context,
|
||||
max_keepalive=max_keepalive,
|
||||
max_connections=max_connections,
|
||||
)
|
||||
|
||||
def dispatcher_for_url(self, url: URL) -> AsyncDispatcher:
|
||||
def dispatcher_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport:
|
||||
"""
|
||||
Returns the AsyncDispatcher instance that should be used for a given URL.
|
||||
Returns the transport instance that should be used for a given URL.
|
||||
This will either be the standard connection pool, or a proxy.
|
||||
"""
|
||||
if self.proxies and not should_not_be_proxied(url):
|
||||
@ -1193,7 +1229,19 @@ class AsyncClient(BaseClient):
|
||||
dispatcher = self.dispatcher_for_url(request.url)
|
||||
|
||||
try:
|
||||
response = await dispatcher.send(request, timeout=timeout)
|
||||
(
|
||||
http_version,
|
||||
status_code,
|
||||
reason_phrase,
|
||||
headers,
|
||||
stream,
|
||||
) = await dispatcher.request(
|
||||
request.method.encode(),
|
||||
request.url.raw,
|
||||
headers=request.headers.raw,
|
||||
stream=request.stream,
|
||||
timeout=timeout.as_dict(),
|
||||
)
|
||||
except HTTPError as exc:
|
||||
# Add the original request to any HTTPError unless
|
||||
# there'a already a request attached in the case of
|
||||
@ -1201,6 +1249,13 @@ class AsyncClient(BaseClient):
|
||||
if exc._request is None:
|
||||
exc._request = request
|
||||
raise
|
||||
response = Response(
|
||||
status_code,
|
||||
http_version=http_version.decode("ascii"),
|
||||
headers=headers,
|
||||
stream=stream, # type: ignore
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.cookies.extract_cookies(response)
|
||||
|
||||
@ -1383,9 +1438,9 @@ class AsyncClient(BaseClient):
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.dispatch.close()
|
||||
await self.dispatch.aclose()
|
||||
for proxy in self.proxies.values():
|
||||
await proxy.close()
|
||||
await proxy.aclose()
|
||||
|
||||
async def __aenter__(self) -> "AsyncClient":
|
||||
return self
|
||||
|
||||
@ -253,6 +253,14 @@ class Timeout:
|
||||
timeout if isinstance(pool_timeout, UnsetType) else pool_timeout
|
||||
)
|
||||
|
||||
def as_dict(self) -> typing.Dict[str, typing.Optional[float]]:
|
||||
return {
|
||||
"connect": self.connect_timeout,
|
||||
"read": self.read_timeout,
|
||||
"write": self.write_timeout,
|
||||
"pool": self.pool_timeout,
|
||||
}
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
|
||||
@ -7,6 +7,8 @@ from json import dumps as json_dumps
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpcore
|
||||
|
||||
from ._exceptions import StreamConsumed
|
||||
from ._types import StrOrBytes
|
||||
from ._utils import format_form_param
|
||||
@ -35,7 +37,7 @@ RequestFiles = typing.Dict[
|
||||
]
|
||||
|
||||
|
||||
class ContentStream:
|
||||
class ContentStream(httpcore.AsyncByteStream, httpcore.SyncByteStream):
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
"""
|
||||
Return a dictionary of headers that are implied by the encoding.
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import typing
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import httpcore
|
||||
|
||||
from .._config import TimeoutTypes
|
||||
from .._content_streams import ByteStream
|
||||
from .._models import Request, Response
|
||||
from .base import AsyncDispatcher
|
||||
|
||||
|
||||
class ASGIDispatch(AsyncDispatcher):
|
||||
class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
"""
|
||||
A custom AsyncDispatcher that handles sending requests directly to an ASGI app.
|
||||
The simplest way to use this functionality is to use the `app` argument.
|
||||
@ -41,37 +40,49 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: typing.Callable,
|
||||
app: Callable,
|
||||
raise_app_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
|
||||
client: Tuple[str, int] = ("127.0.0.1", 123),
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self.root_path = root_path
|
||||
self.client = client
|
||||
|
||||
async def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
|
||||
async def request(
|
||||
self,
|
||||
method: bytes,
|
||||
url: Tuple[bytes, bytes, int, bytes],
|
||||
headers: List[Tuple[bytes, bytes]] = None,
|
||||
stream: httpcore.AsyncByteStream = None,
|
||||
timeout: Dict[str, Optional[float]] = None,
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
|
||||
scheme, host, port, full_path = url
|
||||
path, _, query = full_path.partition(b"?")
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"headers": request.headers.raw,
|
||||
"scheme": request.url.scheme,
|
||||
"path": request.url.path,
|
||||
"query_string": request.url.query.encode("ascii"),
|
||||
"server": request.url.host,
|
||||
"method": method.decode(),
|
||||
"headers": headers,
|
||||
"scheme": scheme.decode("ascii"),
|
||||
"path": path.decode("ascii"),
|
||||
"query_string": query,
|
||||
"server": (host.decode("ascii"), port),
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
}
|
||||
status_code = None
|
||||
headers = None
|
||||
response_headers = None
|
||||
body_parts = []
|
||||
response_started = False
|
||||
response_complete = False
|
||||
|
||||
request_body_chunks = request.stream.__aiter__()
|
||||
headers = [] if headers is None else headers
|
||||
stream = ByteStream(b"") if stream is None else stream
|
||||
|
||||
request_body_chunks = stream.__aiter__()
|
||||
|
||||
async def receive() -> dict:
|
||||
try:
|
||||
@ -81,14 +92,14 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
return {"type": "http.request", "body": body, "more_body": True}
|
||||
|
||||
async def send(message: dict) -> None:
|
||||
nonlocal status_code, headers, body_parts
|
||||
nonlocal status_code, response_headers, body_parts
|
||||
nonlocal response_started, response_complete
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert not response_started
|
||||
|
||||
status_code = message["status"]
|
||||
headers = message.get("headers", [])
|
||||
response_headers = message.get("headers", [])
|
||||
response_started = True
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
@ -96,7 +107,7 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
if body and request.method != "HEAD":
|
||||
if body and method != b"HEAD":
|
||||
body_parts.append(body)
|
||||
|
||||
if not more_body:
|
||||
@ -110,14 +121,8 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
|
||||
assert response_complete
|
||||
assert status_code is not None
|
||||
assert headers is not None
|
||||
assert response_headers is not None
|
||||
|
||||
stream = ByteStream(b"".join(body_parts))
|
||||
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
http_version="HTTP/1.1",
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
request=request,
|
||||
)
|
||||
return (b"HTTP/1.1", status_code, b"", response_headers, stream)
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .._config import Timeout
|
||||
from .._models import (
|
||||
HeaderTypes,
|
||||
QueryParamTypes,
|
||||
Request,
|
||||
RequestData,
|
||||
Response,
|
||||
URLTypes,
|
||||
)
|
||||
|
||||
|
||||
class SyncDispatcher:
|
||||
"""
|
||||
Base class for Dispatcher classes, that handle sending the request.
|
||||
"""
|
||||
|
||||
def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def close(self) -> None:
|
||||
pass # pragma: nocover
|
||||
|
||||
|
||||
class AsyncDispatcher:
|
||||
"""
|
||||
Base class for AsyncDispatcher classes, that handle sending the request.
|
||||
|
||||
Stubs out the interface, as well as providing a `.request()` convenience
|
||||
implementation, to make it easy to use or test stand-alone AsyncDispatchers,
|
||||
without requiring a complete `AsyncClient` instance.
|
||||
"""
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
timeout: Timeout = None,
|
||||
) -> Response:
|
||||
request = Request(method, url, data=data, params=params, headers=headers)
|
||||
return await self.send(request, timeout=timeout)
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
pass # pragma: nocover
|
||||
|
||||
async def __aenter__(self) -> "AsyncDispatcher":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
@ -1,149 +0,0 @@
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
from .._backends.base import ConcurrencyBackend, lookup_backend
|
||||
from .._config import SSLConfig, Timeout
|
||||
from .._models import URL, Origin, Request, Response
|
||||
from .._utils import get_logger
|
||||
from .base import AsyncDispatcher
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
# Callback signature: async def callback(conn: HTTPConnection) -> None
|
||||
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HTTPConnection(AsyncDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
ssl: SSLConfig = None,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
release_func: typing.Optional[ReleaseCallback] = None,
|
||||
uds: typing.Optional[str] = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = SSLConfig() if ssl is None else ssl
|
||||
self.backend = lookup_backend(backend)
|
||||
self.release_func = release_func
|
||||
self.uds = uds
|
||||
self.connection: typing.Union[None, HTTP11Connection, HTTP2Connection] = None
|
||||
self.expires_at: typing.Optional[float] = None
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
timeout = Timeout() if timeout is None else timeout
|
||||
|
||||
if self.connection is None:
|
||||
self.connection = await self.connect(timeout=timeout)
|
||||
|
||||
return await self.connection.send(request, timeout=timeout)
|
||||
|
||||
async def connect(
|
||||
self, timeout: Timeout
|
||||
) -> typing.Union[HTTP11Connection, HTTP2Connection]:
|
||||
host = self.origin.host
|
||||
port = self.origin.port
|
||||
ssl_context = None if not self.origin.is_ssl else self.ssl.ssl_context
|
||||
|
||||
if self.release_func is None:
|
||||
on_release = None
|
||||
else:
|
||||
on_release = functools.partial(self.release_func, self)
|
||||
|
||||
if self.uds is None:
|
||||
logger.trace(
|
||||
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
|
||||
)
|
||||
socket = await self.backend.open_tcp_stream(
|
||||
host, port, ssl_context, timeout
|
||||
)
|
||||
else:
|
||||
logger.trace(
|
||||
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
|
||||
)
|
||||
socket = await self.backend.open_uds_stream(
|
||||
self.uds, host, ssl_context, timeout
|
||||
)
|
||||
|
||||
http_version = socket.get_http_version()
|
||||
logger.trace(f"connected http_version={http_version!r}")
|
||||
|
||||
if http_version == "HTTP/2":
|
||||
return HTTP2Connection(socket, self.backend, on_release=on_release)
|
||||
return HTTP11Connection(socket, on_release=on_release)
|
||||
|
||||
async def tunnel_start_tls(
|
||||
self, origin: Origin, proxy_url: URL, timeout: Timeout = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upgrade this connection to use TLS, assuming it represents a TCP tunnel.
|
||||
"""
|
||||
timeout = Timeout() if timeout is None else timeout
|
||||
|
||||
# First, check that we are in the correct state to start TLS, i.e. we've
|
||||
# just agreed to switch protocols with the server via HTTP/1.1.
|
||||
assert isinstance(self.connection, HTTP11Connection)
|
||||
h11_connection = self.connection
|
||||
assert h11_connection is not None
|
||||
assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
|
||||
|
||||
# Store this information here so that we can transfer
|
||||
# it to the new internal connection object after
|
||||
# the old one goes to 'SWITCHED_PROTOCOL'.
|
||||
# Note that the negotiated 'http_version' may change after the TLS upgrade.
|
||||
http_version = "HTTP/1.1"
|
||||
socket = h11_connection.socket
|
||||
on_release = h11_connection.on_release
|
||||
|
||||
if origin.is_ssl:
|
||||
# Pull the socket stream off the internal HTTP connection object,
|
||||
# and run start_tls().
|
||||
ssl_context = self.ssl.ssl_context
|
||||
|
||||
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
|
||||
socket = await socket.start_tls(
|
||||
hostname=origin.host, ssl_context=ssl_context, timeout=timeout
|
||||
)
|
||||
http_version = socket.get_http_version()
|
||||
logger.trace(
|
||||
f"tunnel_tls_complete "
|
||||
f"proxy_url={proxy_url!r} "
|
||||
f"origin={origin!r} "
|
||||
f"http_version={http_version!r}"
|
||||
)
|
||||
else:
|
||||
# User requested the use of a tunnel, but they're performing a plain-text
|
||||
# HTTP request. Don't try to upgrade to TLS in this case.
|
||||
pass
|
||||
|
||||
if http_version == "HTTP/2":
|
||||
self.connection = HTTP2Connection(
|
||||
socket, self.backend, on_release=on_release
|
||||
)
|
||||
else:
|
||||
self.connection = HTTP11Connection(socket, on_release=on_release)
|
||||
|
||||
async def close(self) -> None:
|
||||
logger.trace("close_connection")
|
||||
if self.connection is not None:
|
||||
await self.connection.close()
|
||||
|
||||
@property
|
||||
def is_http2(self) -> bool:
|
||||
return self.connection is not None and self.connection.is_http2
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.connection is not None and self.connection.is_closed
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.connection is not None and self.connection.is_connection_dropped()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(origin={self.origin!r})"
|
||||
@ -1,221 +0,0 @@
|
||||
import typing
|
||||
|
||||
from .._backends.base import BaseSemaphore, ConcurrencyBackend, lookup_backend
|
||||
from .._config import (
|
||||
DEFAULT_POOL_LIMITS,
|
||||
CertTypes,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
VerifyTypes,
|
||||
)
|
||||
from .._exceptions import PoolTimeout
|
||||
from .._models import Origin, Request, Response
|
||||
from .._utils import get_logger
|
||||
from .base import AsyncDispatcher
|
||||
from .connection import HTTPConnection
|
||||
|
||||
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NullSemaphore(BaseSemaphore):
|
||||
async def acquire(self, timeout: float = None) -> None:
|
||||
return
|
||||
|
||||
def release(self) -> None:
|
||||
return
|
||||
|
||||
|
||||
class ConnectionStore:
|
||||
"""
|
||||
We need to maintain collections of connections in a way that allows us to:
|
||||
|
||||
* Lookup connections by origin.
|
||||
* Iterate over connections by insertion time.
|
||||
* Return the total number of connections.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all: typing.Dict[HTTPConnection, float] = {}
|
||||
self.by_origin: typing.Dict[Origin, typing.Dict[HTTPConnection, float]] = {}
|
||||
|
||||
def pop_by_origin(
|
||||
self, origin: Origin, http2_only: bool = False
|
||||
) -> typing.Optional[HTTPConnection]:
|
||||
try:
|
||||
connections = self.by_origin[origin]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
connection = next(reversed(list(connections.keys())))
|
||||
if http2_only and not connection.is_http2:
|
||||
return None
|
||||
|
||||
del connections[connection]
|
||||
if not connections:
|
||||
del self.by_origin[origin]
|
||||
del self.all[connection]
|
||||
|
||||
return connection
|
||||
|
||||
def add(self, connection: HTTPConnection) -> None:
|
||||
self.all[connection] = 0.0
|
||||
try:
|
||||
self.by_origin[connection.origin][connection] = 0.0
|
||||
except KeyError:
|
||||
self.by_origin[connection.origin] = {connection: 0.0}
|
||||
|
||||
def remove(self, connection: HTTPConnection) -> None:
|
||||
del self.all[connection]
|
||||
del self.by_origin[connection.origin][connection]
|
||||
if not self.by_origin[connection.origin]:
|
||||
del self.by_origin[connection.origin]
|
||||
|
||||
def clear(self) -> None:
|
||||
self.all.clear()
|
||||
self.by_origin.clear()
|
||||
|
||||
def __iter__(self) -> typing.Iterator[HTTPConnection]:
|
||||
return iter(self.all.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.all)
|
||||
|
||||
|
||||
class ConnectionPool(AsyncDispatcher):
|
||||
KEEP_ALIVE_EXPIRY = 5.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
verify: VerifyTypes = True,
|
||||
cert: CertTypes = None,
|
||||
trust_env: bool = None,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
http2: bool = False,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
uds: typing.Optional[str] = None,
|
||||
):
|
||||
self.ssl = SSLConfig(verify=verify, cert=cert, trust_env=trust_env, http2=http2)
|
||||
self.pool_limits = pool_limits
|
||||
self.is_closed = False
|
||||
self.uds = uds
|
||||
|
||||
self.keepalive_connections = ConnectionStore()
|
||||
self.active_connections = ConnectionStore()
|
||||
|
||||
self.backend = lookup_backend(backend)
|
||||
self.next_keepalive_check = 0.0
|
||||
|
||||
@property
|
||||
def max_connections(self) -> BaseSemaphore:
|
||||
# We do this lazily, to make sure backend autodetection always
|
||||
# runs within an async context.
|
||||
if not hasattr(self, "_max_connections"):
|
||||
limit = self.pool_limits.hard_limit
|
||||
if limit:
|
||||
self._max_connections = self.backend.create_semaphore(
|
||||
limit, exc_class=PoolTimeout
|
||||
)
|
||||
else:
|
||||
self._max_connections = NullSemaphore()
|
||||
|
||||
return self._max_connections
|
||||
|
||||
@property
|
||||
def num_connections(self) -> int:
|
||||
return len(self.keepalive_connections) + len(self.active_connections)
|
||||
|
||||
async def check_keepalive_expiry(self) -> None:
|
||||
now = self.backend.time()
|
||||
if now < self.next_keepalive_check:
|
||||
return
|
||||
self.next_keepalive_check = now + 1.0
|
||||
|
||||
# Iterate through all the keep alive connections.
|
||||
# We create a list here to avoid any 'changed during iteration' errors.
|
||||
keepalives = list(self.keepalive_connections.all.keys())
|
||||
for connection in keepalives:
|
||||
if connection.expires_at is not None and now > connection.expires_at:
|
||||
self.keepalive_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
await connection.close()
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
await self.check_keepalive_expiry()
|
||||
connection = await self.acquire_connection(
|
||||
origin=Origin(request.url), timeout=timeout
|
||||
)
|
||||
try:
|
||||
response = await connection.send(request, timeout=timeout)
|
||||
except BaseException as exc:
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
raise exc
|
||||
|
||||
return response
|
||||
|
||||
async def acquire_connection(
|
||||
self, origin: Origin, timeout: Timeout = None
|
||||
) -> HTTPConnection:
|
||||
logger.trace(f"acquire_connection origin={origin!r}")
|
||||
connection = self.pop_connection(origin)
|
||||
|
||||
if connection is None:
|
||||
pool_timeout = None if timeout is None else timeout.pool_timeout
|
||||
|
||||
await self.max_connections.acquire(timeout=pool_timeout)
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
backend=self.backend,
|
||||
release_func=self.release_connection,
|
||||
uds=self.uds,
|
||||
)
|
||||
logger.trace(f"new_connection connection={connection!r}")
|
||||
else:
|
||||
logger.trace(f"reuse_connection connection={connection!r}")
|
||||
|
||||
self.active_connections.add(connection)
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, connection: HTTPConnection) -> None:
|
||||
logger.trace(f"release_connection connection={connection!r}")
|
||||
if connection.is_closed:
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
elif (
|
||||
self.pool_limits.soft_limit is not None
|
||||
and self.num_connections > self.pool_limits.soft_limit
|
||||
):
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
await connection.close()
|
||||
else:
|
||||
now = self.backend.time()
|
||||
connection.expires_at = now + self.KEEP_ALIVE_EXPIRY
|
||||
self.active_connections.remove(connection)
|
||||
self.keepalive_connections.add(connection)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.is_closed = True
|
||||
connections = list(self.keepalive_connections)
|
||||
self.keepalive_connections.clear()
|
||||
for connection in connections:
|
||||
self.max_connections.release()
|
||||
await connection.close()
|
||||
|
||||
def pop_connection(self, origin: Origin) -> typing.Optional[HTTPConnection]:
|
||||
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
|
||||
if connection is None:
|
||||
connection = self.keepalive_connections.pop_by_origin(origin)
|
||||
|
||||
if connection is not None and connection.is_connection_dropped():
|
||||
self.max_connections.release()
|
||||
connection = None
|
||||
|
||||
return connection
|
||||
@ -1,206 +0,0 @@
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
from .._backends.base import BaseSocketStream
|
||||
from .._config import Timeout
|
||||
from .._content_streams import AsyncIteratorStream
|
||||
from .._exceptions import ConnectionClosed, ProtocolError
|
||||
from .._models import Request, Response
|
||||
from .._utils import get_logger
|
||||
|
||||
H11Event = typing.Union[
|
||||
h11.Request,
|
||||
h11.Response,
|
||||
h11.InformationalResponse,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
h11.ConnectionClosed,
|
||||
]
|
||||
|
||||
|
||||
# Callback signature: async def callback() -> None
|
||||
# In practice the callback will be a functools partial, which binds
|
||||
# the `ConnectionPool.release_connection(conn: HTTPConnection)` method.
|
||||
OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HTTP11Connection:
|
||||
READ_NUM_BYTES = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: BaseSocketStream,
|
||||
on_release: typing.Optional[OnReleaseCallback] = None,
|
||||
):
|
||||
self.socket = socket
|
||||
self.on_release = on_release
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
@property
|
||||
def is_http2(self) -> bool:
|
||||
return False
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
timeout = Timeout() if timeout is None else timeout
|
||||
|
||||
await self._send_request(request, timeout)
|
||||
await self._send_request_body(request, timeout)
|
||||
http_version, status_code, headers = await self._receive_response(timeout)
|
||||
stream = AsyncIteratorStream(
|
||||
aiterator=self._receive_response_data(timeout),
|
||||
close_func=self.response_closed,
|
||||
)
|
||||
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
http_version=http_version,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
logger.trace(f"send_event event={event!r}")
|
||||
self.h11_state.send(event)
|
||||
except h11.LocalProtocolError: # pragma: no cover
|
||||
# Premature client disconnect
|
||||
pass
|
||||
await self.socket.close()
|
||||
|
||||
async def _send_request(self, request: Request, timeout: Timeout) -> None:
|
||||
"""
|
||||
Send the request method, URL, and headers to the network.
|
||||
"""
|
||||
logger.trace(
|
||||
f"send_headers method={request.method!r} "
|
||||
f"target={request.url.full_path!r} "
|
||||
f"headers={request.headers!r}"
|
||||
)
|
||||
|
||||
method = request.method.encode("ascii")
|
||||
target = request.url.full_path.encode("ascii")
|
||||
headers = request.headers.raw
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
async def _send_request_body(self, request: Request, timeout: Timeout) -> None:
|
||||
"""
|
||||
Send the request body to the network.
|
||||
"""
|
||||
try:
|
||||
# Send the request body.
|
||||
async for chunk in request.stream:
|
||||
logger.trace(f"send_data data=Data(<{len(chunk)} bytes>)")
|
||||
event = h11.Data(data=chunk)
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
# Finalize sending the request.
|
||||
event = h11.EndOfMessage()
|
||||
await self._send_event(event, timeout)
|
||||
except OSError: # pragma: nocover
|
||||
# Once we've sent the initial part of the request we don't actually
|
||||
# care about connection errors that occur when sending the body.
|
||||
# Ignore these, and defer to any exceptions on reading the response.
|
||||
self.h11_state.send_failed()
|
||||
|
||||
async def _send_event(self, event: H11Event, timeout: Timeout) -> None:
|
||||
"""
|
||||
Send a single `h11` event to the network, waiting for the data to
|
||||
drain before returning.
|
||||
"""
|
||||
bytes_to_send = self.h11_state.send(event)
|
||||
await self.socket.write(bytes_to_send, timeout)
|
||||
|
||||
async def _receive_response(
|
||||
self, timeout: Timeout
|
||||
) -> typing.Tuple[str, int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Read the response status and headers from the network.
|
||||
"""
|
||||
while True:
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h11.InformationalResponse):
|
||||
continue
|
||||
else:
|
||||
assert isinstance(event, h11.Response)
|
||||
break # pragma: no cover
|
||||
http_version = "HTTP/%s" % event.http_version.decode("latin-1", errors="ignore")
|
||||
return http_version, event.status_code, event.headers
|
||||
|
||||
async def _receive_response_data(
|
||||
self, timeout: Timeout
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Read the response data from the network.
|
||||
"""
|
||||
while True:
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h11.Data):
|
||||
yield bytes(event.data)
|
||||
else:
|
||||
assert isinstance(event, h11.EndOfMessage) or event is h11.PAUSED
|
||||
break # pragma: no cover
|
||||
|
||||
async def _receive_event(self, timeout: Timeout) -> H11Event:
|
||||
"""
|
||||
Read a single `h11` event, reading more data from the network if needed.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
event = self.h11_state.next_event()
|
||||
except h11.RemoteProtocolError as e:
|
||||
logger.debug(
|
||||
"h11.RemoteProtocolError exception "
|
||||
+ f"their_state={self.h11_state.their_state} "
|
||||
+ f"error_status_hint={e.error_status_hint}"
|
||||
)
|
||||
if self.socket.is_connection_dropped():
|
||||
raise ConnectionClosed(e)
|
||||
raise ProtocolError(e)
|
||||
|
||||
if isinstance(event, h11.Data):
|
||||
logger.trace(f"receive_event event=Data(<{len(event.data)} bytes>)")
|
||||
else:
|
||||
logger.trace(f"receive_event event={event!r}")
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
try:
|
||||
data = await self.socket.read(self.READ_NUM_BYTES, timeout)
|
||||
except OSError: # pragma: nocover
|
||||
data = b""
|
||||
self.h11_state.receive_data(data)
|
||||
else:
|
||||
assert event is not h11.NEED_DATA
|
||||
break # pragma: no cover
|
||||
return event
|
||||
|
||||
async def response_closed(self) -> None:
|
||||
logger.trace(
|
||||
f"response_closed "
|
||||
f"our_state={self.h11_state.our_state!r} "
|
||||
f"their_state={self.h11_state.their_state}"
|
||||
)
|
||||
if (
|
||||
self.h11_state.our_state is h11.DONE
|
||||
and self.h11_state.their_state is h11.DONE
|
||||
):
|
||||
# Get ready for another request/response cycle.
|
||||
self.h11_state.start_next_cycle()
|
||||
else:
|
||||
await self.close()
|
||||
|
||||
if self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.socket.is_connection_dropped()
|
||||
@ -1,298 +0,0 @@
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
import h2.events
|
||||
from h2.config import H2Configuration
|
||||
from h2.settings import SettingCodes, Settings
|
||||
|
||||
from .._backends.base import (
|
||||
BaseLock,
|
||||
BaseSocketStream,
|
||||
ConcurrencyBackend,
|
||||
lookup_backend,
|
||||
)
|
||||
from .._config import Timeout
|
||||
from .._content_streams import AsyncIteratorStream
|
||||
from .._exceptions import ProtocolError
|
||||
from .._models import Request, Response
|
||||
from .._utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HTTP2Connection:
|
||||
READ_NUM_BYTES = 4096
|
||||
CONFIG = H2Configuration(validate_inbound_headers=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: BaseSocketStream,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.socket = socket
|
||||
self.backend = lookup_backend(backend)
|
||||
self.on_release = on_release
|
||||
self.state = h2.connection.H2Connection(config=self.CONFIG)
|
||||
|
||||
self.streams = {} # type: typing.Dict[int, HTTP2Stream]
|
||||
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
|
||||
|
||||
self.sent_connection_init = False
|
||||
|
||||
@property
|
||||
def is_http2(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def init_lock(self) -> BaseLock:
|
||||
# We do this lazily, to make sure backend autodetection always
|
||||
# runs within an async context.
|
||||
if not hasattr(self, "_initialization_lock"):
|
||||
self._initialization_lock = self.backend.create_lock()
|
||||
return self._initialization_lock
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
timeout = Timeout() if timeout is None else timeout
|
||||
|
||||
async with self.init_lock:
|
||||
if not self.sent_connection_init:
|
||||
# The very first stream is responsible for initiating the connection.
|
||||
await self.send_connection_init(timeout)
|
||||
self.sent_connection_init = True
|
||||
stream_id = self.state.get_next_available_stream_id()
|
||||
|
||||
stream = HTTP2Stream(stream_id=stream_id, connection=self)
|
||||
self.streams[stream_id] = stream
|
||||
self.events[stream_id] = []
|
||||
return await stream.send(request, timeout)
|
||||
|
||||
async def send_connection_init(self, timeout: Timeout) -> None:
|
||||
"""
|
||||
The HTTP/2 connection requires some initial setup before we can start
|
||||
using individual request/response streams on it.
|
||||
"""
|
||||
|
||||
# Need to set these manually here instead of manipulating via
|
||||
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
|
||||
# frames in addition to sending the undesired defaults.
|
||||
self.state.local_settings = Settings(
|
||||
client=True,
|
||||
initial_values={
|
||||
# Disable PUSH_PROMISE frames from the server since we don't do anything
|
||||
# with them for now. Maybe when we support caching?
|
||||
SettingCodes.ENABLE_PUSH: 0,
|
||||
# These two are taken from h2 for safe defaults
|
||||
SettingCodes.MAX_CONCURRENT_STREAMS: 100,
|
||||
SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
|
||||
},
|
||||
)
|
||||
|
||||
# Some websites (*cough* Yahoo *cough*) balk at this setting being
|
||||
# present in the initial handshake since it's not defined in the original
|
||||
# RFC despite the RFC mandating ignoring settings you don't know about.
|
||||
del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL]
|
||||
|
||||
self.state.initiate_connection()
|
||||
self.state.increment_flow_control_window(2 ** 24)
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.socket.is_connection_dropped()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.socket.close()
|
||||
|
||||
async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int:
|
||||
"""
|
||||
Returns the maximum allowable outgoing flow for a given stream.
|
||||
|
||||
If the allowable flow is zero, then waits on the network until
|
||||
WindowUpdated frames have increased the flow rate.
|
||||
|
||||
https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
"""
|
||||
local_flow = self.state.local_flow_control_window(stream_id)
|
||||
connection_flow = self.state.max_outbound_frame_size
|
||||
flow = min(local_flow, connection_flow)
|
||||
while flow == 0:
|
||||
await self.receive_events(timeout)
|
||||
local_flow = self.state.local_flow_control_window(stream_id)
|
||||
connection_flow = self.state.max_outbound_frame_size
|
||||
flow = min(local_flow, connection_flow)
|
||||
return flow
|
||||
|
||||
async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
|
||||
"""
|
||||
Returns the next event for a given stream.
|
||||
|
||||
If no events are available yet, then waits on the network until
|
||||
an event is available.
|
||||
"""
|
||||
while not self.events[stream_id]:
|
||||
await self.receive_events(timeout)
|
||||
return self.events[stream_id].pop(0)
|
||||
|
||||
async def receive_events(self, timeout: Timeout) -> None:
|
||||
"""
|
||||
Read some data from the network, and update the H2 state.
|
||||
"""
|
||||
data = await self.socket.read(self.READ_NUM_BYTES, timeout)
|
||||
events = self.state.receive_data(data)
|
||||
for event in events:
|
||||
event_stream_id = getattr(event, "stream_id", 0)
|
||||
logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}")
|
||||
|
||||
if hasattr(event, "error_code"):
|
||||
raise ProtocolError(event)
|
||||
|
||||
if event_stream_id in self.events:
|
||||
self.events[event_stream_id].append(event)
|
||||
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def send_headers(
|
||||
self,
|
||||
stream_id: int,
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
end_stream: bool,
|
||||
timeout: Timeout,
|
||||
) -> None:
|
||||
self.state.send_headers(stream_id, headers, end_stream=end_stream)
|
||||
self.state.increment_flow_control_window(2 ** 24, stream_id=stream_id)
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def send_data(self, stream_id: int, chunk: bytes, timeout: Timeout) -> None:
|
||||
self.state.send_data(stream_id, chunk)
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
|
||||
self.state.end_stream(stream_id)
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def acknowledge_received_data(
|
||||
self, stream_id: int, amount: int, timeout: Timeout
|
||||
) -> None:
|
||||
self.state.acknowledge_received_data(amount, stream_id)
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def close_stream(self, stream_id: int) -> None:
|
||||
del self.streams[stream_id]
|
||||
del self.events[stream_id]
|
||||
|
||||
if not self.streams and self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
|
||||
class HTTP2Stream:
|
||||
def __init__(self, stream_id: int, connection: HTTP2Connection) -> None:
|
||||
self.stream_id = stream_id
|
||||
self.connection = connection
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout) -> Response:
|
||||
# Send the request.
|
||||
has_body = (
|
||||
"Content-Length" in request.headers
|
||||
or "Transfer-Encoding" in request.headers
|
||||
)
|
||||
|
||||
await self.send_headers(request, has_body, timeout)
|
||||
if has_body:
|
||||
await self.send_body(request, timeout)
|
||||
|
||||
# Receive the response.
|
||||
status_code, headers = await self.receive_response(timeout)
|
||||
stream = AsyncIteratorStream(
|
||||
aiterator=self.body_iter(timeout), close_func=self.close
|
||||
)
|
||||
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
http_version="HTTP/2",
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async def send_headers(
|
||||
self, request: Request, has_body: bool, timeout: Timeout
|
||||
) -> None:
|
||||
headers = [
|
||||
(b":method", request.method.encode("ascii")),
|
||||
(b":authority", request.url.authority.encode("ascii")),
|
||||
(b":scheme", request.url.scheme.encode("ascii")),
|
||||
(b":path", request.url.full_path.encode("ascii")),
|
||||
] + [
|
||||
(k, v)
|
||||
for k, v in request.headers.raw
|
||||
if k not in (b"host", b"transfer-encoding")
|
||||
]
|
||||
end_stream = not has_body
|
||||
|
||||
logger.trace(
|
||||
f"send_headers "
|
||||
f"stream_id={self.stream_id} "
|
||||
f"method={request.method!r} "
|
||||
f"target={request.url.full_path!r} "
|
||||
f"headers={headers!r}"
|
||||
)
|
||||
await self.connection.send_headers(self.stream_id, headers, end_stream, timeout)
|
||||
|
||||
async def send_body(self, request: Request, timeout: Timeout) -> None:
|
||||
logger.trace(f"send_body stream_id={self.stream_id}")
|
||||
async for data in request.stream:
|
||||
while data:
|
||||
max_flow = await self.connection.wait_for_outgoing_flow(
|
||||
self.stream_id, timeout
|
||||
)
|
||||
chunk_size = min(len(data), max_flow)
|
||||
chunk, data = data[:chunk_size], data[chunk_size:]
|
||||
await self.connection.send_data(self.stream_id, chunk, timeout)
|
||||
|
||||
await self.connection.end_stream(self.stream_id, timeout)
|
||||
|
||||
async def receive_response(
|
||||
self, timeout: Timeout
|
||||
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Read the response status and headers from the network.
|
||||
"""
|
||||
while True:
|
||||
event = await self.connection.wait_for_event(self.stream_id, timeout)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
|
||||
return (status_code, headers)
|
||||
|
||||
async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]:
|
||||
while True:
|
||||
event = await self.connection.wait_for_event(self.stream_id, timeout)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
amount = event.flow_controlled_length
|
||||
await self.connection.acknowledge_received_data(
|
||||
self.stream_id, amount, timeout
|
||||
)
|
||||
yield event.data
|
||||
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
|
||||
break
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.connection.close_stream(self.stream_id)
|
||||
@ -1,208 +0,0 @@
|
||||
import enum
|
||||
import typing
|
||||
import warnings
|
||||
from base64 import b64encode
|
||||
|
||||
from .._backends.base import ConcurrencyBackend
|
||||
from .._config import (
|
||||
DEFAULT_POOL_LIMITS,
|
||||
CertTypes,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
VerifyTypes,
|
||||
)
|
||||
from .._exceptions import ProxyError
|
||||
from .._models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
|
||||
from .._utils import get_logger
|
||||
from .connection import HTTPConnection
|
||||
from .connection_pool import ConnectionPool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HTTPProxyMode(enum.Enum):
|
||||
# This enum is pending deprecation in order to reduce API surface area,
|
||||
# but is currently still around for 0.8 backwards compat.
|
||||
DEFAULT = "DEFAULT"
|
||||
FORWARD_ONLY = "FORWARD_ONLY"
|
||||
TUNNEL_ONLY = "TUNNEL_ONLY"
|
||||
|
||||
|
||||
DEFAULT_MODE = "DEFAULT"
|
||||
FORWARD_ONLY = "FORWARD_ONLY"
|
||||
TUNNEL_ONLY = "TUNNEL_ONLY"
|
||||
|
||||
|
||||
class HTTPProxy(ConnectionPool):
|
||||
"""A proxy that sends requests to the recipient server
|
||||
on behalf of the connecting client.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: URLTypes,
|
||||
*,
|
||||
proxy_headers: HeaderTypes = None,
|
||||
proxy_mode: str = "DEFAULT",
|
||||
verify: VerifyTypes = True,
|
||||
cert: CertTypes = None,
|
||||
trust_env: bool = None,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
http2: bool = False,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
):
|
||||
|
||||
if isinstance(proxy_mode, HTTPProxyMode): # pragma: nocover
|
||||
warnings.warn(
|
||||
"The 'HTTPProxyMode' enum is pending deprecation. "
|
||||
"Use a plain string instead. proxy_mode='FORWARD_ONLY', or "
|
||||
"proxy_mode='TUNNEL_ONLY'."
|
||||
)
|
||||
proxy_mode = proxy_mode.value
|
||||
assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
|
||||
|
||||
self.tunnel_ssl = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env, http2=False
|
||||
)
|
||||
|
||||
super(HTTPProxy, self).__init__(
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
pool_limits=pool_limits,
|
||||
backend=backend,
|
||||
trust_env=trust_env,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
self.proxy_url = URL(proxy_url)
|
||||
self.proxy_mode = proxy_mode
|
||||
self.proxy_headers = Headers(proxy_headers)
|
||||
|
||||
url = self.proxy_url
|
||||
if url.username or url.password:
|
||||
self.proxy_headers.setdefault(
|
||||
"Proxy-Authorization",
|
||||
self.build_auth_header(url.username, url.password),
|
||||
)
|
||||
# Remove userinfo from the URL authority, e.g.:
|
||||
# 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
|
||||
credentials, _, authority = url.authority.rpartition("@")
|
||||
self.proxy_url = url.copy_with(authority=authority)
|
||||
|
||||
def build_auth_header(self, username: str, password: str) -> str:
|
||||
userpass = (username.encode("utf-8"), password.encode("utf-8"))
|
||||
token = b64encode(b":".join(userpass)).decode().strip()
|
||||
return f"Basic {token}"
|
||||
|
||||
async def acquire_connection(
|
||||
self, origin: Origin, timeout: Timeout = None
|
||||
) -> HTTPConnection:
|
||||
if self.should_forward_origin(origin):
|
||||
logger.trace(
|
||||
f"forward_connection proxy_url={self.proxy_url!r} origin={origin!r}"
|
||||
)
|
||||
return await super().acquire_connection(Origin(self.proxy_url), timeout)
|
||||
else:
|
||||
logger.trace(
|
||||
f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}"
|
||||
)
|
||||
return await self.tunnel_connection(origin, timeout)
|
||||
|
||||
async def tunnel_connection(
|
||||
self, origin: Origin, timeout: Timeout = None
|
||||
) -> HTTPConnection:
|
||||
"""Creates a new HTTPConnection via the CONNECT method
|
||||
usually reserved for proxying HTTPS connections.
|
||||
"""
|
||||
connection = self.pop_connection(origin)
|
||||
|
||||
if connection is None:
|
||||
connection = await self.request_tunnel_proxy_connection(origin)
|
||||
|
||||
# After we receive the 2XX response from the proxy that our
|
||||
# tunnel is open we switch the connection's origin
|
||||
# to the original so the tunnel can be re-used.
|
||||
self.active_connections.remove(connection)
|
||||
connection.origin = origin
|
||||
self.active_connections.add(connection)
|
||||
|
||||
await connection.tunnel_start_tls(
|
||||
origin=origin, proxy_url=self.proxy_url, timeout=timeout,
|
||||
)
|
||||
else:
|
||||
self.active_connections.add(connection)
|
||||
|
||||
return connection
|
||||
|
||||
async def request_tunnel_proxy_connection(self, origin: Origin) -> HTTPConnection:
|
||||
"""Creates an HTTPConnection by setting up a TCP tunnel"""
|
||||
proxy_headers = self.proxy_headers.copy()
|
||||
proxy_headers.setdefault("Accept", "*/*")
|
||||
proxy_request = Request(
|
||||
method="CONNECT", url=self.proxy_url.copy_with(), headers=proxy_headers
|
||||
)
|
||||
proxy_request.url.full_path = f"{origin.host}:{origin.port}"
|
||||
|
||||
await self.max_connections.acquire()
|
||||
|
||||
connection = HTTPConnection(
|
||||
Origin(self.proxy_url),
|
||||
ssl=self.tunnel_ssl,
|
||||
backend=self.backend,
|
||||
release_func=self.release_connection,
|
||||
)
|
||||
self.active_connections.add(connection)
|
||||
|
||||
# See if our tunnel has been opened successfully
|
||||
proxy_response = await connection.send(proxy_request)
|
||||
logger.trace(
|
||||
f"tunnel_response "
|
||||
f"proxy_url={self.proxy_url!r} "
|
||||
f"origin={origin!r} "
|
||||
f"response={proxy_response!r}"
|
||||
)
|
||||
if not (200 <= proxy_response.status_code <= 299):
|
||||
await proxy_response.aread()
|
||||
raise ProxyError(
|
||||
f"Non-2XX response received from HTTP proxy "
|
||||
f"({proxy_response.status_code})",
|
||||
request=proxy_request,
|
||||
response=proxy_response,
|
||||
)
|
||||
else:
|
||||
# Hack to ingest the response, without closing it.
|
||||
async for chunk in proxy_response._raw_stream:
|
||||
pass
|
||||
|
||||
return connection
|
||||
|
||||
def should_forward_origin(self, origin: Origin) -> bool:
|
||||
"""Determines if the given origin should
|
||||
be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'
|
||||
then the proxy will forward all 'HTTP' requests and
|
||||
tunnel all 'HTTPS' requests.
|
||||
"""
|
||||
return (
|
||||
self.proxy_mode == DEFAULT_MODE and not origin.is_ssl
|
||||
) or self.proxy_mode == FORWARD_ONLY
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
if self.should_forward_origin(Origin(request.url)):
|
||||
# Change the request to have the target URL
|
||||
# as its full_path and switch the proxy URL
|
||||
# for where the request will be sent.
|
||||
target_url = str(request.url)
|
||||
request.url = self.proxy_url.copy_with()
|
||||
request.url.full_path = target_url
|
||||
for name, value in self.proxy_headers.items():
|
||||
request.headers.setdefault(name, value)
|
||||
|
||||
return await super().send(request=request, timeout=timeout)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HTTPProxy(proxy_url={self.proxy_url!r} "
|
||||
f"proxy_headers={self.proxy_headers!r} "
|
||||
f"proxy_mode={self.proxy_mode!r})"
|
||||
)
|
||||
@ -1,8 +1,9 @@
|
||||
import math
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import httpcore
|
||||
import urllib3
|
||||
from urllib3.exceptions import MaxRetryError, SSLError
|
||||
|
||||
@ -12,16 +13,13 @@ from .._config import (
|
||||
PoolLimits,
|
||||
Proxy,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
VerifyTypes,
|
||||
)
|
||||
from .._content_streams import IteratorStream
|
||||
from .._models import Request, Response
|
||||
from .._content_streams import ByteStream, IteratorStream
|
||||
from .._utils import as_network_error
|
||||
from .base import SyncDispatcher
|
||||
|
||||
|
||||
class URLLib3Dispatcher(SyncDispatcher):
|
||||
class URLLib3Dispatcher(httpcore.SyncHTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -62,12 +60,12 @@ class URLLib3Dispatcher(SyncDispatcher):
|
||||
|
||||
def init_pool_manager(
|
||||
self,
|
||||
proxy: typing.Optional[Proxy],
|
||||
proxy: Optional[Proxy],
|
||||
ssl_context: ssl.SSLContext,
|
||||
num_pools: int,
|
||||
maxsize: int,
|
||||
block: bool,
|
||||
) -> typing.Union[urllib3.PoolManager, urllib3.ProxyManager]:
|
||||
) -> Union[urllib3.PoolManager, urllib3.ProxyManager]:
|
||||
if proxy is None:
|
||||
return urllib3.PoolManager(
|
||||
ssl_context=ssl_context,
|
||||
@ -85,20 +83,58 @@ class URLLib3Dispatcher(SyncDispatcher):
|
||||
block=block,
|
||||
)
|
||||
|
||||
def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
timeout = Timeout() if timeout is None else timeout
|
||||
def request(
|
||||
self,
|
||||
method: bytes,
|
||||
url: Tuple[bytes, bytes, int, bytes],
|
||||
headers: List[Tuple[bytes, bytes]] = None,
|
||||
stream: httpcore.SyncByteStream = None,
|
||||
timeout: Dict[str, Optional[float]] = None,
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
|
||||
headers = [] if headers is None else headers
|
||||
stream = ByteStream(b"") if stream is None else stream
|
||||
timeout = {} if timeout is None else timeout
|
||||
|
||||
urllib3_timeout = urllib3.util.Timeout(
|
||||
connect=timeout.connect_timeout, read=timeout.read_timeout
|
||||
connect=timeout.get("connect"), read=timeout.get("read")
|
||||
)
|
||||
chunked = request.headers.get("Transfer-Encoding") == "chunked"
|
||||
content_length = int(request.headers.get("Content-Length", "0"))
|
||||
body = request.stream if chunked or content_length else None
|
||||
|
||||
chunked = False
|
||||
content_length = 0
|
||||
for header_key, header_value in headers:
|
||||
header_key = header_key.lower()
|
||||
if header_key == b"transfer-encoding":
|
||||
chunked = header_value == b"chunked"
|
||||
if header_key == b"content-length":
|
||||
content_length = int(header_value.decode("ascii"))
|
||||
body = stream if chunked or content_length else None
|
||||
|
||||
scheme, host, port, path = url
|
||||
default_scheme = {80: b"http", 443: "https"}.get(port)
|
||||
if scheme == default_scheme:
|
||||
url_str = "%s://%s%s" % (
|
||||
scheme.decode("ascii"),
|
||||
host.decode("ascii"),
|
||||
path.decode("ascii"),
|
||||
)
|
||||
else:
|
||||
url_str = "%s://%s:%d%s" % (
|
||||
scheme.decode("ascii"),
|
||||
host.decode("ascii"),
|
||||
port,
|
||||
path.decode("ascii"),
|
||||
)
|
||||
|
||||
with as_network_error(MaxRetryError, SSLError, socket.error):
|
||||
conn = self.pool.urlopen(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
headers=dict(request.headers),
|
||||
method=method.decode(),
|
||||
url=url_str,
|
||||
headers=dict(
|
||||
[
|
||||
(key.decode("ascii"), value.decode("ascii"))
|
||||
for key, value in headers
|
||||
]
|
||||
),
|
||||
body=body,
|
||||
redirect=False,
|
||||
assert_same_host=False,
|
||||
@ -106,23 +142,20 @@ class URLLib3Dispatcher(SyncDispatcher):
|
||||
preload_content=False,
|
||||
chunked=chunked,
|
||||
timeout=urllib3_timeout,
|
||||
pool_timeout=timeout.pool_timeout,
|
||||
pool_timeout=timeout.get("pool"),
|
||||
)
|
||||
|
||||
def response_bytes() -> typing.Iterator[bytes]:
|
||||
def response_bytes() -> Iterator[bytes]:
|
||||
with as_network_error(socket.error):
|
||||
for chunk in conn.stream(4096, decode_content=False):
|
||||
yield chunk
|
||||
|
||||
return Response(
|
||||
status_code=conn.status,
|
||||
http_version="HTTP/1.1",
|
||||
headers=list(conn.headers.items()),
|
||||
stream=IteratorStream(
|
||||
iterator=response_bytes(), close_func=conn.release_conn
|
||||
),
|
||||
request=request,
|
||||
status_code = conn.status
|
||||
headers = list(conn.headers.items())
|
||||
response_stream = IteratorStream(
|
||||
iterator=response_bytes(), close_func=conn.release_conn
|
||||
)
|
||||
return (b"HTTP/1.1", status_code, conn.reason, headers, response_stream)
|
||||
|
||||
def close(self) -> None:
|
||||
self.pool.clear()
|
||||
|
||||
@ -2,10 +2,9 @@ import io
|
||||
import itertools
|
||||
import typing
|
||||
|
||||
from .._config import TimeoutTypes
|
||||
from .._content_streams import IteratorStream
|
||||
from .._models import Request, Response
|
||||
from .base import SyncDispatcher
|
||||
import httpcore
|
||||
|
||||
from .._content_streams import ByteStream, IteratorStream
|
||||
|
||||
|
||||
def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
|
||||
@ -16,9 +15,9 @@ def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
|
||||
return []
|
||||
|
||||
|
||||
class WSGIDispatch(SyncDispatcher):
|
||||
class WSGIDispatch(httpcore.SyncHTTPTransport):
|
||||
"""
|
||||
A custom SyncDispatcher that handles sending requests directly to an WSGI app.
|
||||
A custom transport that handles sending requests directly to an WSGI app.
|
||||
The simplest way to use this functionality is to use the `app` argument.
|
||||
|
||||
```
|
||||
@ -61,28 +60,46 @@ class WSGIDispatch(SyncDispatcher):
|
||||
self.script_name = script_name
|
||||
self.remote_addr = remote_addr
|
||||
|
||||
def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
|
||||
def request(
|
||||
self,
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = None,
|
||||
stream: httpcore.SyncByteStream = None,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes,
|
||||
int,
|
||||
bytes,
|
||||
typing.List[typing.Tuple[bytes, bytes]],
|
||||
httpcore.SyncByteStream,
|
||||
]:
|
||||
headers = [] if headers is None else headers
|
||||
stream = ByteStream(b"") if stream is None else stream
|
||||
|
||||
scheme, host, port, full_path = url
|
||||
path, _, query = full_path.partition(b"?")
|
||||
environ = {
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": request.url.scheme,
|
||||
"wsgi.input": io.BytesIO(request.read()),
|
||||
"wsgi.url_scheme": scheme.decode("ascii"),
|
||||
"wsgi.input": io.BytesIO(b"".join([chunk for chunk in stream])),
|
||||
"wsgi.errors": io.BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": False,
|
||||
"wsgi.run_once": False,
|
||||
"REQUEST_METHOD": request.method,
|
||||
"REQUEST_METHOD": method.decode(),
|
||||
"SCRIPT_NAME": self.script_name,
|
||||
"PATH_INFO": request.url.path,
|
||||
"QUERY_STRING": request.url.query,
|
||||
"SERVER_NAME": request.url.host,
|
||||
"SERVER_PORT": str(request.url.port),
|
||||
"PATH_INFO": path.decode("ascii"),
|
||||
"QUERY_STRING": query.decode("ascii"),
|
||||
"SERVER_NAME": host.decode("ascii"),
|
||||
"SERVER_PORT": str(port),
|
||||
"REMOTE_ADDR": self.remote_addr,
|
||||
}
|
||||
for key, value in request.headers.items():
|
||||
key = key.upper().replace("-", "_")
|
||||
for header_key, header_value in headers:
|
||||
key = header_key.decode("ascii").upper().replace("-", "_")
|
||||
if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
|
||||
key = "HTTP_" + key
|
||||
environ[key] = value
|
||||
environ[key] = header_value.decode("ascii")
|
||||
|
||||
seen_status = None
|
||||
seen_response_headers = None
|
||||
@ -106,10 +123,11 @@ class WSGIDispatch(SyncDispatcher):
|
||||
if seen_exc_info and self.raise_app_exceptions:
|
||||
raise seen_exc_info[1]
|
||||
|
||||
return Response(
|
||||
status_code=int(seen_status.split()[0]),
|
||||
http_version="HTTP/1.1",
|
||||
headers=seen_response_headers,
|
||||
stream=IteratorStream(chunk for chunk in result),
|
||||
request=request,
|
||||
)
|
||||
status_code = int(seen_status.split()[0])
|
||||
headers = [
|
||||
(key.encode("ascii"), value.encode("ascii"))
|
||||
for key, value in seen_response_headers
|
||||
]
|
||||
stream = IteratorStream(chunk for chunk in result)
|
||||
|
||||
return (b"HTTP/1.1", status_code, b"", headers, stream)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import typing
|
||||
|
||||
import httpcore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._models import Request, Response # pragma: nocover
|
||||
|
||||
@ -26,73 +28,36 @@ class HTTPError(Exception):
|
||||
|
||||
# Timeout exceptions...
|
||||
|
||||
|
||||
class TimeoutException(HTTPError):
|
||||
"""
|
||||
A base class for all timeouts.
|
||||
"""
|
||||
ConnectTimeout = httpcore.ConnectTimeout
|
||||
ReadTimeout = httpcore.ReadTimeout
|
||||
WriteTimeout = httpcore.WriteTimeout
|
||||
PoolTimeout = httpcore.PoolTimeout
|
||||
|
||||
|
||||
class ConnectTimeout(TimeoutException):
|
||||
"""
|
||||
Timeout while establishing a connection.
|
||||
"""
|
||||
# Core networking exceptions...
|
||||
|
||||
NetworkError = httpcore.NetworkError
|
||||
ReadError = httpcore.ReadError
|
||||
WriteError = httpcore.WriteError
|
||||
ConnectError = httpcore.ConnectError
|
||||
CloseError = httpcore.CloseError
|
||||
|
||||
|
||||
class ReadTimeout(TimeoutException):
|
||||
"""
|
||||
Timeout while reading response data.
|
||||
"""
|
||||
# Other transport exceptions...
|
||||
|
||||
|
||||
class WriteTimeout(TimeoutException):
|
||||
"""
|
||||
Timeout while writing request data.
|
||||
"""
|
||||
|
||||
|
||||
class PoolTimeout(TimeoutException):
|
||||
"""
|
||||
Timeout while waiting to acquire a connection from the pool.
|
||||
"""
|
||||
|
||||
|
||||
class ProxyError(HTTPError):
|
||||
"""
|
||||
Error from within a proxy
|
||||
"""
|
||||
ProxyError = httpcore.ProxyError
|
||||
ProtocolError = httpcore.ProtocolError
|
||||
|
||||
|
||||
# HTTP exceptions...
|
||||
|
||||
|
||||
class ProtocolError(HTTPError):
|
||||
"""
|
||||
Malformed HTTP.
|
||||
"""
|
||||
|
||||
|
||||
class DecodingError(HTTPError):
|
||||
"""
|
||||
Decoding of the response failed.
|
||||
"""
|
||||
|
||||
|
||||
# Network exceptions...
|
||||
|
||||
|
||||
class NetworkError(HTTPError):
|
||||
"""
|
||||
A failure occurred while trying to access the network.
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosed(NetworkError):
|
||||
"""
|
||||
Expected more data from peer, but connection was closed.
|
||||
"""
|
||||
|
||||
|
||||
# Redirect exceptions...
|
||||
|
||||
|
||||
|
||||
@ -174,6 +174,15 @@ class URL:
|
||||
def fragment(self) -> str:
|
||||
return self._uri_reference.fragment or ""
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.Tuple[bytes, bytes, int, bytes]:
|
||||
return (
|
||||
self.scheme.encode("ascii"),
|
||||
self.host.encode("ascii"),
|
||||
self.port,
|
||||
self.full_path.encode("ascii"),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_ssl(self) -> bool:
|
||||
return self.scheme == "https"
|
||||
|
||||
@ -11,7 +11,7 @@ combine_as_imports = True
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = True
|
||||
known_first_party = httpx,tests
|
||||
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn
|
||||
known_third_party = brotli,certifi,chardet,cryptography,hstspreload,httpcore,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
|
||||
5
setup.py
5
setup.py
@ -58,12 +58,9 @@ setup(
|
||||
"certifi",
|
||||
"hstspreload",
|
||||
"chardet==3.*",
|
||||
"h11>=0.8,<0.10",
|
||||
"h2==3.*",
|
||||
"idna==2.*",
|
||||
"rfc3986>=1.3,<2",
|
||||
"sniffio==1.*",
|
||||
"urllib3==1.*",
|
||||
"httpcore==0.7.*",
|
||||
],
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
|
||||
@ -148,15 +148,3 @@ async def test_100_continue(server):
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_uds(uds_server):
|
||||
url = uds_server.url
|
||||
uds = uds_server.config.uds
|
||||
assert uds is not None
|
||||
async with httpx.AsyncClient(uds=uds) as client:
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
assert response.encoding == "iso-8859-1"
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
from httpx import (
|
||||
@ -10,35 +10,47 @@ from httpx import (
|
||||
AsyncClient,
|
||||
Auth,
|
||||
DigestAuth,
|
||||
Headers,
|
||||
ProtocolError,
|
||||
Request,
|
||||
RequestBodyUnavailable,
|
||||
Response,
|
||||
)
|
||||
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx._dispatch.base import AsyncDispatcher
|
||||
from httpx._content_streams import ContentStream, JSONStream
|
||||
|
||||
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
def __init__(self, auth_header: str = "", status_code: int = 200) -> None:
|
||||
def get_header_value(headers, key, default=None):
|
||||
lookup = key.encode("ascii").lower()
|
||||
for header_key, header_value in headers:
|
||||
if header_key.lower() == lookup:
|
||||
return header_value.decode("ascii")
|
||||
return default
|
||||
|
||||
|
||||
class MockDispatch(httpcore.AsyncHTTPTransport):
|
||||
def __init__(self, auth_header: bytes = b"", status_code: int = 200) -> None:
|
||||
self.auth_header = auth_header
|
||||
self.status_code = status_code
|
||||
|
||||
async def send(
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
headers = [("www-authenticate", self.auth_header)] if self.auth_header else []
|
||||
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
|
||||
return Response(
|
||||
self.status_code, headers=headers, content=body, request=request
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
authorization = get_header_value(headers, "Authorization")
|
||||
response_headers = (
|
||||
[(b"www-authenticate", self.auth_header)] if self.auth_header else []
|
||||
)
|
||||
response_stream = JSONStream({"auth": authorization})
|
||||
return b"HTTP/1.1", self.status_code, b"", response_headers, response_stream
|
||||
|
||||
|
||||
class MockDigestAuthDispatch(AsyncDispatcher):
|
||||
class MockDigestAuthDispatch(httpcore.AsyncHTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: str = "SHA-256",
|
||||
@ -52,20 +64,26 @@ class MockDigestAuthDispatch(AsyncDispatcher):
|
||||
self._regenerate_nonce = regenerate_nonce
|
||||
self._response_count = 0
|
||||
|
||||
async def send(
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
if self._response_count < self.send_response_after_attempt:
|
||||
return self.challenge_send(request)
|
||||
return self.challenge_send(method, url, headers, stream)
|
||||
|
||||
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
authorization = get_header_value(headers, "Authorization")
|
||||
body = JSONStream({"auth": authorization})
|
||||
return b"HTTP/1.1", 200, b"", [], body
|
||||
|
||||
def challenge_send(self, request: Request) -> Response:
|
||||
def challenge_send(
|
||||
self, method: bytes, url: URL, headers: Headers, stream: ContentStream,
|
||||
) -> typing.Tuple[int, bytes, Headers, ContentStream]:
|
||||
self._response_count += 1
|
||||
nonce = (
|
||||
hashlib.sha256(os.urandom(8)).hexdigest()
|
||||
@ -88,9 +106,12 @@ class MockDigestAuthDispatch(AsyncDispatcher):
|
||||
)
|
||||
|
||||
headers = [
|
||||
("www-authenticate", 'Digest realm="httpx@example.org", ' + challenge_str)
|
||||
(
|
||||
b"www-authenticate",
|
||||
b'Digest realm="httpx@example.org", ' + challenge_str.encode("ascii"),
|
||||
)
|
||||
]
|
||||
return Response(401, headers=headers, content=b"", request=request)
|
||||
return b"HTTP/1.1", 401, b"", headers, ContentStream()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -234,7 +255,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
|
||||
async def test_digest_auth_200_response_including_digest_auth_header() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
|
||||
auth_header = b'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
|
||||
|
||||
client = AsyncClient(
|
||||
dispatch=MockDispatch(auth_header=auth_header, status_code=200)
|
||||
@ -251,7 +272,7 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
|
||||
client = AsyncClient(dispatch=MockDispatch(auth_header="", status_code=401))
|
||||
client = AsyncClient(dispatch=MockDispatch(auth_header=b"", status_code=401))
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 401
|
||||
@ -382,16 +403,16 @@ async def test_digest_auth_incorrect_credentials() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"auth_header",
|
||||
[
|
||||
'Digest realm="httpx@example.org", qop="auth"', # missing fields
|
||||
'realm="httpx@example.org", qop="auth"', # not starting with Digest
|
||||
'DigestZ realm="httpx@example.org", qop="auth"'
|
||||
'qop="auth,auth-int",nonce="abc",opaque="xyz"',
|
||||
'Digest realm="httpx@example.org", qop="auth,au', # malformed fields list
|
||||
b'Digest realm="httpx@example.org", qop="auth"', # missing fields
|
||||
b'realm="httpx@example.org", qop="auth"', # not starting with Digest
|
||||
b'DigestZ realm="httpx@example.org", qop="auth"'
|
||||
b'qop="auth,auth-int",nonce="abc",opaque="xyz"',
|
||||
b'Digest realm="httpx@example.org", qop="auth,au', # malformed fields list
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_raises_protocol_error_on_malformed_header(
|
||||
auth_header: str,
|
||||
auth_header: bytes,
|
||||
) -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
@ -439,7 +460,7 @@ async def test_auth_history() -> None:
|
||||
|
||||
url = "https://example.org/"
|
||||
auth = RepeatAuth(repeat=2)
|
||||
client = AsyncClient(dispatch=MockDispatch(auth_header="abc"))
|
||||
client = AsyncClient(dispatch=MockDispatch(auth_header=b"abc"))
|
||||
|
||||
response = await client.get(url, auth=auth)
|
||||
assert response.status_code == 200
|
||||
|
||||
@ -1,29 +1,43 @@
|
||||
import json
|
||||
import typing
|
||||
from http.cookiejar import Cookie, CookieJar
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
from httpx import AsyncClient, Cookies, Request, Response
|
||||
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx._dispatch.base import AsyncDispatcher
|
||||
from httpx import AsyncClient, Cookies
|
||||
from httpx._content_streams import ByteStream, ContentStream, JSONStream
|
||||
|
||||
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
def get_header_value(headers, key, default=None):
|
||||
lookup = key.encode("ascii").lower()
|
||||
for header_key, header_value in headers:
|
||||
if header_key.lower() == lookup:
|
||||
return header_value.decode("ascii")
|
||||
return default
|
||||
|
||||
|
||||
class MockDispatch(httpcore.AsyncHTTPTransport):
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if request.url.path.startswith("/echo_cookies"):
|
||||
body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
elif request.url.path.startswith("/set_cookie"):
|
||||
headers = {"set-cookie": "example-name=example-value"}
|
||||
return Response(200, headers=headers, request=request)
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
host, scheme, port, path = url
|
||||
if path.startswith(b"/echo_cookies"):
|
||||
cookie = get_header_value(headers, "cookie")
|
||||
body = JSONStream({"cookies": cookie})
|
||||
return b"HTTP/1.1", 200, b"OK", [], body
|
||||
elif path.startswith(b"/set_cookie"):
|
||||
headers = [(b"set-cookie", b"example-name=example-value")]
|
||||
body = ByteStream(b"")
|
||||
return b"HTTP/1.1", 200, b"OK", headers, body
|
||||
else:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -1,26 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import typing
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
from httpx import AsyncClient, Headers, Request, Response, __version__
|
||||
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx._dispatch.base import AsyncDispatcher
|
||||
from httpx import AsyncClient, Headers, __version__
|
||||
from httpx._content_streams import ContentStream, JSONStream
|
||||
|
||||
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
class MockDispatch(httpcore.AsyncHTTPTransport):
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if request.url.path.startswith("/echo_headers"):
|
||||
request_headers = dict(request.headers.items())
|
||||
body = json.dumps({"headers": request_headers}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
headers_dict = dict(
|
||||
[(key.decode("ascii"), value.decode("ascii")) for key, value in headers]
|
||||
)
|
||||
body = JSONStream({"headers": headers_dict})
|
||||
return b"HTTP/1.1", 200, b"OK", [], body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -24,7 +24,7 @@ def test_proxies_parameter(proxies, expected_proxies):
|
||||
|
||||
for proxy_key, url in expected_proxies:
|
||||
assert proxy_key in client.proxies
|
||||
assert client.proxies[proxy_key].proxy_url == url
|
||||
assert client.proxies[proxy_key].proxy_origin == httpx.URL(url).raw[:3]
|
||||
|
||||
assert len(expected_proxies) == len(client.proxies)
|
||||
|
||||
@ -81,7 +81,7 @@ def test_dispatcher_for_request(url, proxies, expected):
|
||||
if expected is None:
|
||||
assert dispatcher is client.dispatch
|
||||
else:
|
||||
assert dispatcher.proxy_url == expected
|
||||
assert dispatcher.proxy_origin == httpx.URL(expected).raw[:3]
|
||||
|
||||
|
||||
def test_unsupported_proxy_scheme():
|
||||
@ -115,4 +115,4 @@ def test_proxies_environ(monkeypatch, url, env, expected):
|
||||
if expected is None:
|
||||
assert dispatcher == client.dispatch
|
||||
else:
|
||||
assert dispatcher.proxy_url == expected
|
||||
assert dispatcher.proxy_origin == httpx.URL(expected).raw[:3]
|
||||
|
||||
@ -1,23 +1,26 @@
|
||||
import json
|
||||
import typing
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
from httpx import URL, AsyncClient, QueryParams, Request, Response
|
||||
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx._dispatch.base import AsyncDispatcher
|
||||
from httpx import URL, AsyncClient, Headers, QueryParams
|
||||
from httpx._content_streams import ContentStream, JSONStream
|
||||
|
||||
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
class MockDispatch(httpcore.AsyncHTTPTransport):
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if request.url.path.startswith("/echo_queryparams"):
|
||||
body = json.dumps({"ok": "ok"}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
headers = Headers()
|
||||
body = JSONStream({"ok": "ok"})
|
||||
return b"HTTP/1.1", 200, b"OK", headers, body
|
||||
|
||||
|
||||
def test_client_queryparams():
|
||||
|
||||
@ -1,118 +1,146 @@
|
||||
import json
|
||||
import typing
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
from httpx import (
|
||||
URL,
|
||||
AsyncClient,
|
||||
NotRedirectResponse,
|
||||
Request,
|
||||
RequestBodyUnavailable,
|
||||
Response,
|
||||
TooManyRedirects,
|
||||
codes,
|
||||
)
|
||||
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx._content_streams import AsyncIteratorStream
|
||||
from httpx._dispatch.base import AsyncDispatcher
|
||||
from httpx._content_streams import AsyncIteratorStream, ByteStream, ContentStream
|
||||
|
||||
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
def get_header_value(headers, key, default=None):
|
||||
lookup = key.encode("ascii").lower()
|
||||
for header_key, header_value in headers:
|
||||
if header_key.lower() == lookup:
|
||||
return header_value.decode("ascii")
|
||||
return default
|
||||
|
||||
|
||||
class MockDispatch(httpcore.AsyncHTTPTransport):
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if request.url.path == "/no_redirect":
|
||||
return Response(codes.OK, request=request)
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
scheme, host, port, path = url
|
||||
path, _, query = path.partition(b"?")
|
||||
if path == b"/no_redirect":
|
||||
return b"HTTP/1.1", codes.OK, b"OK", [], ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/redirect_301":
|
||||
elif path == b"/redirect_301":
|
||||
|
||||
async def body():
|
||||
yield b"<a href='https://example.org/'>here</a>"
|
||||
|
||||
status_code = codes.MOVED_PERMANENTLY
|
||||
headers = {"location": "https://example.org/"}
|
||||
headers = [(b"location", b"https://example.org/")]
|
||||
stream = AsyncIteratorStream(aiterator=body())
|
||||
return Response(
|
||||
status_code, stream=stream, headers=headers, request=request
|
||||
)
|
||||
return b"HTTP/1.1", status_code, b"Moved Permanently", headers, stream
|
||||
|
||||
elif request.url.path == "/redirect_302":
|
||||
elif path == b"/redirect_302":
|
||||
status_code = codes.FOUND
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
headers = [(b"location", b"https://example.org/")]
|
||||
return b"HTTP/1.1", status_code, b"Found", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/redirect_303":
|
||||
elif path == b"/redirect_303":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
headers = [(b"location", b"https://example.org/")]
|
||||
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/relative_redirect":
|
||||
headers = {"location": "/"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
elif path == b"/relative_redirect":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = [(b"location", b"/")]
|
||||
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/malformed_redirect":
|
||||
headers = {"location": "https://:443/"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
elif path == b"/malformed_redirect":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = [(b"location", b"https://:443/")]
|
||||
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/no_scheme_redirect":
|
||||
headers = {"location": "//example.org/"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
elif path == b"/no_scheme_redirect":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = [(b"location", b"//example.org/")]
|
||||
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/multiple_redirects":
|
||||
params = parse_qs(request.url.query)
|
||||
elif path == b"/multiple_redirects":
|
||||
params = parse_qs(query.decode("ascii"))
|
||||
count = int(params.get("count", "0")[0])
|
||||
redirect_count = count - 1
|
||||
code = codes.SEE_OTHER if count else codes.OK
|
||||
location = "/multiple_redirects"
|
||||
phrase = b"See Other" if count else b"OK"
|
||||
location = b"/multiple_redirects"
|
||||
if redirect_count:
|
||||
location += "?count=" + str(redirect_count)
|
||||
headers = {"location": location} if count else {}
|
||||
return Response(code, headers=headers, request=request)
|
||||
location += b"?count=" + str(redirect_count).encode("ascii")
|
||||
headers = [(b"location", location)] if count else []
|
||||
return b"HTTP/1.1", code, phrase, headers, ByteStream(b"")
|
||||
|
||||
if request.url.path == "/redirect_loop":
|
||||
headers = {"location": "/redirect_loop"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
if path == b"/redirect_loop":
|
||||
code = codes.SEE_OTHER
|
||||
headers = [(b"location", b"/redirect_loop")]
|
||||
return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/cross_domain":
|
||||
headers = {"location": "https://example.org/cross_domain_target"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
elif path == b"/cross_domain":
|
||||
code = codes.SEE_OTHER
|
||||
headers = [(b"location", b"https://example.org/cross_domain_target")]
|
||||
return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/cross_domain_target":
|
||||
headers = dict(request.headers.items())
|
||||
content = json.dumps({"headers": headers}).encode()
|
||||
return Response(codes.OK, content=content, request=request)
|
||||
elif path == b"/cross_domain_target":
|
||||
headers_dict = dict(
|
||||
[(key.decode("ascii"), value.decode("ascii")) for key, value in headers]
|
||||
)
|
||||
content = ByteStream(json.dumps({"headers": headers_dict}).encode())
|
||||
return b"HTTP/1.1", 200, b"OK", [], content
|
||||
|
||||
elif request.url.path == "/redirect_body":
|
||||
body = b"".join([part async for part in request.stream])
|
||||
headers = {"location": "/redirect_body_target"}
|
||||
return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
|
||||
elif path == b"/redirect_body":
|
||||
_ = b"".join([part async for part in stream])
|
||||
code = codes.PERMANENT_REDIRECT
|
||||
headers = [(b"location", b"/redirect_body_target")]
|
||||
return b"HTTP/1.1", code, b"Permanent Redirect", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/redirect_no_body":
|
||||
content = b"".join([part async for part in request.stream])
|
||||
headers = {"location": "/redirect_body_target"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
elif path == b"/redirect_no_body":
|
||||
_ = b"".join([part async for part in stream])
|
||||
code = codes.SEE_OTHER
|
||||
headers = [(b"location", b"/redirect_body_target")]
|
||||
return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
elif request.url.path == "/redirect_body_target":
|
||||
content = b"".join([part async for part in request.stream])
|
||||
headers = dict(request.headers.items())
|
||||
body = json.dumps({"body": content.decode(), "headers": headers}).encode()
|
||||
return Response(codes.OK, content=body, request=request)
|
||||
elif path == b"/redirect_body_target":
|
||||
content = b"".join([part async for part in stream])
|
||||
headers_dict = dict(
|
||||
[(key.decode("ascii"), value.decode("ascii")) for key, value in headers]
|
||||
)
|
||||
body = ByteStream(
|
||||
json.dumps({"body": content.decode(), "headers": headers_dict}).encode()
|
||||
)
|
||||
return b"HTTP/1.1", 200, b"OK", [], body
|
||||
|
||||
elif request.url.path == "/cross_subdomain":
|
||||
if request.headers["host"] != "www.example.org":
|
||||
headers = {"location": "https://www.example.org/cross_subdomain"}
|
||||
return Response(
|
||||
codes.PERMANENT_REDIRECT, headers=headers, request=request
|
||||
elif path == b"/cross_subdomain":
|
||||
host = get_header_value(headers, "host")
|
||||
if host != "www.example.org":
|
||||
headers = [(b"location", b"https://www.example.org/cross_subdomain")]
|
||||
return (
|
||||
b"HTTP/1.1",
|
||||
codes.PERMANENT_REDIRECT,
|
||||
b"Permanent Redirect",
|
||||
headers,
|
||||
ByteStream(b""),
|
||||
)
|
||||
else:
|
||||
return Response(codes.OK, content=b"Hello, world!", request=request)
|
||||
return b"HTTP/1.1", 200, b"OK", [], ByteStream(b"Hello, world!")
|
||||
|
||||
return Response(codes.OK, content=b"Hello, world!", request=request)
|
||||
return b"HTTP/1.1", 200, b"OK", [], ByteStream(b"Hello, world!")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
@ -326,43 +354,53 @@ async def test_cross_subdomain_redirect():
|
||||
assert response.url == URL("https://www.example.org/cross_subdomain")
|
||||
|
||||
|
||||
class MockCookieDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
class MockCookieDispatch(httpcore.AsyncHTTPTransport):
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if request.url.path == "/":
|
||||
if "cookie" in request.headers:
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]],
|
||||
stream: ContentStream,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
|
||||
]:
|
||||
scheme, host, port, path = url
|
||||
if path == b"/":
|
||||
cookie = get_header_value(headers, "Cookie")
|
||||
if cookie is not None:
|
||||
content = b"Logged in"
|
||||
else:
|
||||
content = b"Not logged in"
|
||||
return Response(codes.OK, content=content, request=request)
|
||||
return b"HTTP/1.1", 200, b"OK", [], ByteStream(content)
|
||||
|
||||
elif request.url.path == "/login":
|
||||
elif path == b"/login":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = {
|
||||
"location": "/",
|
||||
"set-cookie": (
|
||||
"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; "
|
||||
"httponly; samesite=lax"
|
||||
headers = [
|
||||
(b"location", b"/"),
|
||||
(
|
||||
b"set-cookie",
|
||||
(
|
||||
b"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; "
|
||||
b"httponly; samesite=lax"
|
||||
),
|
||||
),
|
||||
}
|
||||
]
|
||||
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/logout":
|
||||
elif path == b"/logout":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = {
|
||||
"location": "/",
|
||||
"set-cookie": (
|
||||
"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; "
|
||||
"httponly; samesite=lax"
|
||||
headers = [
|
||||
(b"location", b"/"),
|
||||
(
|
||||
b"set-cookie",
|
||||
(
|
||||
b"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; "
|
||||
b"httponly; samesite=lax"
|
||||
),
|
||||
),
|
||||
}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
]
|
||||
return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
|
||||
@ -295,15 +295,6 @@ def server():
|
||||
yield from serve_in_thread(server)
|
||||
|
||||
|
||||
@pytest.fixture(scope=SERVER_SCOPE)
|
||||
def uds_server():
|
||||
uds = "test_server.sock"
|
||||
config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
|
||||
server = TestServer(config=config)
|
||||
yield from serve_in_thread(server)
|
||||
os.remove(uds)
|
||||
|
||||
|
||||
@pytest.fixture(scope=SERVER_SCOPE)
|
||||
def https_server(cert_pem_file, cert_private_key_file):
|
||||
config = Config(
|
||||
@ -317,19 +308,3 @@ def https_server(cert_pem_file, cert_private_key_file):
|
||||
)
|
||||
server = TestServer(config=config)
|
||||
yield from serve_in_thread(server)
|
||||
|
||||
|
||||
@pytest.fixture(scope=SERVER_SCOPE)
|
||||
def https_uds_server(cert_pem_file, cert_private_key_file):
|
||||
uds = "https_test_server.sock"
|
||||
config = Config(
|
||||
app=app,
|
||||
lifespan="off",
|
||||
ssl_certfile=cert_pem_file,
|
||||
ssl_keyfile=cert_private_key_file,
|
||||
uds=uds,
|
||||
loop="asyncio",
|
||||
)
|
||||
server = TestServer(config=config)
|
||||
yield from serve_in_thread(server)
|
||||
os.remove(uds)
|
||||
|
||||
@ -1,246 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._dispatch.connection_pool import ConnectionPool
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_keepalive_connections(server):
|
||||
"""
|
||||
Connections should default to staying in a keep-alive state.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_keepalive_timeout(server):
|
||||
"""
|
||||
Keep-alive connections should timeout.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
http.next_keepalive_check = 0.0
|
||||
await http.check_keepalive_expiry()
|
||||
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
async with ConnectionPool() as http:
|
||||
http.KEEP_ALIVE_EXPIRY = 0.0
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
http.next_keepalive_check = 0.0
|
||||
await http.check_keepalive_expiry()
|
||||
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_differing_connection_keys(server):
|
||||
"""
|
||||
Connections to differing connection keys should result in multiple connections.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_soft_limit(server):
|
||||
"""
|
||||
The soft_limit config should limit the maximum number of keep-alive connections.
|
||||
"""
|
||||
pool_limits = httpx.PoolLimits(soft_limit=1)
|
||||
|
||||
async with ConnectionPool(pool_limits=pool_limits) as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_streaming_response_holds_connection(server):
|
||||
"""
|
||||
A streaming request should hold the connection open until the response is read.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
await response.aread()
|
||||
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_multiple_concurrent_connections(server):
|
||||
"""
|
||||
Multiple concurrent requests should open multiple concurrent connections.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response_a = await http.request("GET", server.url)
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
response_b = await http.request("GET", server.url)
|
||||
assert len(http.active_connections) == 2
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
await response_b.aread()
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
await response_a.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_close_connections(server):
|
||||
"""
|
||||
Using a `Connection: close` header should close the connection.
|
||||
"""
|
||||
headers = [(b"connection", b"close")]
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url, headers=headers)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_standard_response_close(server):
|
||||
"""
|
||||
A standard close should keep the connection open.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
await response.aclose()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_premature_response_close(server):
|
||||
"""
|
||||
A premature close should close the connection.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aclose()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_keepalive_connection_closed_by_server_is_reestablished(server):
|
||||
"""
|
||||
Upon keep-alive connection closed by remote a new connection
|
||||
should be reestablished.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
|
||||
# Shutdown the server to close the keep-alive connection
|
||||
await server.restart()
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
|
||||
"""
|
||||
Upon keep-alive connection closed by remote a new connection
|
||||
should be reestablished.
|
||||
"""
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
|
||||
# Shutdown the server to close the keep-alive connection
|
||||
await server.restart()
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_connection_closed_free_semaphore_on_acquire(server):
|
||||
"""
|
||||
Verify that max_connections semaphore is released
|
||||
properly on a disconnected connection.
|
||||
"""
|
||||
async with ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
|
||||
# Close the connection so we're forced to recycle it
|
||||
await server.restart()
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_connection_pool_closed_close_keepalive_and_free_semaphore(server):
|
||||
"""
|
||||
Closing the connection pool should close remaining keepalive connections and
|
||||
release the max_connections semaphore.
|
||||
"""
|
||||
http = ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1))
|
||||
|
||||
async with http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert response.status_code == 200
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
# Perform a second round of requests to make sure the max_connections semaphore
|
||||
# was released properly.
|
||||
|
||||
async with http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert response.status_code == 200
|
||||
@ -1,44 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._config import SSLConfig
|
||||
from httpx._dispatch.connection import HTTPConnection
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_get(server):
|
||||
async with HTTPConnection(origin=server.url) as conn:
|
||||
response = await conn.request("GET", server.url)
|
||||
await response.aread()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_post(server):
|
||||
async with HTTPConnection(origin=server.url) as conn:
|
||||
response = await conn.request("GET", server.url, data=b"Hello, world!")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_premature_close(server):
|
||||
with pytest.raises(httpx.ConnectionClosed):
|
||||
async with HTTPConnection(origin=server.url) as conn:
|
||||
response = await conn.request(
|
||||
"GET", server.url.copy_with(path="/premature_close")
|
||||
)
|
||||
await response.aread()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_https_get_with_ssl(https_server, ca_cert_pem_file):
|
||||
"""
|
||||
An HTTPS request, with SSL configuration set on the client.
|
||||
"""
|
||||
ssl = SSLConfig(verify=ca_cert_pem_file)
|
||||
async with HTTPConnection(origin=https_server.url, ssl=ssl) as conn:
|
||||
response = await conn.request("GET", https_server.url)
|
||||
await response.aread()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
@ -1,158 +0,0 @@
|
||||
import json
|
||||
import socket
|
||||
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import pytest
|
||||
from h2.settings import SettingCodes
|
||||
|
||||
from httpx import AsyncClient, Response, TimeoutException
|
||||
from httpx._dispatch.connection_pool import ConnectionPool
|
||||
|
||||
from .utils import MockHTTP2Backend
|
||||
|
||||
|
||||
async def app(request):
|
||||
method = request.method
|
||||
path = request.url.path
|
||||
body = b"".join([part async for part in request.stream])
|
||||
content = json.dumps(
|
||||
{"method": method, "path": path, "body": body.decode()}
|
||||
).encode()
|
||||
headers = {"Content-Length": str(len(content))}
|
||||
return Response(200, headers=headers, content=content, request=request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_get_request():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
dispatch = ConnectionPool(backend=backend, http2=True)
|
||||
|
||||
async with AsyncClient(dispatch=dispatch) as client:
|
||||
response = await client.get("http://example.org")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {"method": "GET", "path": "/", "body": ""}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_post_request():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
dispatch = ConnectionPool(backend=backend, http2=True)
|
||||
|
||||
async with AsyncClient(dispatch=dispatch) as client:
|
||||
response = await client.post("http://example.org", data=b"<data>")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"body": "<data>",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_large_post_request():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
dispatch = ConnectionPool(backend=backend, http2=True)
|
||||
|
||||
data = b"a" * 100000
|
||||
async with AsyncClient(dispatch=dispatch) as client:
|
||||
response = await client.post("http://example.org", data=data)
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"body": data.decode(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_multiple_requests():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
dispatch = ConnectionPool(backend=backend, http2=True)
|
||||
|
||||
async with AsyncClient(dispatch=dispatch) as client:
|
||||
response_1 = await client.get("http://example.org/1")
|
||||
response_2 = await client.get("http://example.org/2")
|
||||
response_3 = await client.get("http://example.org/3")
|
||||
|
||||
assert response_1.status_code == 200
|
||||
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
|
||||
|
||||
assert response_2.status_code == 200
|
||||
assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
|
||||
|
||||
assert response_3.status_code == 200
|
||||
assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_reconnect():
|
||||
"""
|
||||
If a connection has been dropped between requests, then we should
|
||||
be seemlessly reconnected.
|
||||
"""
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
dispatch = ConnectionPool(backend=backend, http2=True)
|
||||
|
||||
async with AsyncClient(dispatch=dispatch) as client:
|
||||
response_1 = await client.get("http://example.org/1")
|
||||
backend.server.close_connection = True
|
||||
response_2 = await client.get("http://example.org/2")
|
||||
|
||||
assert response_1.status_code == 200
|
||||
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
|
||||
|
||||
assert response_2.status_code == 200
|
||||
assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_settings_in_handshake():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
dispatch = ConnectionPool(backend=backend, http2=True)
|
||||
|
||||
async with AsyncClient(dispatch=dispatch) as client:
|
||||
await client.get("http://example.org")
|
||||
|
||||
h2_conn = backend.server.conn
|
||||
|
||||
assert isinstance(h2_conn, h2.connection.H2Connection)
|
||||
expected_settings = {
|
||||
SettingCodes.HEADER_TABLE_SIZE: 4096,
|
||||
SettingCodes.ENABLE_PUSH: 0,
|
||||
SettingCodes.MAX_CONCURRENT_STREAMS: 100,
|
||||
SettingCodes.INITIAL_WINDOW_SIZE: 65535,
|
||||
SettingCodes.MAX_FRAME_SIZE: 16384,
|
||||
SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
|
||||
# This one's here because h2 helpfully populates remote_settings
|
||||
# with default values even if the peer doesn't send the setting.
|
||||
SettingCodes.ENABLE_CONNECT_PROTOCOL: 0,
|
||||
}
|
||||
assert dict(h2_conn.remote_settings) == expected_settings
|
||||
|
||||
# We don't expect the ENABLE_CONNECT_PROTOCOL to be in the handshake
|
||||
expected_settings.pop(SettingCodes.ENABLE_CONNECT_PROTOCOL)
|
||||
|
||||
assert len(backend.server.settings_changed) == 1
|
||||
settings = backend.server.settings_changed[0]
|
||||
|
||||
assert isinstance(settings, h2.events.RemoteSettingsChanged)
|
||||
assert len(settings.changed_settings) == len(expected_settings)
|
||||
for setting_code, changed_setting in settings.changed_settings.items():
|
||||
assert isinstance(changed_setting, h2.settings.ChangedSetting)
|
||||
assert changed_setting.new_value == expected_settings[setting_code]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_http2_live_request():
|
||||
async with AsyncClient(http2=True) as client:
|
||||
try:
|
||||
resp = await client.get("https://nghttp2.org/httpbin/anything")
|
||||
except TimeoutException: # pragma: nocover
|
||||
pytest.xfail(reason="nghttp2.org appears to be unresponsive")
|
||||
except socket.gaierror: # pragma: nocover
|
||||
pytest.xfail(reason="You appear to be offline")
|
||||
assert resp.status_code == 200
|
||||
assert resp.http_version == "HTTP/2"
|
||||
@ -1,207 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._dispatch.proxy_http import HTTPProxy
|
||||
|
||||
from .utils import MockRawSocketBackend
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_tunnel_success():
|
||||
raw_io = MockRawSocketBackend(
|
||||
data_to_send=(
|
||||
[
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: proxy-server\r\n"
|
||||
b"\r\n",
|
||||
b"HTTP/1.1 404 Not Found\r\n"
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: origin-server\r\n"
|
||||
b"\r\n",
|
||||
]
|
||||
),
|
||||
)
|
||||
async with HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
|
||||
) as proxy:
|
||||
response = await proxy.request("GET", "http://example.com")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.headers["Server"] == "origin-server"
|
||||
|
||||
assert response.request.method == "GET"
|
||||
assert response.request.url == "http://example.com"
|
||||
assert response.request.headers["Host"] == "example.com"
|
||||
|
||||
recv = raw_io.received_data
|
||||
assert len(recv) == 3
|
||||
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
|
||||
assert recv[1].startswith(
|
||||
b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
|
||||
)
|
||||
assert recv[2].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [300, 304, 308, 401, 500])
|
||||
async def test_proxy_tunnel_non_2xx_response(status_code):
|
||||
raw_io = MockRawSocketBackend(
|
||||
data_to_send=(
|
||||
[
|
||||
b"HTTP/1.1 %d Not Good\r\n" % status_code,
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: proxy-server\r\n"
|
||||
b"\r\n",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.ProxyError) as e:
|
||||
async with HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
|
||||
) as proxy:
|
||||
await proxy.request("GET", "http://example.com")
|
||||
|
||||
# ProxyError.request should be the CONNECT request not the original request
|
||||
assert e.value.request.method == "CONNECT"
|
||||
assert e.value.request.headers["Host"] == "127.0.0.1:8000"
|
||||
assert e.value.request.url.full_path == "example.com:80"
|
||||
|
||||
# ProxyError.response should be the CONNECT response
|
||||
assert e.value.response.status_code == status_code
|
||||
assert e.value.response.headers["Server"] == "proxy-server"
|
||||
|
||||
# Verify that the request wasn't sent after receiving an error from CONNECT
|
||||
recv = raw_io.received_data
|
||||
assert len(recv) == 2
|
||||
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
|
||||
assert recv[1].startswith(
|
||||
b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_tunnel_start_tls():
|
||||
raw_io = MockRawSocketBackend(
|
||||
data_to_send=(
|
||||
[
|
||||
# Tunnel Response
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: proxy-server\r\n"
|
||||
b"\r\n",
|
||||
# Response 1
|
||||
b"HTTP/1.1 404 Not Found\r\n"
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: origin-server\r\n"
|
||||
b"Connection: keep-alive\r\n"
|
||||
b"Content-Length: 0\r\n"
|
||||
b"\r\n",
|
||||
# Response 2
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: origin-server\r\n"
|
||||
b"Connection: keep-alive\r\n"
|
||||
b"Content-Length: 0\r\n"
|
||||
b"\r\n",
|
||||
]
|
||||
),
|
||||
)
|
||||
async with HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
|
||||
) as proxy:
|
||||
resp = await proxy.request("GET", "https://example.com")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.headers["Server"] == "origin-server"
|
||||
|
||||
assert resp.request.method == "GET"
|
||||
assert resp.request.url == "https://example.com"
|
||||
assert resp.request.headers["Host"] == "example.com"
|
||||
|
||||
await resp.aread()
|
||||
|
||||
# Make another request to see that the tunnel is re-used.
|
||||
resp = await proxy.request("GET", "https://example.com/target")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Server"] == "origin-server"
|
||||
|
||||
assert resp.request.method == "GET"
|
||||
assert resp.request.url == "https://example.com/target"
|
||||
assert resp.request.headers["Host"] == "example.com"
|
||||
|
||||
await resp.aread()
|
||||
|
||||
recv = raw_io.received_data
|
||||
assert len(recv) == 5
|
||||
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
|
||||
assert recv[1].startswith(
|
||||
b"CONNECT example.com:443 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
|
||||
)
|
||||
assert recv[2] == b"--- START_TLS(example.com) ---"
|
||||
assert recv[3].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
|
||||
assert recv[4].startswith(b"GET /target HTTP/1.1\r\nhost: example.com\r\n")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("proxy_mode", ["FORWARD_ONLY", "DEFAULT"])
|
||||
async def test_proxy_forwarding(proxy_mode):
|
||||
raw_io = MockRawSocketBackend(
|
||||
data_to_send=(
|
||||
[
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
|
||||
b"Server: origin-server\r\n"
|
||||
b"\r\n"
|
||||
]
|
||||
),
|
||||
)
|
||||
async with HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000",
|
||||
backend=raw_io,
|
||||
proxy_mode=proxy_mode,
|
||||
proxy_headers={"Proxy-Authorization": "test", "Override": "2"},
|
||||
) as proxy:
|
||||
response = await proxy.request(
|
||||
"GET", "http://example.com", headers={"override": "1"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["Server"] == "origin-server"
|
||||
|
||||
assert response.request.method == "GET"
|
||||
assert response.request.url == "http://127.0.0.1:8000"
|
||||
assert response.request.url.full_path == "http://example.com"
|
||||
assert response.request.headers["Host"] == "example.com"
|
||||
|
||||
recv = raw_io.received_data
|
||||
assert len(recv) == 2
|
||||
assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
|
||||
assert recv[1].startswith(
|
||||
b"GET http://example.com HTTP/1.1\r\nhost: example.com\r\n"
|
||||
)
|
||||
assert b"proxy-authorization: test" in recv[1]
|
||||
assert b"override: 1" in recv[1]
|
||||
|
||||
|
||||
def test_proxy_url_with_username_and_password():
|
||||
proxy = HTTPProxy("http://user:password@example.com:1080")
|
||||
|
||||
assert proxy.proxy_url == "http://example.com:1080"
|
||||
assert proxy.proxy_headers["Proxy-Authorization"] == "Basic dXNlcjpwYXNzd29yZA=="
|
||||
|
||||
|
||||
def test_proxy_repr():
|
||||
proxy = HTTPProxy(
|
||||
"http://127.0.0.1:1080",
|
||||
proxy_headers={"Custom": "Header"},
|
||||
proxy_mode="DEFAULT",
|
||||
)
|
||||
|
||||
assert repr(proxy) == (
|
||||
"HTTPProxy(proxy_url=URL('http://127.0.0.1:1080') "
|
||||
"proxy_headers=Headers({'custom': 'Header'}) "
|
||||
"proxy_mode='DEFAULT')"
|
||||
)
|
||||
@ -1,204 +0,0 @@
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from httpx import Request, Timeout
|
||||
from httpx._backends.base import BaseSocketStream, lookup_backend
|
||||
|
||||
|
||||
class MockHTTP2Backend:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.backend = lookup_backend()
|
||||
self.server = None
|
||||
|
||||
async def open_tcp_stream(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: Timeout,
|
||||
) -> BaseSocketStream:
|
||||
self.server = MockHTTP2Server(self.app, backend=self.backend)
|
||||
return self.server
|
||||
|
||||
# Defer all other attributes and methods to the underlying backend.
|
||||
def __getattr__(self, name: str) -> typing.Any:
|
||||
return getattr(self.backend, name)
|
||||
|
||||
|
||||
class MockHTTP2Server(BaseSocketStream):
|
||||
def __init__(self, app, backend: MockHTTP2Backend):
|
||||
config = h2.config.H2Configuration(client_side=False)
|
||||
self.conn = h2.connection.H2Connection(config=config)
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.buffer = b""
|
||||
self.requests = {}
|
||||
self.close_connection = False
|
||||
self.return_data = {}
|
||||
self.settings_changed = []
|
||||
|
||||
# Socket stream interface
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
return "HTTP/2"
|
||||
|
||||
async def read(self, n, timeout, flag=None) -> bytes:
|
||||
send, self.buffer = self.buffer[:n], self.buffer[n:]
|
||||
return send
|
||||
|
||||
async def write(self, data: bytes, timeout) -> None:
|
||||
if not data:
|
||||
return
|
||||
events = self.conn.receive_data(data)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
self.request_received(event.headers, event.stream_id)
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
self.receive_data(event.data, event.stream_id)
|
||||
# This should send an UPDATE_WINDOW for both the stream and the
|
||||
# connection increasing it by the amount
|
||||
# consumed keeping the flow control window constant
|
||||
flow_control_consumed = event.flow_controlled_length
|
||||
if flow_control_consumed > 0:
|
||||
self.conn.increment_flow_control_window(flow_control_consumed)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
self.conn.increment_flow_control_window(
|
||||
flow_control_consumed, event.stream_id
|
||||
)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
await self.stream_complete(event.stream_id)
|
||||
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
self.settings_changed.append(event)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.close_connection
|
||||
|
||||
# Server implementation
|
||||
|
||||
def request_received(self, headers, stream_id):
|
||||
"""
|
||||
Handler for when the initial part of the HTTP request is received.
|
||||
"""
|
||||
if stream_id not in self.requests:
|
||||
self.requests[stream_id] = []
|
||||
self.requests[stream_id].append({"headers": headers, "data": b""})
|
||||
|
||||
def receive_data(self, data, stream_id):
|
||||
"""
|
||||
Handler for when a data part of the HTTP request is received.
|
||||
"""
|
||||
self.requests[stream_id][-1]["data"] += data
|
||||
|
||||
async def stream_complete(self, stream_id):
|
||||
"""
|
||||
Handler for when the HTTP request is completed.
|
||||
"""
|
||||
request = self.requests[stream_id].pop(0)
|
||||
if not self.requests[stream_id]:
|
||||
del self.requests[stream_id]
|
||||
|
||||
headers_dict = dict(request["headers"])
|
||||
|
||||
method = headers_dict[b":method"].decode("ascii")
|
||||
url = "%s://%s%s" % (
|
||||
headers_dict[b":scheme"].decode("ascii"),
|
||||
headers_dict[b":authority"].decode("ascii"),
|
||||
headers_dict[b":path"].decode("ascii"),
|
||||
)
|
||||
headers = [(k, v) for k, v in request["headers"] if not k.startswith(b":")]
|
||||
data = request["data"]
|
||||
|
||||
# Call out to the app.
|
||||
request = Request(method, url, headers=headers, data=data)
|
||||
response = await self.app(request)
|
||||
|
||||
# Write the response to the buffer.
|
||||
status_code_bytes = str(response.status_code).encode("ascii")
|
||||
response_headers = [(b":status", status_code_bytes)] + response.headers.raw
|
||||
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
self.return_data[stream_id] = response.content
|
||||
self.send_return_data(stream_id)
|
||||
|
||||
def send_return_data(self, stream_id):
|
||||
while self.return_data[stream_id]:
|
||||
flow_control = self.conn.local_flow_control_window(stream_id)
|
||||
chunk_size = min(
|
||||
len(self.return_data[stream_id]),
|
||||
flow_control,
|
||||
self.conn.max_outbound_frame_size,
|
||||
)
|
||||
if chunk_size > 0:
|
||||
chunk, self.return_data[stream_id] = (
|
||||
self.return_data[stream_id][:chunk_size],
|
||||
self.return_data[stream_id][chunk_size:],
|
||||
)
|
||||
self.conn.send_data(stream_id, chunk)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
self.conn.end_stream(stream_id)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
|
||||
|
||||
class MockRawSocketBackend:
|
||||
def __init__(self, data_to_send=b""):
|
||||
self.backend = lookup_backend()
|
||||
self.data_to_send = data_to_send
|
||||
self.received_data = []
|
||||
self.stream = MockRawSocketStream(self)
|
||||
|
||||
async def open_tcp_stream(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: Timeout,
|
||||
) -> BaseSocketStream:
|
||||
self.received_data.append(
|
||||
b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port)
|
||||
)
|
||||
return self.stream
|
||||
|
||||
# Defer all other attributes and methods to the underlying backend.
|
||||
def __getattr__(self, name: str) -> typing.Any:
|
||||
return getattr(self.backend, name)
|
||||
|
||||
|
||||
class MockRawSocketStream(BaseSocketStream):
|
||||
def __init__(self, backend: MockRawSocketBackend):
|
||||
self.backend = backend
|
||||
|
||||
async def start_tls(
|
||||
self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
|
||||
) -> BaseSocketStream:
|
||||
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
|
||||
return MockRawSocketStream(self.backend)
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
return "HTTP/1.1"
|
||||
|
||||
async def write(self, data: bytes, timeout) -> None:
|
||||
if not data:
|
||||
return
|
||||
self.backend.received_data.append(data)
|
||||
|
||||
async def read(self, n, timeout, flag=None) -> bytes:
|
||||
if not self.backend.data_to_send:
|
||||
return b""
|
||||
return self.backend.data_to_send.pop(0)
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
@ -56,33 +56,6 @@ async def test_start_tls_on_tcp_socket_stream(https_server):
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_start_tls_on_uds_socket_stream(https_uds_server):
|
||||
backend = lookup_backend()
|
||||
ctx = SSLConfig().load_ssl_context_no_verify()
|
||||
timeout = Timeout(5)
|
||||
|
||||
stream = await backend.open_uds_stream(
|
||||
https_uds_server.config.uds, https_uds_server.url.host, None, timeout
|
||||
)
|
||||
|
||||
try:
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is None
|
||||
|
||||
stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
|
||||
|
||||
response = await read_response(stream, timeout, should_contain=b"Hello, world")
|
||||
assert response.startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_concurrent_read(server):
|
||||
"""
|
||||
|
||||
@ -2,27 +2,34 @@ import binascii
|
||||
import cgi
|
||||
import io
|
||||
import os
|
||||
import typing
|
||||
from unittest import mock
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx._content_streams import encode
|
||||
from httpx._dispatch.base import AsyncDispatcher
|
||||
from httpx._content_streams import AsyncIteratorStream, encode
|
||||
from httpx._utils import format_form_param
|
||||
|
||||
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
class MockDispatch(httpcore.AsyncHTTPTransport):
|
||||
async def request(
|
||||
self,
|
||||
request: httpx.Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> httpx.Response:
|
||||
content = b"".join([part async for part in request.stream])
|
||||
return httpx.Response(200, content=content, request=request)
|
||||
method: bytes,
|
||||
url: typing.Tuple[bytes, bytes, int, bytes],
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = None,
|
||||
stream: httpcore.AsyncByteStream = None,
|
||||
timeout: typing.Dict[str, typing.Optional[float]] = None,
|
||||
) -> typing.Tuple[
|
||||
bytes,
|
||||
int,
|
||||
bytes,
|
||||
typing.List[typing.Tuple[bytes, bytes]],
|
||||
httpcore.AsyncByteStream,
|
||||
]:
|
||||
content = AsyncIteratorStream(aiterator=(part async for part in stream))
|
||||
return b"HTTP/1.1", 200, b"OK", [], content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("value,output"), (("abc", b"abc"), (b"abc", b"abc")))
|
||||
|
||||
@ -98,7 +98,6 @@ async def test_logs_debug(server, capsys):
|
||||
assert response.status_code == 200
|
||||
stderr = capsys.readouterr().err
|
||||
assert 'HTTP Request: GET http://127.0.0.1:8000/ "HTTP/1.1 200 OK"' in stderr
|
||||
assert "httpx._dispatch.connection_pool" not in stderr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -109,7 +108,6 @@ async def test_logs_trace(server, capsys):
|
||||
assert response.status_code == 200
|
||||
stderr = capsys.readouterr().err
|
||||
assert 'HTTP Request: GET http://127.0.0.1:8000/ "HTTP/1.1 200 OK"' in stderr
|
||||
assert "httpx._dispatch.connection_pool" in stderr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
Reference in New Issue
Block a user