Sync or Async dispatch (#83)

* Support thread-pooled dispatch

* Add ConcurrencyBackend.run

* Initial work towards support byte-iterators on sync request data

* Test case for byte iterator content

* byte iterator support for RequestData

* Add BaseResponse

* Bridge sync/async data in SyncResponse

* Add BaseClient

* SyncResponse -> Response

* Tweaking type annotation

* Distinct classes for Request, AsyncRequest

* Tweak is_streaming, content in BaseRequest

* Stream handling moves to client

* Handle mediating to AsyncResponse from a standard sync Dispatcher class

* Working on thread-pooled dispatcher

* Support threaded dispatch, inc. streaming requests/responses

* Increase test coverage

* Coverage and tweaks

* Include Accept and User-Agent headers by default
This commit is contained in:
Tom Christie 2019-06-10 12:26:03 +01:00 committed by GitHub
parent ba033c549f
commit 0cbf3c7581
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1080 additions and 604 deletions

View File

@ -28,8 +28,25 @@ from .exceptions import (
TooManyRedirects,
WriteTimeout,
)
from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response
from .interfaces import (
AsyncDispatcher,
BaseReader,
BaseWriter,
ConcurrencyBackend,
Dispatcher,
Protocol,
)
from .models import (
URL,
AsyncRequest,
AsyncResponse,
Cookies,
Headers,
Origin,
QueryParams,
Request,
Response,
)
from .status_codes import StatusCode, codes
__version__ = "0.4.0"

View File

@ -8,7 +8,7 @@ from .models import (
HeaderTypes,
QueryParamTypes,
RequestData,
SyncResponse,
Response,
URLTypes,
)
@ -30,7 +30,7 @@ def request(
cert: CertTypes = None,
verify: VerifyTypes = True,
stream: bool = False,
) -> SyncResponse:
) -> Response:
with Client() as client:
return client.request(
method=method,
@ -61,7 +61,7 @@ def get(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"GET",
url,
@ -88,7 +88,7 @@ def options(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"OPTIONS",
url,
@ -115,7 +115,7 @@ def head(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"HEAD",
url,
@ -144,7 +144,7 @@ def post(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"POST",
url,
@ -175,7 +175,7 @@ def put(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"PUT",
url,
@ -206,7 +206,7 @@ def patch(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"PATCH",
url,
@ -237,7 +237,7 @@ def delete(
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return request(
"DELETE",
url,

View File

@ -1,7 +1,7 @@
import typing
from base64 import b64encode
from .models import Request
from .models import AsyncRequest
class AuthBase:
@ -9,7 +9,7 @@ class AuthBase:
Base class that all auth implementations derive from.
"""
def __call__(self, request: Request) -> Request:
def __call__(self, request: AsyncRequest) -> AsyncRequest:
raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover
@ -20,7 +20,7 @@ class HTTPBasicAuth(AuthBase):
self.username = username
self.password = password
def __call__(self, request: Request) -> Request:
def __call__(self, request: AsyncRequest) -> AsyncRequest:
request.headers["Authorization"] = self.build_auth_header()
return request

View File

@ -1,8 +1,8 @@
import asyncio
import typing
from types import TracebackType
from .auth import HTTPBasicAuth
from .concurrency import AsyncioBackend
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
@ -13,10 +13,15 @@ from .config import (
VerifyTypes,
)
from .dispatch.connection_pool import ConnectionPool
from .dispatch.threaded import ThreadedDispatcher
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from .interfaces import ConcurrencyBackend, Dispatcher
from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
from .models import (
URL,
AsyncRequest,
AsyncRequestData,
AsyncResponse,
AsyncResponseContent,
AuthTypes,
Cookies,
CookieTypes,
@ -26,13 +31,13 @@ from .models import (
Request,
RequestData,
Response,
SyncResponse,
ResponseContent,
URLTypes,
)
from .status_codes import codes
class AsyncClient:
class BaseClient:
def __init__(
self,
auth: AuthTypes = None,
@ -42,23 +47,208 @@ class AsyncClient:
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
dispatch: Dispatcher = None,
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
backend: ConcurrencyBackend = None,
):
if backend is None:
backend = AsyncioBackend()
if dispatch is None:
dispatch = ConnectionPool(
async_dispatch = ConnectionPool(
verify=verify,
cert=cert,
timeout=timeout,
pool_limits=pool_limits,
backend=backend,
)
) # type: AsyncDispatcher
elif isinstance(dispatch, Dispatcher):
async_dispatch = ThreadedDispatcher(dispatch, backend)
else:
async_dispatch = dispatch
self.auth = auth
self.cookies = Cookies(cookies)
self.max_redirects = max_redirects
self.dispatch = dispatch
self.dispatch = async_dispatch
self.concurrency_backend = backend
def merge_cookies(
self, cookies: CookieTypes = None
) -> typing.Optional[CookieTypes]:
if cookies or self.cookies:
merged_cookies = Cookies(self.cookies)
merged_cookies.update(cookies)
return merged_cookies
return cookies
async def send(
self,
request: AsyncRequest,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
if auth is None:
auth = self.auth
url = request.url
if auth is None and (url.username or url.password):
auth = HTTPBasicAuth(username=url.username, password=url.password)
if auth is not None:
if isinstance(auth, tuple):
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
request = auth(request)
response = await self.send_handling_redirects(
request,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
async def send_handling_redirects(
self,
request: AsyncRequest,
*,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
allow_redirects: bool = True,
history: typing.List[AsyncResponse] = None,
) -> AsyncResponse:
if history is None:
history = []
while True:
# We perform these checks here, so that calls to `response.next()`
# will raise redirect errors if appropriate.
if len(history) > self.max_redirects:
raise TooManyRedirects()
if request.url in [response.url for response in history]:
raise RedirectLoop()
response = await self.dispatch.send(
request, verify=verify, cert=cert, timeout=timeout
)
assert isinstance(response, AsyncResponse)
response.history = list(history)
self.cookies.extract_cookies(response)
history = [response] + history
if not response.is_redirect:
break
if allow_redirects:
request = self.build_redirect_request(request, response)
else:
async def send_next() -> AsyncResponse:
nonlocal request, response, verify, cert, allow_redirects, timeout, history
request = self.build_redirect_request(request, response)
response = await self.send_handling_redirects(
request,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
history=history,
)
return response
response.next = send_next # type: ignore
break
return response
def build_redirect_request(
self, request: AsyncRequest, response: AsyncResponse
) -> AsyncRequest:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url)
content = self.redirect_content(request, method)
cookies = self.merge_cookies(request.cookies)
return AsyncRequest(
method=method, url=url, headers=headers, data=content, cookies=cookies
)
def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.SEE_OTHER and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.FOUND and method != "HEAD":
method = "GET"
# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
method = "GET"
return method
def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]
url = URL(location, allow_relative=True)
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
if url.is_relative_url:
url = url.resolve_with(request.url)
# Attach previous fragment if needed (RFC 7231 7.1.2)
if request.url.fragment and not url.fragment:
url = url.copy_with(fragment=request.url.fragment)
return url
def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
"""
Strip Authorization headers when responses are redirected away from
the origin.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
del headers["Authorization"]
return headers
def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:
raise RedirectBodyUnavailable()
return request.content
class AsyncClient(BaseClient):
async def get(
self,
url: URLTypes,
@ -72,7 +262,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"GET",
url,
@ -100,7 +290,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"OPTIONS",
url,
@ -128,7 +318,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"HEAD",
url,
@ -147,7 +337,7 @@ class AsyncClient:
self,
url: URLTypes,
*,
data: RequestData = b"",
data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
@ -158,7 +348,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"POST",
url,
@ -179,7 +369,7 @@ class AsyncClient:
self,
url: URLTypes,
*,
data: RequestData = b"",
data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
@ -190,7 +380,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"PUT",
url,
@ -211,7 +401,7 @@ class AsyncClient:
self,
url: URLTypes,
*,
data: RequestData = b"",
data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
@ -222,7 +412,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"PATCH",
url,
@ -243,7 +433,7 @@ class AsyncClient:
self,
url: URLTypes,
*,
data: RequestData = b"",
data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
@ -254,7 +444,7 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
return await self.request(
"DELETE",
url,
@ -276,7 +466,7 @@ class AsyncClient:
method: str,
url: URLTypes,
*,
data: RequestData = b"",
data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
@ -287,8 +477,8 @@ class AsyncClient:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
request = Request(
) -> AsyncResponse:
request = AsyncRequest(
method,
url,
data=data,
@ -308,174 +498,6 @@ class AsyncClient:
)
return response
def merge_cookies(
self, cookies: CookieTypes = None
) -> typing.Optional[CookieTypes]:
if cookies or self.cookies:
merged_cookies = Cookies(self.cookies)
merged_cookies.update(cookies)
return merged_cookies
return cookies
async def send(
self,
request: Request,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if auth is None:
auth = self.auth
url = request.url
if auth is None and (url.username or url.password):
auth = HTTPBasicAuth(username=url.username, password=url.password)
if auth is not None:
if isinstance(auth, tuple):
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
request = auth(request)
response = await self.send_handling_redirects(
request,
stream=stream,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
)
return response
async def send_handling_redirects(
self,
request: Request,
*,
stream: bool = False,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
allow_redirects: bool = True,
history: typing.List[Response] = None,
) -> Response:
if history is None:
history = []
while True:
# We perform these checks here, so that calls to `response.next()`
# will raise redirect errors if appropriate.
if len(history) > self.max_redirects:
raise TooManyRedirects()
if request.url in [response.url for response in history]:
raise RedirectLoop()
response = await self.dispatch.send(
request, stream=stream, verify=verify, cert=cert, timeout=timeout
)
response.history = list(history)
self.cookies.extract_cookies(response)
history = [response] + history
if not response.is_redirect:
break
if allow_redirects:
request = self.build_redirect_request(request, response)
else:
async def send_next() -> Response:
nonlocal request, response, verify, cert, allow_redirects, timeout, history
request = self.build_redirect_request(request, response)
response = await self.send_handling_redirects(
request,
stream=stream,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
history=history,
)
return response
response.next = send_next # type: ignore
break
return response
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url)
content = self.redirect_content(request, method)
cookies = self.merge_cookies(request.cookies)
return Request(
method=method, url=url, headers=headers, data=content, cookies=cookies
)
def redirect_method(self, request: Request, response: Response) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.SEE_OTHER and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.FOUND and method != "HEAD":
method = "GET"
# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
method = "GET"
return method
def redirect_url(self, request: Request, response: Response) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]
url = URL(location, allow_relative=True)
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
if url.is_relative_url:
url = url.resolve_with(request.url)
# Attach previous fragment if needed (RFC 7231 7.1.2)
if request.url.fragment and not url.fragment:
url = url.copy_with(fragment=request.url.fragment)
return url
def redirect_headers(self, request: Request, url: URL) -> Headers:
"""
Strip Authorization headers when responses are redirected away from
the origin.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
del headers["Authorization"]
return headers
def redirect_content(self, request: Request, method: str) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:
raise RedirectBodyUnavailable()
return request.content
async def close(self) -> None:
await self.dispatch.close()
@ -491,33 +513,28 @@ class AsyncClient:
await self.close()
class Client:
def __init__(
self,
auth: AuthTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
dispatch: Dispatcher = None,
backend: ConcurrencyBackend = None,
) -> None:
self._client = AsyncClient(
auth=auth,
verify=verify,
cert=cert,
timeout=timeout,
pool_limits=pool_limits,
max_redirects=max_redirects,
dispatch=dispatch,
backend=backend,
)
self._loop = asyncio.new_event_loop()
class Client(BaseClient):
def _async_request_data(self, data: RequestData) -> AsyncRequestData:
"""
If the request data is an bytes iterator then return an async bytes
iterator onto the request data.
"""
if isinstance(data, (bytes, dict)):
return data
@property
def cookies(self) -> Cookies:
return self._client.cookies
# Coerce an iterator into an async iterator, with each item in the
# iteration running as a thread-pooled operation.
assert hasattr(data, "__iter__")
return self.concurrency_backend.iterate_in_threadpool(data)
def _sync_data(self, data: AsyncResponseContent) -> ResponseContent:
if isinstance(data, bytes):
return data
# Coerce an async iterator into an iterator, with each item in the
# iteration run within the event loop.
assert hasattr(data, "__aiter__")
return self.concurrency_backend.iterate(data)
def request(
self,
@ -535,25 +552,55 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
request = Request(
) -> Response:
request = AsyncRequest(
method,
url,
data=data,
data=self._async_request_data(data),
json=json,
params=params,
headers=headers,
cookies=self._client.merge_cookies(cookies),
cookies=self.merge_cookies(cookies),
)
response = self.send(
request,
stream=stream,
concurrency_backend = self.concurrency_backend
coroutine = self.send
args = [request]
kwargs = dict(
stream=True,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
)
async_response = concurrency_backend.run(coroutine, *args, **kwargs)
content = getattr(
async_response, "_raw_content", getattr(async_response, "_raw_stream", None)
)
sync_content = self._sync_data(content)
def sync_on_close() -> None:
nonlocal concurrency_backend, async_response
concurrency_backend.run(async_response.on_close)
response = Response(
status_code=async_response.status_code,
reason_phrase=async_response.reason_phrase,
protocol=async_response.protocol,
headers=async_response.headers,
content=sync_content,
on_close=sync_on_close,
request=async_response.request,
history=async_response.history,
)
if not stream:
try:
response.read()
finally:
response.close()
return response
def get(
@ -569,7 +616,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"GET",
url,
@ -596,7 +643,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"OPTIONS",
url,
@ -623,7 +670,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"HEAD",
url,
@ -652,7 +699,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"POST",
url,
@ -683,7 +730,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"PUT",
url,
@ -714,7 +761,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"PATCH",
url,
@ -745,7 +792,7 @@ class Client:
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
) -> Response:
return self.request(
"DELETE",
url,
@ -761,32 +808,9 @@ class Client:
timeout=timeout,
)
def send(
self,
request: Request,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> SyncResponse:
response = self._loop.run_until_complete(
self._client.send(
request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
)
)
return SyncResponse(response, self._loop)
def close(self) -> None:
self._loop.run_until_complete(self._client.close())
coroutine = self.dispatch.close
self.concurrency_backend.run(coroutine)
def __enter__(self) -> "Client":
return self

View File

@ -9,6 +9,7 @@ protocols, and help keep the rest of the package more `async`/`await`
based, and less strictly `asyncio`-specific.
"""
import asyncio
import functools
import ssl
import typing
@ -133,6 +134,15 @@ class AsyncioBackend(ConcurrencyBackend):
ssl_monkey_patch()
SSL_MONKEY_PATCH_APPLIED = True
@property
def loop(self) -> asyncio.AbstractEventLoop:
if not hasattr(self, "_loop"):
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
return self._loop
async def connect(
self,
hostname: str,
@ -162,5 +172,24 @@ class AsyncioBackend(ConcurrencyBackend):
return (reader, writer, protocol)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
if kwargs:
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await self.loop.run_in_executor(None, func, *args)
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
loop = self.loop
if loop.is_running():
self._loop = asyncio.new_event_loop()
try:
return self.loop.run_until_complete(coroutine(*args, **kwargs))
finally:
self._loop = loop
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)

View File

@ -15,8 +15,8 @@ from ..config import (
VerifyTypes,
)
from ..exceptions import ConnectTimeout
from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
from ..models import Origin, Request, Response
from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol
from ..models import AsyncRequest, AsyncResponse, Origin
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
@ -24,7 +24,7 @@ from .http11 import HTTP11Connection
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
class HTTPConnection(Dispatcher):
class HTTPConnection(AsyncDispatcher):
def __init__(
self,
origin: typing.Union[str, Origin],
@ -44,24 +44,19 @@ class HTTPConnection(Dispatcher):
async def send(
self,
request: Request,
stream: bool = False,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
if self.h11_connection is None and self.h2_connection is None:
await self.connect(verify=verify, cert=cert, timeout=timeout)
if self.h2_connection is not None:
response = await self.h2_connection.send(
request, stream=stream, timeout=timeout
)
response = await self.h2_connection.send(request, timeout=timeout)
else:
assert self.h11_connection is not None
response = await self.h11_connection.send(
request, stream=stream, timeout=timeout
)
response = await self.h11_connection.send(request, timeout=timeout)
return response

View File

@ -12,8 +12,8 @@ from ..config import (
)
from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
from ..interfaces import ConcurrencyBackend, Dispatcher
from ..models import Origin, Request, Response
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse, Origin
from .connection import HTTPConnection
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
@ -77,7 +77,7 @@ class ConnectionStore:
return len(self.all)
class ConnectionPool(Dispatcher):
class ConnectionPool(AsyncDispatcher):
def __init__(
self,
*,
@ -105,16 +105,15 @@ class ConnectionPool(Dispatcher):
async def send(
self,
request: Request,
stream: bool = False,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
connection = await self.acquire_connection(request.url.origin)
try:
response = await connection.send(
request, stream=stream, verify=verify, cert=cert, timeout=timeout
request, verify=verify, cert=cert, timeout=timeout
)
except BaseException as exc:
self.active_connections.remove(connection)

View File

@ -4,8 +4,8 @@ import h11
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, Dispatcher
from ..models import Request, Response
from ..interfaces import BaseReader, BaseWriter
from ..models import AsyncRequest, AsyncResponse
H11Event = typing.Union[
h11.Request,
@ -38,15 +38,15 @@ class HTTP11Connection:
self.h11_state = h11.Connection(our_role=h11.CLIENT)
async def send(
self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
) -> Response:
self, request: AsyncRequest, timeout: TimeoutTypes = None
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
#  Start sending the request.
method = request.method.encode("ascii")
target = request.url.full_path.encode("ascii")
headers = request.headers.raw
if 'Host' not in request.headers:
if "Host" not in request.headers:
host = request.url.authority.encode("ascii")
headers = [(b"host", host)] + headers
event = h11.Request(method=method, target=target, headers=headers)
@ -72,7 +72,7 @@ class HTTP11Connection:
headers = event.headers
content = self._body_iter(timeout)
response = Response(
return AsyncResponse(
status_code=status_code,
reason_phrase=reason_phrase,
protocol="HTTP/1.1",
@ -82,14 +82,6 @@ class HTTP11Connection:
request=request,
)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
async def close(self) -> None:
event = h11.ConnectionClosed()
self.h11_state.send(event)

View File

@ -6,8 +6,8 @@ import h2.events
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import BaseReader, BaseWriter, Dispatcher
from ..models import Request, Response
from ..interfaces import BaseReader, BaseWriter
from ..models import AsyncRequest, AsyncResponse
class HTTP2Connection:
@ -24,8 +24,8 @@ class HTTP2Connection:
self.initialized = False
async def send(
self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
) -> Response:
self, request: AsyncRequest, timeout: TimeoutTypes = None
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
#  Start sending the request.
@ -59,7 +59,7 @@ class HTTP2Connection:
content = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)
response = Response(
return AsyncResponse(
status_code=status_code,
protocol="HTTP/2",
headers=headers,
@ -68,14 +68,6 @@ class HTTP2Connection:
request=request,
)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
async def close(self) -> None:
await self.writer.close()
@ -86,7 +78,7 @@ class HTTP2Connection:
self.initialized = True
async def send_headers(
self, request: Request, timeout: TimeoutConfig = None
self, request: AsyncRequest, timeout: TimeoutConfig = None
) -> int:
stream_id = self.h2_state.get_next_available_stream_id()
headers = [

View File

@ -0,0 +1,97 @@
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
from ..models import (
AsyncRequest,
AsyncRequestData,
AsyncResponse,
AsyncResponseContent,
Request,
RequestData,
Response,
ResponseContent,
)
class ThreadedDispatcher(AsyncDispatcher):
"""
The ThreadedDispatcher class is used to mediate between the Client
(which always uses async under the hood), and a synchronous `Dispatch`
class.
"""
def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None:
self.sync_dispatcher = dispatch
self.backend = backend
async def send(
self,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
concurrency_backend = self.backend
data = getattr(request, "content", getattr(request, "content_aiter", None))
sync_data = self._sync_request_data(data)
sync_request = Request(
method=request.method,
url=request.url,
headers=request.headers,
data=sync_data,
)
func = self.sync_dispatcher.send
kwargs = {
"request": sync_request,
"verify": verify,
"cert": cert,
"timeout": timeout,
}
sync_response = await self.backend.run_in_threadpool(func, **kwargs)
assert isinstance(sync_response, Response)
content = getattr(
sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None)
)
async_content = self._async_response_content(content)
async def async_on_close() -> None:
nonlocal concurrency_backend, sync_response
await concurrency_backend.run_in_threadpool(sync_response.close)
return AsyncResponse(
status_code=sync_response.status_code,
reason_phrase=sync_response.reason_phrase,
protocol=sync_response.protocol,
headers=sync_response.headers,
content=async_content,
on_close=async_on_close,
request=request,
history=sync_response.history,
)
async def close(self) -> None:
"""
The `.close()` method runs the `Dispatcher.close()` within a threadpool,
so as not to block the async event loop.
"""
func = self.sync_dispatcher.close
await self.backend.run_in_threadpool(func)
def _async_response_content(self, content: ResponseContent) -> AsyncResponseContent:
if isinstance(content, bytes):
return content
# Coerce an async iterator into an iterator, with each item in the
# iteration run within the event loop.
assert hasattr(content, "__iter__")
return self.backend.iterate_in_threadpool(content)
def _sync_request_data(self, data: AsyncRequestData) -> RequestData:
if isinstance(data, bytes):
return data
return self.backend.iterate(data)

View File

@ -6,6 +6,9 @@ from types import TracebackType
from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
from .models import (
URL,
AsyncRequest,
AsyncRequestData,
AsyncResponse,
Headers,
HeaderTypes,
QueryParamTypes,
@ -21,9 +24,9 @@ class Protocol(str, enum.Enum):
HTTP_2 = "HTTP/2"
class Dispatcher:
class AsyncDispatcher:
"""
Base class for dispatcher classes, that handle sending the request.
Base class for async dispatcher classes, that handle sending the request.
Stubs out the interface, as well as providing a `.request()` convienence
implementation, to make it easy to use or test stand-alone dispatchers,
@ -35,34 +38,29 @@ class Dispatcher:
method: str,
url: URLTypes,
*,
data: RequestData = b"",
data: AsyncRequestData = b"",
params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None
) -> Response:
request = Request(method, url, data=data, params=params, headers=headers)
response = await self.send(
request, stream=stream, verify=verify, cert=cert, timeout=timeout
)
return response
) -> AsyncResponse:
request = AsyncRequest(method, url, data=data, params=params, headers=headers)
return await self.send(request, verify=verify, cert=cert, timeout=timeout)
async def send(
self,
request: Request,
stream: bool = False,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
raise NotImplementedError() # pragma: nocover
async def close(self) -> None:
pass # pragma: nocover
async def __aenter__(self) -> "Dispatcher":
async def __aenter__(self) -> "AsyncDispatcher":
return self
async def __aexit__(
@ -74,6 +72,54 @@ class Dispatcher:
await self.close()
class Dispatcher:
"""
Base class for syncronous dispatcher classes, that handle sending the request.
Stubs out the interface, as well as providing a `.request()` convienence
implementation, to make it easy to use or test stand-alone dispatchers,
without requiring a complete `Client` instance.
"""
def request(
self,
method: str,
url: URLTypes,
*,
data: RequestData = b"",
params: QueryParamTypes = None,
headers: HeaderTypes = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None
) -> Response:
request = Request(method, url, data=data, params=params, headers=headers)
return self.send(request, verify=verify, cert=cert, timeout=timeout)
def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
raise NotImplementedError() # pragma: nocover
def close(self) -> None:
pass # pragma: nocover
def __enter__(self) -> "Dispatcher":
return self
def __exit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
self.close()
class BaseReader:
"""
A stream reader. Abstracts away any asyncio-specfic interfaces
@ -128,3 +174,36 @@ class ConcurrencyBackend:
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
raise NotImplementedError() # pragma: no cover
async def iterate_in_threadpool(self, iterator): # type: ignore
class IterationComplete(Exception):
pass
def next_wrapper(iterator): # type: ignore
try:
return next(iterator)
except StopIteration:
raise IterationComplete()
while True:
try:
yield await self.run_in_threadpool(next_wrapper, iterator)
except IterationComplete:
break
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
raise NotImplementedError() # pragma: no cover
def iterate(self, async_iterator): # type: ignore
while True:
try:
yield self.run(async_iterator.__anext__)
except StopAsyncIteration:
break

View File

@ -1,4 +1,3 @@
import asyncio
import cgi
import email.message
import json as jsonlib
@ -48,12 +47,16 @@ CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
AuthTypes = typing.Union[
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
typing.Callable[["Request"], "Request"],
typing.Callable[["AsyncRequest"], "AsyncRequest"],
]
RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
ResponseContent = typing.Union[bytes, typing.Iterator[bytes]]
class URL:
@ -469,14 +472,12 @@ class Headers(typing.MutableMapping[str, str]):
return f"{class_name}({as_list!r}{encoding_str})"
class Request:
class BaseRequest:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
data: RequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
@ -488,18 +489,82 @@ class Request:
self._cookies = Cookies(cookies)
self._cookies.set_cookie_header(self)
if json is not None:
data = jsonlib.dumps(json).encode("utf-8")
self.headers["Content-Type"] = "application/json"
def encode_json(self, json: typing.Any) -> bytes:
return jsonlib.dumps(json).encode("utf-8")
if isinstance(data, bytes):
def urlencode_data(self, data: dict) -> bytes:
return urlencode(data, doseq=True).encode("utf-8")
def prepare(self) -> None:
content = getattr(self, "content", None) # type: bytes
is_streaming = getattr(self, "is_streaming", False)
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
has_user_agent = "user-agent" in self.headers
has_accept = "accept" in self.headers
has_content_length = (
"content-length" in self.headers or "transfer-encoding" in self.headers
)
has_accept_encoding = "accept-encoding" in self.headers
if not has_user_agent:
auto_headers.append((b"user-agent", b"httpcore"))
if not has_accept:
auto_headers.append((b"accept", b"*/*"))
if not has_content_length:
if is_streaming:
auto_headers.append((b"transfer-encoding", b"chunked"))
elif content:
content_length = str(len(content)).encode()
auto_headers.append((b"content-length", content_length))
if not has_accept_encoding:
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
for item in reversed(auto_headers):
self.headers.raw.insert(0, item)
@property
def cookies(self) -> "Cookies":
if not hasattr(self, "_cookies"):
self._cookies = Cookies()
return self._cookies
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
class AsyncRequest(BaseRequest):
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
data: AsyncRequestData = b"",
json: typing.Any = None,
):
super().__init__(
method=method, url=url, params=params, headers=headers, cookies=cookies
)
if json is not None:
self.is_streaming = False
self.content = self.encode_json(json)
self.headers["Content-Type"] = "application/json"
elif isinstance(data, bytes):
self.is_streaming = False
self.content = data
elif isinstance(data, dict):
self.is_streaming = False
self.content = urlencode(data, doseq=True).encode("utf-8")
self.content = self.urlencode_data(data)
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
else:
assert hasattr(data, "__aiter__")
self.is_streaming = True
self.content_aiter = data
@ -520,39 +585,55 @@ class Request:
elif self.content:
yield self.content
def prepare(self) -> None:
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
has_content_length = (
"content-length" in self.headers or "transfer-encoding" in self.headers
class Request(BaseRequest):
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
data: RequestData = b"",
json: typing.Any = None,
):
super().__init__(
method=method, url=url, params=params, headers=headers, cookies=cookies
)
has_accept_encoding = "accept-encoding" in self.headers
if not has_content_length:
if self.is_streaming:
auto_headers.append((b"transfer-encoding", b"chunked"))
elif self.content:
content_length = str(len(self.content)).encode()
auto_headers.append((b"content-length", content_length))
if not has_accept_encoding:
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
if json is not None:
self.is_streaming = False
self.content = self.encode_json(json)
self.headers["Content-Type"] = "application/json"
elif isinstance(data, bytes):
self.is_streaming = False
self.content = data
elif isinstance(data, dict):
self.is_streaming = False
self.content = self.urlencode_data(data)
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
else:
assert hasattr(data, "__iter__")
self.is_streaming = True
self.content_iter = data
for item in reversed(auto_headers):
self.headers.raw.insert(0, item)
self.prepare()
@property
def cookies(self) -> "Cookies":
if not hasattr(self, "_cookies"):
self._cookies = Cookies()
return self._cookies
def read(self) -> bytes:
if not hasattr(self, "content"):
self.content = b"".join([part for part in self.stream()])
return self.content
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
def stream(self) -> typing.Iterator[bytes]:
if self.is_streaming:
for part in self.content_iter:
yield part
elif self.content:
yield self.content
class Response:
class BaseResponse:
def __init__(
self,
status_code: int,
@ -560,28 +641,16 @@ class Response:
reason_phrase: str = None,
protocol: str = None,
headers: HeaderTypes = None,
content: ResponseContent = b"",
request: BaseRequest = None,
on_close: typing.Callable = None,
request: Request = None,
history: typing.List["Response"] = None,
):
self.status_code = StatusCode.enum_or_int(status_code)
self.reason_phrase = StatusCode.get_reason_phrase(status_code)
self.protocol = protocol
self.headers = Headers(headers)
if isinstance(content, bytes):
self.is_closed = True
self.is_stream_consumed = True
self._raw_content = content
else:
self.is_closed = False
self.is_stream_consumed = False
self._raw_stream = content
self.on_close = on_close
self.request = request
self.history = [] if history is None else list(history)
self.on_close = on_close
self.next = None # typing.Optional[typing.Callable]
@property
@ -597,7 +666,8 @@ class Response:
def content(self) -> bytes:
if not hasattr(self, "_content"):
if hasattr(self, "_raw_content"):
content = self.decoder.decode(self._raw_content)
raw_content = getattr(self, "_raw_content") # type: bytes
content = self.decoder.decode(raw_content)
content += self.decoder.flush()
self._content = content
else:
@ -682,6 +752,77 @@ class Response:
return self._decoder
@property
def is_redirect(self) -> bool:
return StatusCode.is_redirect(self.status_code) and "location" in self.headers
def raise_for_status(self) -> None:
"""
Raise the `HttpError` if one occurred.
"""
message = (
"{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
"For more information check: https://httpstatuses.com/{0.status_code}"
)
if StatusCode.is_client_error(self.status_code):
message = message.format(self, error_type="Client Error")
elif StatusCode.is_server_error(self.status_code):
message = message.format(self, error_type="Server Error")
else:
message = ""
if message:
raise HttpError(message)
def json(self) -> typing.Any:
return jsonlib.loads(self.content.decode("utf-8"))
@property
def cookies(self) -> "Cookies":
if not hasattr(self, "_cookies"):
assert self.request is not None
self._cookies = Cookies()
self._cookies.extract_cookies(self)
return self._cookies
def __repr__(self) -> str:
return f"<Response({self.status_code}, {self.reason_phrase!r})>"
class AsyncResponse(BaseResponse):
def __init__(
self,
status_code: int,
*,
reason_phrase: str = None,
protocol: str = None,
headers: HeaderTypes = None,
content: AsyncResponseContent = b"",
on_close: typing.Callable = None,
request: AsyncRequest = None,
history: typing.List["BaseResponse"] = None,
):
super().__init__(
status_code=status_code,
reason_phrase=reason_phrase,
protocol=protocol,
headers=headers,
request=request,
on_close=on_close,
)
self.history = [] if history is None else list(history)
if isinstance(content, bytes):
self.is_closed = True
self.is_stream_consumed = True
self._raw_content = content
else:
self.is_closed = False
self.is_stream_consumed = False
self._raw_stream = content
async def read(self) -> bytes:
"""
Read and return the response content.
@ -729,128 +870,86 @@ class Response:
if self.on_close is not None:
await self.on_close()
@property
def is_redirect(self) -> bool:
return StatusCode.is_redirect(self.status_code) and "location" in self.headers
def raise_for_status(self) -> None:
"""
Raise the `HttpError` if one occurred.
"""
message = (
"{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
"For more information check: https://httpstatuses.com/{0.status_code}"
class Response(BaseResponse):
def __init__(
self,
status_code: int,
*,
reason_phrase: str = None,
protocol: str = None,
headers: HeaderTypes = None,
content: ResponseContent = b"",
on_close: typing.Callable = None,
request: Request = None,
history: typing.List["BaseResponse"] = None,
):
super().__init__(
status_code=status_code,
reason_phrase=reason_phrase,
protocol=protocol,
headers=headers,
request=request,
on_close=on_close,
)
if StatusCode.is_client_error(self.status_code):
message = message.format(self, error_type="Client Error")
elif StatusCode.is_server_error(self.status_code):
message = message.format(self, error_type="Server Error")
self.history = [] if history is None else list(history)
if isinstance(content, bytes):
self.is_closed = True
self.is_stream_consumed = True
self._raw_content = content
else:
message = ""
if message:
raise HttpError(message)
def json(self) -> typing.Any:
return jsonlib.loads(self.content.decode("utf-8"))
@property
def cookies(self) -> "Cookies":
if not hasattr(self, "_cookies"):
assert self.request is not None
self._cookies = Cookies()
self._cookies.extract_cookies(self)
return self._cookies
def __repr__(self) -> str:
return f"<Response({self.status_code}, {self.reason_phrase!r})>"
class SyncResponse:
"""
A thread-synchronous response. This class proxies onto a `Response`
instance, providing standard synchronous interfaces where required.
"""
def __init__(self, response: Response, loop: asyncio.AbstractEventLoop):
self._response = response
self._loop = loop
@property
def status_code(self) -> int:
return self._response.status_code
@property
def reason_phrase(self) -> str:
return self._response.reason_phrase
@property
def protocol(self) -> typing.Optional[str]:
return self._response.protocol
@property
def url(self) -> typing.Optional[URL]:
return self._response.url
@property
def request(self) -> typing.Optional[Request]:
return self._response.request
@property
def headers(self) -> Headers:
return self._response.headers
@property
def content(self) -> bytes:
return self._response.content
@property
def text(self) -> str:
return self._response.text
@property
def encoding(self) -> str:
return self._response.encoding
@property
def is_redirect(self) -> bool:
return self._response.is_redirect
def raise_for_status(self) -> None:
return self._response.raise_for_status()
def json(self) -> typing.Any:
return self._response.json()
self.is_closed = False
self.is_stream_consumed = False
self._raw_stream = content
def read(self) -> bytes:
return self._loop.run_until_complete(self._response.read())
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part for part in self.stream()])
return self._content
def stream(self) -> typing.Iterator[bytes]:
inner = self._response.stream()
while True:
try:
yield self._loop.run_until_complete(inner.__anext__())
except StopAsyncIteration:
break
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
else:
for chunk in self.raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
def raw(self) -> typing.Iterator[bytes]:
inner = self._response.raw()
while True:
try:
yield self._loop.run_until_complete(inner.__anext__())
except StopAsyncIteration:
break
"""
A byte-iterator over the raw response content.
"""
if hasattr(self, "_raw_content"):
yield self._raw_content
else:
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_stream_consumed = True
for part in self._raw_stream:
yield part
self.close()
def close(self) -> None:
return self._loop.run_until_complete(self._response.close())
@property
def cookies(self) -> "Cookies":
return self._response.cookies
def __repr__(self) -> str:
return f"<Response({self.status_code}, {self.reason_phrase!r})>"
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:
self.on_close()
class Cookies(MutableMapping):
@ -871,7 +970,7 @@ class Cookies(MutableMapping):
else:
self.jar = cookies
def extract_cookies(self, response: Response) -> None:
def extract_cookies(self, response: BaseResponse) -> None:
"""
Loads any cookies based on the response `Set-Cookie` headers.
"""
@ -881,7 +980,7 @@ class Cookies(MutableMapping):
self.jar.extract_cookies(urlib_response, urllib_request) # type: ignore
def set_cookie_header(self, request: Request) -> None:
def set_cookie_header(self, request: BaseRequest) -> None:
"""
Sets an appropriate 'Cookie:' HTTP header on the `Request`.
"""
@ -1000,7 +1099,7 @@ class Cookies(MutableMapping):
for use with `CookieJar` operations.
"""
def __init__(self, request: Request) -> None:
def __init__(self, request: BaseRequest) -> None:
super().__init__(
url=str(request.url),
headers=dict(request.headers),
@ -1018,7 +1117,7 @@ class Cookies(MutableMapping):
for use with `CookieJar` operations.
"""
def __init__(self, response: Response):
def __init__(self, response: BaseResponse):
self.response = response
def info(self) -> email.message.Message:

View File

@ -4,27 +4,26 @@ import pytest
from httpcore import (
URL,
AsyncDispatcher,
AsyncRequest,
AsyncResponse,
CertTypes,
Client,
Dispatcher,
Request,
Response,
TimeoutTypes,
VerifyTypes,
)
class MockDispatch(Dispatcher):
class MockDispatch(AsyncDispatcher):
async def send(
self,
request: Request,
stream: bool = False,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
return Response(200, content=body, request=request)
return AsyncResponse(200, content=body, request=request)
def test_basic_auth():

View File

@ -5,32 +5,31 @@ import pytest
from httpcore import (
URL,
AsyncDispatcher,
AsyncRequest,
AsyncResponse,
CertTypes,
Client,
Cookies,
Dispatcher,
Request,
Response,
TimeoutTypes,
VerifyTypes,
)
class MockDispatch(Dispatcher):
class MockDispatch(AsyncDispatcher):
async def send(
self,
request: Request,
stream: bool = False,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
if request.url.path.startswith("/echo_cookies"):
body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
return Response(200, content=body, request=request)
return AsyncResponse(200, content=body, request=request)
elif request.url.path.startswith("/set_cookie"):
headers = {"set-cookie": "example-name=example-value"}
return Response(200, headers=headers, request=request)
return AsyncResponse(200, headers=headers, request=request)
def test_set_cookie():

View File

@ -6,8 +6,10 @@ import pytest
from httpcore import (
URL,
AsyncClient,
AsyncDispatcher,
AsyncRequest,
AsyncResponse,
CertTypes,
Dispatcher,
RedirectBodyUnavailable,
RedirectLoop,
Request,
@ -19,37 +21,36 @@ from httpcore import (
)
class MockDispatch(Dispatcher):
class MockDispatch(AsyncDispatcher):
async def send(
self,
request: Request,
stream: bool = False,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
) -> AsyncResponse:
if request.url.path == "/redirect_301":
status_code = codes.MOVED_PERMANENTLY
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
return AsyncResponse(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_302":
status_code = codes.FOUND
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
return AsyncResponse(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_303":
status_code = codes.SEE_OTHER
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
return AsyncResponse(status_code, headers=headers, request=request)
elif request.url.path == "/relative_redirect":
headers = {"location": "/"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/no_scheme_redirect":
headers = {"location": "//example.org/"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
@ -60,32 +61,34 @@ class MockDispatch(Dispatcher):
if redirect_count:
location += "?count=" + str(redirect_count)
headers = {"location": location} if count else {}
return Response(code, headers=headers, request=request)
return AsyncResponse(code, headers=headers, request=request)
if request.url.path == "/redirect_loop":
headers = {"location": "/redirect_loop"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/cross_domain":
headers = {"location": "https://example.org/cross_domain_target"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/cross_domain_target":
headers = dict(request.headers.items())
content = json.dumps({"headers": headers}).encode()
return Response(codes.OK, content=content, request=request)
return AsyncResponse(codes.OK, content=content, request=request)
elif request.url.path == "/redirect_body":
await request.read()
headers = {"location": "/redirect_body_target"}
return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
return AsyncResponse(
codes.PERMANENT_REDIRECT, headers=headers, request=request
)
elif request.url.path == "/redirect_body_target":
content = await request.read()
body = json.dumps({"body": content.decode()}).encode()
return Response(codes.OK, content=body, request=request)
return AsyncResponse(codes.OK, content=body, request=request)
return Response(codes.OK, content=b"Hello, world!", request=request)
return AsyncResponse(codes.OK, content=b"Hello, world!", request=request)
@pytest.mark.asyncio

View File

@ -10,10 +10,12 @@ async def test_keepalive_connections(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@ -25,10 +27,12 @@ async def test_differing_connection_keys(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@ -42,10 +46,12 @@ async def test_soft_limit(server):
async with httpcore.ConnectionPool(pool_limits=pool_limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@ -56,7 +62,7 @@ async def test_streaming_response_holds_connection(server):
A streaming request should hold the connection open until the response is read.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
response = await http.request("GET", "http://127.0.0.1:8000/")
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
@ -72,11 +78,11 @@ async def test_multiple_concurrent_connections(server):
Multiple conncurrent requests should open multiple conncurrent connections.
"""
async with httpcore.ConnectionPool() as http:
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
response_a = await http.request("GET", "http://127.0.0.1:8000/")
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
response_b = await http.request("GET", "http://127.0.0.1:8000/")
assert len(http.active_connections) == 2
assert len(http.keepalive_connections) == 0
@ -97,6 +103,7 @@ async def test_close_connections(server):
headers = [(b"connection", b"close")]
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
@ -107,7 +114,7 @@ async def test_standard_response_close(server):
A standard close should keep the connection open.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
await response.close()
assert len(http.active_connections) == 0
@ -120,7 +127,7 @@ async def test_premature_response_close(server):
A premature close should close the connection.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.close()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0

View File

@ -7,6 +7,7 @@ from httpcore import HTTPConnection, Request, SSLConfig
async def test_get(server):
conn = HTTPConnection(origin="http://127.0.0.1:8000/")
response = await conn.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
@ -27,6 +28,7 @@ async def test_https_get_with_ssl_defaults(https_server):
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False)
response = await conn.request("GET", "https://127.0.0.1:8001/")
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
@ -38,5 +40,6 @@ async def test_https_get_with_sll_overrides(https_server):
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/")
response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"

View File

@ -0,0 +1,100 @@
import json
import pytest
from httpcore import (
CertTypes,
Client,
Dispatcher,
Request,
Response,
TimeoutTypes,
VerifyTypes,
)
def streaming_body():
for part in [b"Hello", b", ", b"world!"]:
yield part
class MockDispatch(Dispatcher):
def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
if request.url.path == "/streaming_response":
return Response(200, content=streaming_body(), request=request)
elif request.url.path == "/echo_request_body":
content = request.read()
return Response(200, content=content, request=request)
elif request.url.path == "/echo_request_body_streaming":
content = b"".join([part for part in request.stream()])
return Response(200, content=content, request=request)
else:
body = json.dumps({"hello": "world"}).encode()
return Response(200, content=body, request=request)
def test_threaded_dispatch():
"""
Use a syncronous 'Dispatcher' class with the client.
Calls to the dispatcher will end up running within a thread pool.
"""
url = "https://example.org/"
with Client(dispatch=MockDispatch()) as client:
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"hello": "world"}
def test_threaded_streaming_response():
url = "https://example.org/streaming_response"
with Client(dispatch=MockDispatch()) as client:
response = client.get(url)
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_threaded_streaming_request():
url = "https://example.org/echo_request_body"
with Client(dispatch=MockDispatch()) as client:
response = client.post(url, data=streaming_body())
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_threaded_request_body():
url = "https://example.org/echo_request_body"
with Client(dispatch=MockDispatch()) as client:
response = client.post(url, data=b"Hello, world!")
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_threaded_request_body_streaming():
url = "https://example.org/echo_request_body_streaming"
with Client(dispatch=MockDispatch()) as client:
response = client.post(url, data=b"Hello, world!")
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_dispatch_class():
"""
Use a syncronous 'Dispatcher' class directly.
"""
url = "https://example.org/"
with MockDispatch() as dispatcher:
response = dispatcher.request("GET", url)
assert response.status_code == 200
assert response.json() == {"hello": "world"}

View File

@ -10,87 +10,62 @@ def test_request_repr():
def test_no_content():
request = httpcore.Request("GET", "http://example.org")
request.prepare()
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"deflate, gzip, br")]
)
assert "Content-Length" not in request.headers
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", data=b"test 123")
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
]
)
assert request.headers["Content-Length"] == "8"
def test_url_encoded_data():
request = httpcore.Request("POST", "http://example.org", data={"test": "123"})
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-type", b"application/x-www-form-urlencoded"),
]
)
assert request.content == b"test=123"
for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
request = RequestClass("POST", "http://example.org", data={"test": "123"})
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert request.content == b"test=123"
def test_json_encoded_data():
for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
request = RequestClass("POST", "http://example.org", json={"test": 123})
assert request.headers["Content-Type"] == "application/json"
assert request.content == b'{"test": 123}'
def test_transfer_encoding_header():
async def streaming_body(data):
def streaming_body(data):
yield data # pragma: nocover
data = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", data=data)
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"transfer-encoding", b"chunked"),
(b"accept-encoding", b"deflate, gzip, br"),
]
)
assert "Content-Length" not in request.headers
assert request.headers["Transfer-Encoding"] == "chunked"
def test_override_host_header():
headers = [(b"host", b"1.2.3.4:80")]
headers = {"host": "1.2.3.4:80"}
request = httpcore.Request("GET", "http://example.org", headers=headers)
request.prepare()
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
)
assert request.headers["Host"] == "1.2.3.4:80"
def test_override_accept_encoding_header():
headers = [(b"accept-encoding", b"identity")]
headers = {"Accept-Encoding": "identity"}
request = httpcore.Request("GET", "http://example.org", headers=headers)
request.prepare()
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"identity")]
)
assert request.headers["Accept-Encoding"] == "identity"
def test_override_content_length_header():
async def streaming_body(data):
def streaming_body(data):
yield data # pragma: nocover
data = streaming_body(b"test 123")
headers = [(b"content-length", b"8")]
headers = {"Content-Length": "8"}
request = httpcore.Request("POST", "http://example.org", data=data, headers=headers)
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-length", b"8"),
]
)
assert request.headers["Content-Length"] == "8"
def test_url():

