Add NetworkOptions

This commit is contained in:
Tom Christie 2024-01-12 12:14:15 +00:00
parent 99cba6ac64
commit 83b5e4bf13
4 changed files with 81 additions and 26 deletions

View File

@ -2,7 +2,7 @@ from .__version__ import __description__, __title__, __version__
from ._api import delete, get, head, options, patch, post, put, request, stream
from ._auth import Auth, BasicAuth, DigestAuth, NetRCAuth
from ._client import USE_CLIENT_DEFAULT, AsyncClient, Client
from ._config import Limits, Proxy, Timeout, create_ssl_context
from ._config import Limits, NetworkOptions, Proxy, Timeout, create_ssl_context
from ._content import ByteStream
from ._exceptions import (
CloseError,
@ -96,6 +96,7 @@ __all__ = [
"MockTransport",
"NetRCAuth",
"NetworkError",
"NetworkOptions",
"options",
"patch",
"PoolTimeout",

View File

@ -12,6 +12,13 @@ from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes
from ._urls import URL
from ._utils import get_ca_bundle_from_env
SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
typing.Tuple[int, int, None, int],
]
DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
@ -363,6 +370,37 @@ class Proxy:
return f"Proxy({url_str}{auth_str}{headers_str})"
class NetworkOptions:
def __init__(
self,
connection_retries: int = 0,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
uds: typing.Optional[str] = None,
) -> None:
self.connection_retries = connection_retries
self.local_address = local_address
self.socket_options = socket_options
self.uds = uds
def __repr__(self) -> str:
defaults = {
"connection_retries": 0,
"local_address": None,
"socket_options": None,
"uds": None,
}
params = ", ".join(
[
f"{attr}={getattr(self, attr)!r}"
for attr, default in defaults.items()
if getattr(self, attr) != default
]
)
return f"NetworkOptions({params})"
DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0)
DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20)
DEFAULT_NETWORK_OPTIONS = NetworkOptions(connection_retries=0)
DEFAULT_MAX_REDIRECTS = 20

View File

@ -29,7 +29,14 @@ from types import TracebackType
import httpcore
from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
from .._config import (
DEFAULT_LIMITS,
DEFAULT_NETWORK_OPTIONS,
Proxy,
Limits,
NetworkOptions,
create_ssl_context,
)
from .._exceptions import (
ConnectError,
ConnectTimeout,
@ -54,12 +61,6 @@ from .base import AsyncBaseTransport, BaseTransport
T = typing.TypeVar("T", bound="HTTPTransport")
A = typing.TypeVar("A", bound="AsyncHTTPTransport")
SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
typing.Tuple[int, int, None, int],
]
@contextlib.contextmanager
def map_httpcore_exceptions() -> typing.Iterator[None]:
@ -126,10 +127,7 @@ class HTTPTransport(BaseTransport):
limits: Limits = DEFAULT_LIMITS,
trust_env: bool = True,
proxy: typing.Optional[ProxyTypes] = None,
uds: typing.Optional[str] = None,
local_address: typing.Optional[str] = None,
retries: int = 0,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
network_options: NetworkOptions = DEFAULT_NETWORK_OPTIONS,
) -> None:
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
@ -142,10 +140,10 @@ class HTTPTransport(BaseTransport):
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
uds=network_options.uds,
local_address=network_options.local_address,
retries=network_options.connection_retries,
socket_options=network_options.socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.HTTPProxy(
@ -164,7 +162,10 @@ class HTTPTransport(BaseTransport):
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
uds=network_options.uds,
local_address=network_options.local_address,
retries=network_options.connection_retries,
socket_options=network_options.socket_options,
)
elif proxy.url.scheme == "socks5":
try:
@ -267,10 +268,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
limits: Limits = DEFAULT_LIMITS,
trust_env: bool = True,
proxy: typing.Optional[ProxyTypes] = None,
uds: typing.Optional[str] = None,
local_address: typing.Optional[str] = None,
retries: int = 0,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
network_options: NetworkOptions = DEFAULT_NETWORK_OPTIONS,
) -> None:
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
@ -283,10 +281,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
uds=uds,
local_address=local_address,
retries=retries,
socket_options=socket_options,
uds=network_options.uds,
local_address=network_options.local_address,
retries=network_options.connection_retries,
socket_options=network_options.socket_options,
)
elif proxy.url.scheme in ("http", "https"):
self._pool = httpcore.AsyncHTTPProxy(
@ -304,7 +302,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
keepalive_expiry=limits.keepalive_expiry,
http1=http1,
http2=http2,
socket_options=socket_options,
uds=network_options.uds,
local_address=network_options.local_address,
retries=network_options.connection_retries,
socket_options=network_options.socket_options,
)
elif proxy.url.scheme == "socks5":
try:

View File

@ -221,3 +221,18 @@ def test_proxy_with_auth_from_url():
def test_invalid_proxy_scheme():
with pytest.raises(ValueError):
httpx.Proxy("invalid://example.com")
def test_network_options():
network_options = httpx.NetworkOptions()
assert repr(network_options) == "NetworkOptions()"
network_options = httpx.NetworkOptions(connection_retries=1)
assert repr(network_options) == "NetworkOptions(connection_retries=1)"
network_options = httpx.NetworkOptions(
connection_retries=1, local_address="0.0.0.0"
)
assert repr(network_options) == (
"NetworkOptions(connection_retries=1, local_address='0.0.0.0')"
)