Work on bringing API into parity with requests. (#76)

* Finesse timeout argument.

* Drop unused imports

* Add 'cert' and 'verify' arguments
This commit is contained in:
Tom Christie 2019-05-23 16:21:00 +01:00 committed by GitHub
parent fc627f3387
commit 95740415db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 318 additions and 203 deletions

View File

@ -1,7 +1,14 @@
from .api import delete, get, head, options, patch, post, put, request
from .client import AsyncClient, Client
from .concurrency import AsyncioBackend
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .config import (
CertTypes,
PoolLimits,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
VerifyTypes,
)
from .dispatch.connection import HTTPConnection
from .dispatch.connection_pool import ConnectionPool
from .exceptions import (

View File

@ -1,7 +1,7 @@
import typing
from .client import Client
from .config import SSLConfig, TimeoutConfig
from .config import CertTypes, TimeoutTypes, VerifyTypes
from .models import (
AuthTypes,
CookieTypes,
@ -17,16 +17,19 @@ def request(
method: str,
url: URLTypes,
*,
params: QueryParamTypes = None,
data: RequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
stream: bool = False,
# files
auth: AuthTypes = None,
timeout: TimeoutTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
# proxies
cert: CertTypes = None,
verify: VerifyTypes = True,
stream: bool = False,
) -> SyncResponse:
with Client() as client:
return client.request(
@ -40,7 +43,8 @@ def request(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -54,8 +58,9 @@ def get(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"GET",
@ -65,7 +70,8 @@ def get(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -79,8 +85,9 @@ def options(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"OPTIONS",
@ -90,7 +97,8 @@ def options(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -104,8 +112,9 @@ def head(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"HEAD",
@ -115,7 +124,8 @@ def head(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -131,8 +141,9 @@ def post(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"POST",
@ -144,7 +155,8 @@ def post(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -160,8 +172,9 @@ def put(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"PUT",
@ -173,7 +186,8 @@ def put(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -189,8 +203,9 @@ def patch(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"PATCH",
@ -202,7 +217,8 @@ def patch(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)
@ -218,8 +234,9 @@ def delete(
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return request(
"DELETE",
@ -231,6 +248,7 @@ def delete(
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
cert=cert,
verify=verify,
timeout=timeout,
)

View File

@ -6,11 +6,11 @@ from .auth import HTTPBasicAuth
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
PoolLimits,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
VerifyTypes,
)
from .dispatch.connection_pool import ConnectionPool
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
@ -37,8 +37,9 @@ class AsyncClient:
self,
auth: AuthTypes = None,
cookies: CookieTypes = None,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
verify: VerifyTypes = True,
cert: CertTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
dispatch: Dispatcher = None,
@ -46,7 +47,11 @@ class AsyncClient:
):
if dispatch is None:
dispatch = ConnectionPool(
ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
verify=verify,
cert=cert,
timeout=timeout,
pool_limits=pool_limits,
backend=backend,
)
self.auth = auth
@ -64,8 +69,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"GET",
@ -76,7 +82,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -90,8 +97,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"OPTIONS",
@ -102,7 +110,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -116,8 +125,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"HEAD",
@ -128,7 +138,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -144,8 +155,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"POST",
@ -158,7 +170,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -174,8 +187,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"PUT",
@ -188,7 +202,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -204,8 +219,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"PATCH",
@ -218,7 +234,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -234,8 +251,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
return await self.request(
"DELETE",
@ -248,7 +266,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -265,8 +284,9 @@ class AsyncClient:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
request = Request(
method,
@ -283,7 +303,8 @@ class AsyncClient:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
return response
@ -306,9 +327,10 @@ class AsyncClient:
*,
stream: bool = False,
auth: AuthTypes = None,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if auth is None:
auth = self.auth
@ -325,7 +347,8 @@ class AsyncClient:
response = await self.send_handling_redirects(
request,
stream=stream,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
)
@ -336,8 +359,9 @@ class AsyncClient:
request: Request,
*,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
allow_redirects: bool = True,
history: typing.List[Response] = None,
) -> Response:
@ -353,7 +377,7 @@ class AsyncClient:
raise RedirectLoop()
response = await self.dispatch.send(
request, stream=stream, ssl=ssl, timeout=timeout
request, stream=stream, verify=verify, cert=cert, timeout=timeout
)
response.history = list(history)
self.cookies.extract_cookies(response)
@ -366,13 +390,14 @@ class AsyncClient:
else:
async def send_next() -> Response:
nonlocal request, response, ssl, allow_redirects, timeout, history
nonlocal request, response, verify, cert, allow_redirects, timeout, history
request = self.build_redirect_request(request, response)
response = await self.send_handling_redirects(
request,
stream=stream,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
history=history,
)
@ -474,8 +499,9 @@ class Client:
def __init__(
self,
auth: AuthTypes = None,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
dispatch: Dispatcher = None,
@ -483,7 +509,8 @@ class Client:
) -> None:
self._client = AsyncClient(
auth=auth,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
pool_limits=pool_limits,
max_redirects=max_redirects,
@ -509,8 +536,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
request = Request(
method,
@ -527,7 +555,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
return response
@ -542,8 +571,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"GET",
@ -553,7 +583,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -567,8 +598,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"OPTIONS",
@ -578,7 +610,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -592,8 +625,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"HEAD",
@ -603,7 +637,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -619,8 +654,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"POST",
@ -632,7 +668,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -648,8 +685,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"PUT",
@ -661,7 +699,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -677,8 +716,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"PATCH",
@ -690,7 +730,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -706,8 +747,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
return self.request(
"DELETE",
@ -719,7 +761,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
@ -733,8 +776,9 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
response = self._loop.run_until_complete(
self._client.send(
@ -742,7 +786,8 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
verify=verify,
cert=cert,
timeout=timeout,
)
)

View File

@ -22,9 +22,6 @@ from .interfaces import (
Protocol,
)
OptionalTimeout = typing.Optional[TimeoutConfig]
SSL_MONKEY_PATCH_APPLIED = False
@ -56,7 +53,7 @@ class Reader(BaseReader):
self.stream_reader = stream_reader
self.timeout = timeout
async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
if timeout is None:
timeout = self.timeout
@ -78,7 +75,7 @@ class Writer(BaseWriter):
def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover
async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
if not data:
return

View File

@ -5,18 +5,17 @@ import typing
import certifi
CertTypes = typing.Union[str, typing.Tuple[str, str]]
VerifyTypes = typing.Union[str, bool]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
class SSLConfig:
"""
SSL Configuration.
"""
def __init__(
self,
*,
cert: typing.Union[None, str, typing.Tuple[str, str]] = None,
verify: typing.Union[str, bool] = True,
):
def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True):
self.cert = cert
self.verify = verify
@ -31,6 +30,15 @@ class SSLConfig:
class_name = self.__class__.__name__
return f"{class_name}(cert={self.cert}, verify={self.verify})"
def with_overrides(
self, cert: CertTypes = None, verify: VerifyTypes = None
) -> "SSLConfig":
cert = self.cert if cert is None else cert
verify = self.verify if verify is None else verify
if (cert == self.cert) and (verify == self.verify):
return self
return SSLConfig(cert=cert, verify=verify)
async def load_ssl_context(self) -> ssl.SSLContext:
if not hasattr(self, "ssl_context"):
if not self.verify:
@ -109,25 +117,33 @@ class TimeoutConfig:
def __init__(
self,
timeout: float = None,
timeout: TimeoutTypes = None,
*,
connect_timeout: float = None,
read_timeout: float = None,
write_timeout: float = None,
):
if timeout is not None:
if timeout is None:
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.write_timeout = write_timeout
else:
# Specified as a single timeout value
assert connect_timeout is None
assert read_timeout is None
assert write_timeout is None
connect_timeout = timeout
read_timeout = timeout
write_timeout = timeout
self.timeout = timeout
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.write_timeout = write_timeout
if isinstance(timeout, TimeoutConfig):
self.connect_timeout = timeout.connect_timeout
self.read_timeout = timeout.read_timeout
self.write_timeout = timeout.write_timeout
elif isinstance(timeout, tuple):
self.connect_timeout = timeout[0]
self.read_timeout = timeout[1]
self.write_timeout = timeout[2]
else:
self.connect_timeout = timeout
self.read_timeout = timeout
self.write_timeout = timeout
def __eq__(self, other: typing.Any) -> bool:
return (
@ -139,8 +155,8 @@ class TimeoutConfig:
def __repr__(self) -> str:
class_name = self.__class__.__name__
if self.timeout is not None:
return f"{class_name}(timeout={self.timeout})"
if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1:
return f"{class_name}(timeout={self.connect_timeout})"
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"

View File

@ -8,8 +8,11 @@ from ..concurrency import AsyncioBackend
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
VerifyTypes,
)
from ..exceptions import ConnectTimeout
from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
@ -25,14 +28,15 @@ class HTTPConnection(Dispatcher):
def __init__(
self,
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
verify: VerifyTypes = True,
cert: CertTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
self.ssl = SSLConfig(cert=cert, verify=verify)
self.timeout = TimeoutConfig(timeout)
self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
@ -42,11 +46,12 @@ class HTTPConnection(Dispatcher):
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if self.h11_connection is None and self.h2_connection is None:
await self.connect(ssl=ssl, timeout=timeout)
await self.connect(verify=verify, cert=cert, timeout=timeout)
if self.h2_connection is not None:
response = await self.h2_connection.send(
@ -61,12 +66,13 @@ class HTTPConnection(Dispatcher):
return response
async def connect(
self, ssl: SSLConfig = None, timeout: TimeoutConfig = None
self,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> None:
if ssl is None:
ssl = self.ssl
if timeout is None:
timeout = self.timeout
ssl = self.ssl.with_overrides(verify=verify, cert=cert)
timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
host = self.origin.host
port = self.origin.port

View File

@ -4,11 +4,11 @@ from ..concurrency import AsyncioBackend
from ..config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
PoolLimits,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
VerifyTypes,
)
from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
@ -81,12 +81,14 @@ class ConnectionPool(Dispatcher):
def __init__(
self,
*,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
verify: VerifyTypes = True,
cert: CertTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
backend: ConcurrencyBackend = None,
):
self.ssl = ssl
self.verify = verify
self.cert = cert
self.timeout = timeout
self.pool_limits = pool_limits
self.is_closed = False
@ -105,13 +107,14 @@ class ConnectionPool(Dispatcher):
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
connection = await self.acquire_connection(request.url.origin)
try:
response = await connection.send(
request, stream=stream, ssl=ssl, timeout=timeout
request, stream=stream, verify=verify, cert=cert, timeout=timeout
)
except BaseException as exc:
self.active_connections.remove(connection)
@ -128,7 +131,8 @@ class ConnectionPool(Dispatcher):
await self.max_connections.acquire()
connection = HTTPConnection(
origin,
ssl=self.ssl,
verify=self.verify,
cert=self.cert,
timeout=self.timeout,
backend=self.backend,
release_func=self.release_connection,

View File

@ -2,12 +2,7 @@ import typing
import h11
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
SSLConfig,
TimeoutConfig,
)
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, Dispatcher
from ..models import Request, Response
@ -22,8 +17,6 @@ H11Event = typing.Union[
]
OptionalTimeout = typing.Optional[TimeoutConfig]
# Callback signature: async def callback() -> None
# In practice the callback will be a functools partial, which binds
# the `ConnectionPool.release_connection(conn: HTTPConnection)` method.
@ -45,8 +38,10 @@ class HTTP11Connection:
self.h11_state = h11.Connection(our_role=h11.CLIENT)
async def send(
self, request: Request, stream: bool = False, timeout: TimeoutConfig = None
self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
) -> Response:
timeout = None if timeout is None else TimeoutConfig(timeout)
#  Start sending the request.
method = request.method.encode("ascii")
target = request.url.full_path.encode("ascii")
@ -97,18 +92,20 @@ class HTTP11Connection:
self.h11_state.send(event)
await self.writer.close()
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
async def _body_iter(
self, timeout: TimeoutConfig = None
) -> typing.AsyncIterator[bytes]:
event = await self._receive_event(timeout)
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event(timeout)
assert isinstance(event, h11.EndOfMessage)
async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None:
async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
data = self.h11_state.send(event)
await self.writer.write(data, timeout)
async def _receive_event(self, timeout: OptionalTimeout) -> H11Event:
async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
event = self.h11_state.next_event()
while event is h11.NEED_DATA:

View File

@ -4,18 +4,11 @@ import typing
import h2.connection
import h2.events
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
SSLConfig,
TimeoutConfig,
)
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, Dispatcher
from ..models import Request, Response
OptionalTimeout = typing.Optional[TimeoutConfig]
class HTTP2Connection:
READ_NUM_BYTES = 4096
@ -31,8 +24,10 @@ class HTTP2Connection:
self.initialized = False
async def send(
self, request: Request, stream: bool = False, timeout: TimeoutConfig = None
self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
) -> Response:
timeout = None if timeout is None else TimeoutConfig(timeout)
#  Start sending the request.
if not self.initialized:
self.initiate_connection()
@ -89,7 +84,9 @@ class HTTP2Connection:
self.writer.write_no_block(data_to_send)
self.initialized = True
async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int:
async def send_headers(
self, request: Request, timeout: TimeoutConfig = None
) -> int:
stream_id = self.h2_state.get_next_available_stream_id()
headers = [
(b":method", request.method.encode("ascii")),
@ -103,19 +100,19 @@ class HTTP2Connection:
return stream_id
async def send_data(
self, stream_id: int, data: bytes, timeout: OptionalTimeout
self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
) -> None:
self.h2_state.send_data(stream_id, data)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None:
async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
async def body_iter(
self, stream_id: int, timeout: OptionalTimeout
self, stream_id: int, timeout: TimeoutConfig = None
) -> typing.AsyncIterator[bytes]:
while True:
event = await self.receive_event(stream_id, timeout)
@ -125,7 +122,7 @@ class HTTP2Connection:
break
async def receive_event(
self, stream_id: int, timeout: OptionalTimeout
self, stream_id: int, timeout: TimeoutConfig = None
) -> h2.events.Event:
while not self.events[stream_id]:
data = await self.reader.read(self.READ_NUM_BYTES, timeout)

View File

@ -3,7 +3,7 @@ import ssl
import typing
from types import TracebackType
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
from .models import (
URL,
Headers,
@ -15,8 +15,6 @@ from .models import (
URLTypes,
)
OptionalTimeout = typing.Optional[TimeoutConfig]
class Protocol(str, enum.Enum):
HTTP_11 = "HTTP/1.1"
@ -41,12 +39,15 @@ class Dispatcher:
params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None
) -> Response:
request = Request(method, url, data=data, params=params, headers=headers)
self.prepare_request(request)
response = await self.send(request, stream=stream, ssl=ssl, timeout=timeout)
response = await self.send(
request, stream=stream, verify=verify, cert=cert, timeout=timeout
)
return response
def prepare_request(self, request: Request) -> None:
@ -56,8 +57,9 @@ class Dispatcher:
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
raise NotImplementedError() # pragma: nocover
@ -83,7 +85,7 @@ class BaseReader:
backend, or for stand-alone test cases.
"""
async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
raise NotImplementedError() # pragma: no cover
@ -97,7 +99,7 @@ class BaseWriter:
def write_no_block(self, data: bytes) -> None:
raise NotImplementedError() # pragma: no cover
async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
raise NotImplementedError() # pragma: no cover
async def close(self) -> None:

View File

@ -4,12 +4,13 @@ import pytest
from httpcore import (
URL,
CertTypes,
Client,
Dispatcher,
Request,
Response,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
VerifyTypes,
)
@ -18,8 +19,9 @@ class MockDispatch(Dispatcher):
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
return Response(200, content=body, request=request)

View File

@ -5,13 +5,14 @@ import pytest
from httpcore import (
URL,
CertTypes,
Client,
Cookies,
Dispatcher,
Request,
Response,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
VerifyTypes,
)
@ -20,8 +21,9 @@ class MockDispatch(Dispatcher):
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path.startswith("/echo_cookies"):
body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()

View File

@ -6,14 +6,15 @@ import pytest
from httpcore import (
URL,
AsyncClient,
CertTypes,
Dispatcher,
RedirectBodyUnavailable,
RedirectLoop,
Request,
Response,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
TooManyRedirects,
VerifyTypes,
codes,
)
@ -23,8 +24,9 @@ class MockDispatch(Dispatcher):
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path == "/redirect_301":
status_code = codes.MOVED_PERMANENTLY

View File

@ -6,17 +6,7 @@ from httpcore import HTTPConnection, Request, SSLConfig
@pytest.mark.asyncio
async def test_get(server):
conn = HTTPConnection(origin="http://127.0.0.1:8000/")
request = Request("GET", "http://127.0.0.1:8000/")
request.prepare()
response = await conn.send(request)
assert response.status_code == 200
assert response.content == b"Hello, world!"
@pytest.mark.asyncio
async def test_https_get(https_server):
http = HTTPConnection(origin="https://127.0.0.1:8001/", ssl=SSLConfig(verify=False))
response = await http.request("GET", "https://127.0.0.1:8001/")
response = await conn.request("GET", "http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.content == b"Hello, world!"
@ -24,7 +14,27 @@ async def test_https_get(https_server):
@pytest.mark.asyncio
async def test_post(server):
conn = HTTPConnection(origin="http://127.0.0.1:8000/")
request = Request("GET", "http://127.0.0.1:8000/", data=b"Hello, world!")
request.prepare()
response = await conn.send(request)
response = await conn.request("GET", "http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_https_get_with_ssl_defaults(https_server):
"""
An HTTPS request, with default SSL configuration set on the client.
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False)
response = await conn.request("GET", "https://127.0.0.1:8001/")
assert response.status_code == 200
assert response.content == b"Hello, world!"
@pytest.mark.asyncio
async def test_https_get_with_sll_overrides(https_server):
"""
An HTTPS request, with SSL configuration set on the request.
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/")
response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
assert response.status_code == 200
assert response.content == b"Hello, world!"

View File

@ -94,3 +94,13 @@ def test_timeout_eq():
def test_limits_eq():
limits = httpcore.PoolLimits(hard_limit=100)
assert limits == httpcore.PoolLimits(hard_limit=100)
def test_timeout_from_tuple():
timeout = httpcore.TimeoutConfig(timeout=(5.0, 5.0, 5.0))
assert timeout == httpcore.TimeoutConfig(timeout=5.0)
def test_timeout_from_config_instance():
timeout = httpcore.TimeoutConfig(timeout=(5.0))
assert httpcore.TimeoutConfig(timeout) == httpcore.TimeoutConfig(timeout=5.0)