View File

@ -3,7 +3,12 @@ import pytest
import httpcore
async def streaming_body():
def streaming_body():
yield b"Hello, "
yield b"world!"
async def async_streaming_body():
yield b"Hello, "
yield b"world!"
@ -105,8 +110,7 @@ def test_response_force_encoding():
assert response.encoding == "iso-8859-1"
@pytest.mark.asyncio
async def test_read_response():
def test_read_response():
response = httpcore.Response(200, content=b"Hello, world!")
assert response.status_code == 200
@ -114,37 +118,56 @@ async def test_read_response():
assert response.encoding == "ascii"
assert response.is_closed
content = await response.read()
content = response.read()
assert content == b"Hello, world!"
assert response.content == b"Hello, world!"
assert response.is_closed
@pytest.mark.asyncio
async def test_raw_interface():
def test_raw_interface():
response = httpcore.Response(200, content=b"Hello, world!")
raw = b""
async for part in response.raw():
for part in response.raw():
raw += part
assert raw == b"Hello, world!"
@pytest.mark.asyncio
async def test_stream_interface():
def test_stream_interface():
response = httpcore.Response(200, content=b"Hello, world!")
content = b""
for part in response.stream():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_async_stream_interface():
response = httpcore.AsyncResponse(200, content=b"Hello, world!")
content = b""
async for part in response.stream():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_stream_interface_after_read():
def test_stream_interface_after_read():
response = httpcore.Response(200, content=b"Hello, world!")
response.read()
content = b""
for part in response.stream():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_async_stream_interface_after_read():
response = httpcore.AsyncResponse(200, content=b"Hello, world!")
await response.read()
content = b""
@ -153,13 +176,26 @@ async def test_stream_interface_after_read():
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_streaming_response():
def test_streaming_response():
response = httpcore.Response(200, content=streaming_body())
assert response.status_code == 200
assert not response.is_closed
content = response.read()
assert content == b"Hello, world!"
assert response.content == b"Hello, world!"
assert response.is_closed
@pytest.mark.asyncio
async def test_async_streaming_response():
response = httpcore.AsyncResponse(200, content=async_streaming_body())
assert response.status_code == 200
assert not response.is_closed
content = await response.read()
assert content == b"Hello, world!"
@ -167,10 +203,21 @@ async def test_streaming_response():
assert response.is_closed
@pytest.mark.asyncio
async def test_cannot_read_after_stream_consumed():
def test_cannot_read_after_stream_consumed():
response = httpcore.Response(200, content=streaming_body())
content = b""
for part in response.stream():
content += part
with pytest.raises(httpcore.StreamConsumed):
response.read()
@pytest.mark.asyncio
async def test_async_cannot_read_after_stream_consumed():
response = httpcore.AsyncResponse(200, content=async_streaming_body())
content = b""
async for part in response.stream():
content += part
@ -179,10 +226,19 @@ async def test_cannot_read_after_stream_consumed():
await response.read()
@pytest.mark.asyncio
async def test_cannot_read_after_response_closed():
def test_cannot_read_after_response_closed():
response = httpcore.Response(200, content=streaming_body())
response.close()
with pytest.raises(httpcore.ResponseClosed):
response.read()
@pytest.mark.asyncio
async def test_async_cannot_read_after_response_closed():
response = httpcore.AsyncResponse(200, content=async_streaming_body())
await response.close()
with pytest.raises(httpcore.ResponseClosed):

View File

@ -38,6 +38,18 @@ def test_post(server):
assert response.reason_phrase == "OK"
@threadpool
def test_post_byte_iterator(server):
def data():
yield b"Hello"
yield b", "
yield b"world!"
response = httpcore.post("http://127.0.0.1:8000/", data=data())
assert response.status_code == 200
assert response.reason_phrase == "OK"
@threadpool
def test_options(server):
response = httpcore.options("http://127.0.0.1:8000/")

View File

@ -64,19 +64,18 @@ def test_multi_with_identity():
assert response.content == body
@pytest.mark.asyncio
async def test_streaming():
def test_streaming():
body = b"test 123"
compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
async def compress(body):
def compress(body):
yield compressor.compress(body)
yield compressor.flush()
headers = [(b"Content-Encoding", b"gzip")]
response = httpcore.Response(200, headers=headers, content=compress(body))
assert not hasattr(response, "body")
assert await response.read() == body
assert response.read() == body
@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))