Sync or Async dispatch (#83)
* Support thread-pooled dispatch * Add ConcurrencyBackend.run * Initial work towards support byte-iterators on sync request data * Test case for byte iterator content * byte iterator support for RequestData * Add BaseResponse * Bridge sync/async data in SyncResponse * Add BaseClient * SyncResponse -> Response * Tweaking type annotation * Distinct classes for Request, AsyncRequest * Tweak is_streaming, content in BaseRequest * Stream handling moves to client * Handle mediating to AsyncResponse from a standard sync Dispatcher class * Working on thread-pooled dispatcher * Support threaded dispatch, inc. streaming requests/responses * Increase test coverage * Coverage and tweaks * Include Accept and User-Agent headers by default
This commit is contained in:
parent
ba033c549f
commit
0cbf3c7581
@ -28,8 +28,25 @@ from .exceptions import (
|
||||
TooManyRedirects,
|
||||
WriteTimeout,
|
||||
)
|
||||
from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
|
||||
from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response
|
||||
from .interfaces import (
|
||||
AsyncDispatcher,
|
||||
BaseReader,
|
||||
BaseWriter,
|
||||
ConcurrencyBackend,
|
||||
Dispatcher,
|
||||
Protocol,
|
||||
)
|
||||
from .models import (
|
||||
URL,
|
||||
AsyncRequest,
|
||||
AsyncResponse,
|
||||
Cookies,
|
||||
Headers,
|
||||
Origin,
|
||||
QueryParams,
|
||||
Request,
|
||||
Response,
|
||||
)
|
||||
from .status_codes import StatusCode, codes
|
||||
|
||||
__version__ = "0.4.0"
|
||||
|
||||
@ -8,7 +8,7 @@ from .models import (
|
||||
HeaderTypes,
|
||||
QueryParamTypes,
|
||||
RequestData,
|
||||
SyncResponse,
|
||||
Response,
|
||||
URLTypes,
|
||||
)
|
||||
|
||||
@ -30,7 +30,7 @@ def request(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
stream: bool = False,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
with Client() as client:
|
||||
return client.request(
|
||||
method=method,
|
||||
@ -61,7 +61,7 @@ def get(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"GET",
|
||||
url,
|
||||
@ -88,7 +88,7 @@ def options(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"OPTIONS",
|
||||
url,
|
||||
@ -115,7 +115,7 @@ def head(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"HEAD",
|
||||
url,
|
||||
@ -144,7 +144,7 @@ def post(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"POST",
|
||||
url,
|
||||
@ -175,7 +175,7 @@ def put(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"PUT",
|
||||
url,
|
||||
@ -206,7 +206,7 @@ def patch(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"PATCH",
|
||||
url,
|
||||
@ -237,7 +237,7 @@ def delete(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return request(
|
||||
"DELETE",
|
||||
url,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from .models import Request
|
||||
from .models import AsyncRequest
|
||||
|
||||
|
||||
class AuthBase:
|
||||
@ -9,7 +9,7 @@ class AuthBase:
|
||||
Base class that all auth implementations derive from.
|
||||
"""
|
||||
|
||||
def __call__(self, request: Request) -> Request:
|
||||
def __call__(self, request: AsyncRequest) -> AsyncRequest:
|
||||
raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ class HTTPBasicAuth(AuthBase):
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def __call__(self, request: Request) -> Request:
|
||||
def __call__(self, request: AsyncRequest) -> AsyncRequest:
|
||||
request.headers["Authorization"] = self.build_auth_header()
|
||||
return request
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .auth import HTTPBasicAuth
|
||||
from .concurrency import AsyncioBackend
|
||||
from .config import (
|
||||
DEFAULT_MAX_REDIRECTS,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
@ -13,10 +13,15 @@ from .config import (
|
||||
VerifyTypes,
|
||||
)
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
from .dispatch.threaded import ThreadedDispatcher
|
||||
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from .interfaces import ConcurrencyBackend, Dispatcher
|
||||
from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
|
||||
from .models import (
|
||||
URL,
|
||||
AsyncRequest,
|
||||
AsyncRequestData,
|
||||
AsyncResponse,
|
||||
AsyncResponseContent,
|
||||
AuthTypes,
|
||||
Cookies,
|
||||
CookieTypes,
|
||||
@ -26,13 +31,13 @@ from .models import (
|
||||
Request,
|
||||
RequestData,
|
||||
Response,
|
||||
SyncResponse,
|
||||
ResponseContent,
|
||||
URLTypes,
|
||||
)
|
||||
from .status_codes import codes
|
||||
|
||||
|
||||
class AsyncClient:
|
||||
class BaseClient:
|
||||
def __init__(
|
||||
self,
|
||||
auth: AuthTypes = None,
|
||||
@ -42,23 +47,208 @@ class AsyncClient:
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
dispatch: Dispatcher = None,
|
||||
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
|
||||
backend: ConcurrencyBackend = None,
|
||||
):
|
||||
if backend is None:
|
||||
backend = AsyncioBackend()
|
||||
|
||||
if dispatch is None:
|
||||
dispatch = ConnectionPool(
|
||||
async_dispatch = ConnectionPool(
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
pool_limits=pool_limits,
|
||||
backend=backend,
|
||||
)
|
||||
) # type: AsyncDispatcher
|
||||
elif isinstance(dispatch, Dispatcher):
|
||||
async_dispatch = ThreadedDispatcher(dispatch, backend)
|
||||
else:
|
||||
async_dispatch = dispatch
|
||||
|
||||
self.auth = auth
|
||||
self.cookies = Cookies(cookies)
|
||||
self.max_redirects = max_redirects
|
||||
self.dispatch = dispatch
|
||||
self.dispatch = async_dispatch
|
||||
self.concurrency_backend = backend
|
||||
|
||||
def merge_cookies(
|
||||
self, cookies: CookieTypes = None
|
||||
) -> typing.Optional[CookieTypes]:
|
||||
if cookies or self.cookies:
|
||||
merged_cookies = Cookies(self.cookies)
|
||||
merged_cookies.update(cookies)
|
||||
return merged_cookies
|
||||
return cookies
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: AsyncRequest,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> AsyncResponse:
|
||||
if auth is None:
|
||||
auth = self.auth
|
||||
|
||||
url = request.url
|
||||
if auth is None and (url.username or url.password):
|
||||
auth = HTTPBasicAuth(username=url.username, password=url.password)
|
||||
|
||||
if auth is not None:
|
||||
if isinstance(auth, tuple):
|
||||
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
|
||||
request = auth(request)
|
||||
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def send_handling_redirects(
|
||||
self,
|
||||
request: AsyncRequest,
|
||||
*,
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
history: typing.List[AsyncResponse] = None,
|
||||
) -> AsyncResponse:
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
while True:
|
||||
# We perform these checks here, so that calls to `response.next()`
|
||||
# will raise redirect errors if appropriate.
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
if request.url in [response.url for response in history]:
|
||||
raise RedirectLoop()
|
||||
|
||||
response = await self.dispatch.send(
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
assert isinstance(response, AsyncResponse)
|
||||
response.history = list(history)
|
||||
self.cookies.extract_cookies(response)
|
||||
history = [response] + history
|
||||
if not response.is_redirect:
|
||||
break
|
||||
|
||||
if allow_redirects:
|
||||
request = self.build_redirect_request(request, response)
|
||||
else:
|
||||
|
||||
async def send_next() -> AsyncResponse:
|
||||
nonlocal request, response, verify, cert, allow_redirects, timeout, history
|
||||
request = self.build_redirect_request(request, response)
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
history=history,
|
||||
)
|
||||
return response
|
||||
|
||||
response.next = send_next # type: ignore
|
||||
break
|
||||
|
||||
return response
|
||||
|
||||
def build_redirect_request(
|
||||
self, request: AsyncRequest, response: AsyncResponse
|
||||
) -> AsyncRequest:
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
headers = self.redirect_headers(request, url)
|
||||
content = self.redirect_content(request, method)
|
||||
cookies = self.merge_cookies(request.cookies)
|
||||
return AsyncRequest(
|
||||
method=method, url=url, headers=headers, data=content, cookies=cookies
|
||||
)
|
||||
|
||||
def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
|
||||
"""
|
||||
When being redirected we may want to change the method of the request
|
||||
based on certain specs or browser behavior.
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.4.4
|
||||
if response.status_code == codes.SEE_OTHER and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Do what the browsers do, despite standards...
|
||||
# Turn 302s into GETs.
|
||||
if response.status_code == codes.FOUND and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# If a POST is responded to with a 301, turn it into a GET.
|
||||
# This bizarre behaviour is explained in 'requests' issue 1704.
|
||||
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
|
||||
method = "GET"
|
||||
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
|
||||
"""
|
||||
Return the URL for the redirect to follow.
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
url = URL(location, allow_relative=True)
|
||||
|
||||
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
if url.is_relative_url:
|
||||
url = url.resolve_with(request.url)
|
||||
|
||||
# Attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
if request.url.fragment and not url.fragment:
|
||||
url = url.copy_with(fragment=request.url.fragment)
|
||||
|
||||
return url
|
||||
|
||||
def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
|
||||
"""
|
||||
Strip Authorization headers when responses are redirected away from
|
||||
the origin.
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
if url.origin != request.url.origin:
|
||||
del headers["Authorization"]
|
||||
return headers
|
||||
|
||||
def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return b""
|
||||
if request.is_streaming:
|
||||
raise RedirectBodyUnavailable()
|
||||
return request.content
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
async def get(
|
||||
self,
|
||||
url: URLTypes,
|
||||
@ -72,7 +262,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"GET",
|
||||
url,
|
||||
@ -100,7 +290,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"OPTIONS",
|
||||
url,
|
||||
@ -128,7 +318,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"HEAD",
|
||||
url,
|
||||
@ -147,7 +337,7 @@ class AsyncClient:
|
||||
self,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
data: AsyncRequestData = b"",
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
@ -158,7 +348,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"POST",
|
||||
url,
|
||||
@ -179,7 +369,7 @@ class AsyncClient:
|
||||
self,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
data: AsyncRequestData = b"",
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
@ -190,7 +380,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"PUT",
|
||||
url,
|
||||
@ -211,7 +401,7 @@ class AsyncClient:
|
||||
self,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
data: AsyncRequestData = b"",
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
@ -222,7 +412,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"PATCH",
|
||||
url,
|
||||
@ -243,7 +433,7 @@ class AsyncClient:
|
||||
self,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
data: AsyncRequestData = b"",
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
@ -254,7 +444,7 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
return await self.request(
|
||||
"DELETE",
|
||||
url,
|
||||
@ -276,7 +466,7 @@ class AsyncClient:
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
data: AsyncRequestData = b"",
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
@ -287,8 +477,8 @@ class AsyncClient:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
request = Request(
|
||||
) -> AsyncResponse:
|
||||
request = AsyncRequest(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
@ -308,174 +498,6 @@ class AsyncClient:
|
||||
)
|
||||
return response
|
||||
|
||||
def merge_cookies(
|
||||
self, cookies: CookieTypes = None
|
||||
) -> typing.Optional[CookieTypes]:
|
||||
if cookies or self.cookies:
|
||||
merged_cookies = Cookies(self.cookies)
|
||||
merged_cookies.update(cookies)
|
||||
return merged_cookies
|
||||
return cookies
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if auth is None:
|
||||
auth = self.auth
|
||||
|
||||
url = request.url
|
||||
if auth is None and (url.username or url.password):
|
||||
auth = HTTPBasicAuth(username=url.username, password=url.password)
|
||||
|
||||
if auth is not None:
|
||||
if isinstance(auth, tuple):
|
||||
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
|
||||
request = auth(request)
|
||||
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
stream=stream,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
return response
|
||||
|
||||
async def send_handling_redirects(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
history: typing.List[Response] = None,
|
||||
) -> Response:
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
while True:
|
||||
# We perform these checks here, so that calls to `response.next()`
|
||||
# will raise redirect errors if appropriate.
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
if request.url in [response.url for response in history]:
|
||||
raise RedirectLoop()
|
||||
|
||||
response = await self.dispatch.send(
|
||||
request, stream=stream, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
response.history = list(history)
|
||||
self.cookies.extract_cookies(response)
|
||||
history = [response] + history
|
||||
if not response.is_redirect:
|
||||
break
|
||||
|
||||
if allow_redirects:
|
||||
request = self.build_redirect_request(request, response)
|
||||
else:
|
||||
|
||||
async def send_next() -> Response:
|
||||
nonlocal request, response, verify, cert, allow_redirects, timeout, history
|
||||
request = self.build_redirect_request(request, response)
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
stream=stream,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
history=history,
|
||||
)
|
||||
return response
|
||||
|
||||
response.next = send_next # type: ignore
|
||||
break
|
||||
|
||||
return response
|
||||
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
headers = self.redirect_headers(request, url)
|
||||
content = self.redirect_content(request, method)
|
||||
cookies = self.merge_cookies(request.cookies)
|
||||
return Request(
|
||||
method=method, url=url, headers=headers, data=content, cookies=cookies
|
||||
)
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
When being redirected we may want to change the method of the request
|
||||
based on certain specs or browser behavior.
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.4.4
|
||||
if response.status_code == codes.SEE_OTHER and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Do what the browsers do, despite standards...
|
||||
# Turn 302s into GETs.
|
||||
if response.status_code == codes.FOUND and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# If a POST is responded to with a 301, turn it into a GET.
|
||||
# This bizarre behaviour is explained in 'requests' issue 1704.
|
||||
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
|
||||
method = "GET"
|
||||
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: Request, response: Response) -> URL:
|
||||
"""
|
||||
Return the URL for the redirect to follow.
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
url = URL(location, allow_relative=True)
|
||||
|
||||
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
if url.is_relative_url:
|
||||
url = url.resolve_with(request.url)
|
||||
|
||||
# Attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
if request.url.fragment and not url.fragment:
|
||||
url = url.copy_with(fragment=request.url.fragment)
|
||||
|
||||
return url
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL) -> Headers:
|
||||
"""
|
||||
Strip Authorization headers when responses are redirected away from
|
||||
the origin.
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
if url.origin != request.url.origin:
|
||||
del headers["Authorization"]
|
||||
return headers
|
||||
|
||||
def redirect_content(self, request: Request, method: str) -> bytes:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return b""
|
||||
if request.is_streaming:
|
||||
raise RedirectBodyUnavailable()
|
||||
return request.content
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
|
||||
@ -491,33 +513,28 @@ class AsyncClient:
|
||||
await self.close()
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
auth: AuthTypes = None,
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
dispatch: Dispatcher = None,
|
||||
backend: ConcurrencyBackend = None,
|
||||
) -> None:
|
||||
self._client = AsyncClient(
|
||||
auth=auth,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
pool_limits=pool_limits,
|
||||
max_redirects=max_redirects,
|
||||
dispatch=dispatch,
|
||||
backend=backend,
|
||||
)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
class Client(BaseClient):
|
||||
def _async_request_data(self, data: RequestData) -> AsyncRequestData:
|
||||
"""
|
||||
If the request data is an bytes iterator then return an async bytes
|
||||
iterator onto the request data.
|
||||
"""
|
||||
if isinstance(data, (bytes, dict)):
|
||||
return data
|
||||
|
||||
@property
|
||||
def cookies(self) -> Cookies:
|
||||
return self._client.cookies
|
||||
# Coerce an iterator into an async iterator, with each item in the
|
||||
# iteration running as a thread-pooled operation.
|
||||
assert hasattr(data, "__iter__")
|
||||
return self.concurrency_backend.iterate_in_threadpool(data)
|
||||
|
||||
def _sync_data(self, data: AsyncResponseContent) -> ResponseContent:
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
|
||||
# Coerce an async iterator into an iterator, with each item in the
|
||||
# iteration run within the event loop.
|
||||
assert hasattr(data, "__aiter__")
|
||||
return self.concurrency_backend.iterate(data)
|
||||
|
||||
def request(
|
||||
self,
|
||||
@ -535,25 +552,55 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
request = Request(
|
||||
) -> Response:
|
||||
request = AsyncRequest(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
data=self._async_request_data(data),
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=self._client.merge_cookies(cookies),
|
||||
cookies=self.merge_cookies(cookies),
|
||||
)
|
||||
response = self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
concurrency_backend = self.concurrency_backend
|
||||
|
||||
coroutine = self.send
|
||||
args = [request]
|
||||
kwargs = dict(
|
||||
stream=True,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
)
|
||||
async_response = concurrency_backend.run(coroutine, *args, **kwargs)
|
||||
|
||||
content = getattr(
|
||||
async_response, "_raw_content", getattr(async_response, "_raw_stream", None)
|
||||
)
|
||||
|
||||
sync_content = self._sync_data(content)
|
||||
|
||||
def sync_on_close() -> None:
|
||||
nonlocal concurrency_backend, async_response
|
||||
concurrency_backend.run(async_response.on_close)
|
||||
|
||||
response = Response(
|
||||
status_code=async_response.status_code,
|
||||
reason_phrase=async_response.reason_phrase,
|
||||
protocol=async_response.protocol,
|
||||
headers=async_response.headers,
|
||||
content=sync_content,
|
||||
on_close=sync_on_close,
|
||||
request=async_response.request,
|
||||
history=async_response.history,
|
||||
)
|
||||
if not stream:
|
||||
try:
|
||||
response.read()
|
||||
finally:
|
||||
response.close()
|
||||
return response
|
||||
|
||||
def get(
|
||||
@ -569,7 +616,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"GET",
|
||||
url,
|
||||
@ -596,7 +643,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"OPTIONS",
|
||||
url,
|
||||
@ -623,7 +670,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"HEAD",
|
||||
url,
|
||||
@ -652,7 +699,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"POST",
|
||||
url,
|
||||
@ -683,7 +730,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"PUT",
|
||||
url,
|
||||
@ -714,7 +761,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"PATCH",
|
||||
url,
|
||||
@ -745,7 +792,7 @@ class Client:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
) -> Response:
|
||||
return self.request(
|
||||
"DELETE",
|
||||
url,
|
||||
@ -761,32 +808,9 @@ class Client:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> SyncResponse:
|
||||
response = self._loop.run_until_complete(
|
||||
self._client.send(
|
||||
request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
)
|
||||
)
|
||||
return SyncResponse(response, self._loop)
|
||||
|
||||
def close(self) -> None:
|
||||
self._loop.run_until_complete(self._client.close())
|
||||
coroutine = self.dispatch.close
|
||||
self.concurrency_backend.run(coroutine)
|
||||
|
||||
def __enter__(self) -> "Client":
|
||||
return self
|
||||
|
||||
@ -9,6 +9,7 @@ protocols, and help keep the rest of the package more `async`/`await`
|
||||
based, and less strictly `asyncio`-specific.
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
@ -133,6 +134,15 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
ssl_monkey_patch()
|
||||
SSL_MONKEY_PATCH_APPLIED = True
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
if not hasattr(self, "_loop"):
|
||||
try:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
return self._loop
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
hostname: str,
|
||||
@ -162,5 +172,24 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
|
||||
return (reader, writer, protocol)
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
if kwargs:
|
||||
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await self.loop.run_in_executor(None, func, *args)
|
||||
|
||||
def run(
|
||||
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
loop = self.loop
|
||||
if loop.is_running():
|
||||
self._loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return self.loop.run_until_complete(coroutine(*args, **kwargs))
|
||||
finally:
|
||||
self._loop = loop
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
return PoolSemaphore(limits)
|
||||
|
||||
@ -15,8 +15,8 @@ from ..config import (
|
||||
VerifyTypes,
|
||||
)
|
||||
from ..exceptions import ConnectTimeout
|
||||
from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
|
||||
from ..models import Origin, Request, Response
|
||||
from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol
|
||||
from ..models import AsyncRequest, AsyncResponse, Origin
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
@ -24,7 +24,7 @@ from .http11 import HTTP11Connection
|
||||
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
|
||||
|
||||
|
||||
class HTTPConnection(Dispatcher):
|
||||
class HTTPConnection(AsyncDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
@ -44,24 +44,19 @@ class HTTPConnection(Dispatcher):
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
if self.h11_connection is None and self.h2_connection is None:
|
||||
await self.connect(verify=verify, cert=cert, timeout=timeout)
|
||||
|
||||
if self.h2_connection is not None:
|
||||
response = await self.h2_connection.send(
|
||||
request, stream=stream, timeout=timeout
|
||||
)
|
||||
response = await self.h2_connection.send(request, timeout=timeout)
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
response = await self.h11_connection.send(
|
||||
request, stream=stream, timeout=timeout
|
||||
)
|
||||
response = await self.h11_connection.send(request, timeout=timeout)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -12,8 +12,8 @@ from ..config import (
|
||||
)
|
||||
from ..decoders import ACCEPT_ENCODING
|
||||
from ..exceptions import PoolTimeout
|
||||
from ..interfaces import ConcurrencyBackend, Dispatcher
|
||||
from ..models import Origin, Request, Response
|
||||
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
|
||||
from ..models import AsyncRequest, AsyncResponse, Origin
|
||||
from .connection import HTTPConnection
|
||||
|
||||
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
@ -77,7 +77,7 @@ class ConnectionStore:
|
||||
return len(self.all)
|
||||
|
||||
|
||||
class ConnectionPool(Dispatcher):
|
||||
class ConnectionPool(AsyncDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -105,16 +105,15 @@ class ConnectionPool(Dispatcher):
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
connection = await self.acquire_connection(request.url.origin)
|
||||
try:
|
||||
response = await connection.send(
|
||||
request, stream=stream, verify=verify, cert=cert, timeout=timeout
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
except BaseException as exc:
|
||||
self.active_connections.remove(connection)
|
||||
|
||||
@ -4,8 +4,8 @@ import h11
|
||||
|
||||
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
|
||||
from ..exceptions import ConnectTimeout, ReadTimeout
|
||||
from ..interfaces import BaseReader, BaseWriter, Dispatcher
|
||||
from ..models import Request, Response
|
||||
from ..interfaces import BaseReader, BaseWriter
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
|
||||
H11Event = typing.Union[
|
||||
h11.Request,
|
||||
@ -38,15 +38,15 @@ class HTTP11Connection:
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
async def send(
|
||||
self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
|
||||
) -> Response:
|
||||
self, request: AsyncRequest, timeout: TimeoutTypes = None
|
||||
) -> AsyncResponse:
|
||||
timeout = None if timeout is None else TimeoutConfig(timeout)
|
||||
|
||||
# Start sending the request.
|
||||
method = request.method.encode("ascii")
|
||||
target = request.url.full_path.encode("ascii")
|
||||
headers = request.headers.raw
|
||||
if 'Host' not in request.headers:
|
||||
if "Host" not in request.headers:
|
||||
host = request.url.authority.encode("ascii")
|
||||
headers = [(b"host", host)] + headers
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
@ -72,7 +72,7 @@ class HTTP11Connection:
|
||||
headers = event.headers
|
||||
content = self._body_iter(timeout)
|
||||
|
||||
response = Response(
|
||||
return AsyncResponse(
|
||||
status_code=status_code,
|
||||
reason_phrase=reason_phrase,
|
||||
protocol="HTTP/1.1",
|
||||
@ -82,14 +82,6 @@ class HTTP11Connection:
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
self.h11_state.send(event)
|
||||
|
||||
@ -6,8 +6,8 @@ import h2.events
|
||||
|
||||
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
|
||||
from ..exceptions import ConnectTimeout, ReadTimeout
|
||||
from ..interfaces import BaseReader, BaseWriter, Dispatcher
|
||||
from ..models import Request, Response
|
||||
from ..interfaces import BaseReader, BaseWriter
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
|
||||
|
||||
class HTTP2Connection:
|
||||
@ -24,8 +24,8 @@ class HTTP2Connection:
|
||||
self.initialized = False
|
||||
|
||||
async def send(
|
||||
self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
|
||||
) -> Response:
|
||||
self, request: AsyncRequest, timeout: TimeoutTypes = None
|
||||
) -> AsyncResponse:
|
||||
timeout = None if timeout is None else TimeoutConfig(timeout)
|
||||
|
||||
# Start sending the request.
|
||||
@ -59,7 +59,7 @@ class HTTP2Connection:
|
||||
content = self.body_iter(stream_id, timeout)
|
||||
on_close = functools.partial(self.response_closed, stream_id=stream_id)
|
||||
|
||||
response = Response(
|
||||
return AsyncResponse(
|
||||
status_code=status_code,
|
||||
protocol="HTTP/2",
|
||||
headers=headers,
|
||||
@ -68,14 +68,6 @@ class HTTP2Connection:
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.writer.close()
|
||||
|
||||
@ -86,7 +78,7 @@ class HTTP2Connection:
|
||||
self.initialized = True
|
||||
|
||||
async def send_headers(
|
||||
self, request: Request, timeout: TimeoutConfig = None
|
||||
self, request: AsyncRequest, timeout: TimeoutConfig = None
|
||||
) -> int:
|
||||
stream_id = self.h2_state.get_next_available_stream_id()
|
||||
headers = [
|
||||
|
||||
97
httpcore/dispatch/threaded.py
Normal file
97
httpcore/dispatch/threaded.py
Normal file
@ -0,0 +1,97 @@
|
||||
from ..config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
|
||||
from ..models import (
|
||||
AsyncRequest,
|
||||
AsyncRequestData,
|
||||
AsyncResponse,
|
||||
AsyncResponseContent,
|
||||
Request,
|
||||
RequestData,
|
||||
Response,
|
||||
ResponseContent,
|
||||
)
|
||||
|
||||
|
||||
class ThreadedDispatcher(AsyncDispatcher):
|
||||
"""
|
||||
The ThreadedDispatcher class is used to mediate between the Client
|
||||
(which always uses async under the hood), and a synchronous `Dispatch`
|
||||
class.
|
||||
"""
|
||||
|
||||
def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None:
|
||||
self.sync_dispatcher = dispatch
|
||||
self.backend = backend
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> AsyncResponse:
|
||||
concurrency_backend = self.backend
|
||||
|
||||
data = getattr(request, "content", getattr(request, "content_aiter", None))
|
||||
sync_data = self._sync_request_data(data)
|
||||
|
||||
sync_request = Request(
|
||||
method=request.method,
|
||||
url=request.url,
|
||||
headers=request.headers,
|
||||
data=sync_data,
|
||||
)
|
||||
|
||||
func = self.sync_dispatcher.send
|
||||
kwargs = {
|
||||
"request": sync_request,
|
||||
"verify": verify,
|
||||
"cert": cert,
|
||||
"timeout": timeout,
|
||||
}
|
||||
sync_response = await self.backend.run_in_threadpool(func, **kwargs)
|
||||
assert isinstance(sync_response, Response)
|
||||
|
||||
content = getattr(
|
||||
sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None)
|
||||
)
|
||||
|
||||
async_content = self._async_response_content(content)
|
||||
|
||||
async def async_on_close() -> None:
|
||||
nonlocal concurrency_backend, sync_response
|
||||
await concurrency_backend.run_in_threadpool(sync_response.close)
|
||||
|
||||
return AsyncResponse(
|
||||
status_code=sync_response.status_code,
|
||||
reason_phrase=sync_response.reason_phrase,
|
||||
protocol=sync_response.protocol,
|
||||
headers=sync_response.headers,
|
||||
content=async_content,
|
||||
on_close=async_on_close,
|
||||
request=request,
|
||||
history=sync_response.history,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
The `.close()` method runs the `Dispatcher.close()` within a threadpool,
|
||||
so as not to block the async event loop.
|
||||
"""
|
||||
func = self.sync_dispatcher.close
|
||||
await self.backend.run_in_threadpool(func)
|
||||
|
||||
def _async_response_content(self, content: ResponseContent) -> AsyncResponseContent:
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
|
||||
# Coerce an async iterator into an iterator, with each item in the
|
||||
# iteration run within the event loop.
|
||||
assert hasattr(content, "__iter__")
|
||||
return self.backend.iterate_in_threadpool(content)
|
||||
|
||||
def _sync_request_data(self, data: AsyncRequestData) -> RequestData:
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
|
||||
return self.backend.iterate(data)
|
||||
@ -6,6 +6,9 @@ from types import TracebackType
|
||||
from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
|
||||
from .models import (
|
||||
URL,
|
||||
AsyncRequest,
|
||||
AsyncRequestData,
|
||||
AsyncResponse,
|
||||
Headers,
|
||||
HeaderTypes,
|
||||
QueryParamTypes,
|
||||
@ -21,9 +24,9 @@ class Protocol(str, enum.Enum):
|
||||
HTTP_2 = "HTTP/2"
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
class AsyncDispatcher:
|
||||
"""
|
||||
Base class for dispatcher classes, that handle sending the request.
|
||||
Base class for async dispatcher classes, that handle sending the request.
|
||||
|
||||
Stubs out the interface, as well as providing a `.request()` convienence
|
||||
implementation, to make it easy to use or test stand-alone dispatchers,
|
||||
@ -35,34 +38,29 @@ class Dispatcher:
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
data: AsyncRequestData = b"",
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None
|
||||
) -> Response:
|
||||
request = Request(method, url, data=data, params=params, headers=headers)
|
||||
response = await self.send(
|
||||
request, stream=stream, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
return response
|
||||
) -> AsyncResponse:
|
||||
request = AsyncRequest(method, url, data=data, params=params, headers=headers)
|
||||
return await self.send(request, verify=verify, cert=cert, timeout=timeout)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
pass # pragma: nocover
|
||||
|
||||
async def __aenter__(self) -> "Dispatcher":
|
||||
async def __aenter__(self) -> "AsyncDispatcher":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
@ -74,6 +72,54 @@ class Dispatcher:
|
||||
await self.close()
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""
|
||||
Base class for syncronous dispatcher classes, that handle sending the request.
|
||||
|
||||
Stubs out the interface, as well as providing a `.request()` convienence
|
||||
implementation, to make it easy to use or test stand-alone dispatchers,
|
||||
without requiring a complete `Client` instance.
|
||||
"""
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None
|
||||
) -> Response:
|
||||
request = Request(method, url, data=data, params=params, headers=headers)
|
||||
return self.send(request, verify=verify, cert=cert, timeout=timeout)
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def close(self) -> None:
|
||||
pass # pragma: nocover
|
||||
|
||||
def __enter__(self) -> "Dispatcher":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class BaseReader:
|
||||
"""
|
||||
A stream reader. Abstracts away any asyncio-specfic interfaces
|
||||
@ -128,3 +174,36 @@ class ConcurrencyBackend:
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def iterate_in_threadpool(self, iterator): # type: ignore
|
||||
class IterationComplete(Exception):
|
||||
pass
|
||||
|
||||
def next_wrapper(iterator): # type: ignore
|
||||
try:
|
||||
return next(iterator)
|
||||
except StopIteration:
|
||||
raise IterationComplete()
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield await self.run_in_threadpool(next_wrapper, iterator)
|
||||
except IterationComplete:
|
||||
break
|
||||
|
||||
def run(
|
||||
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def iterate(self, async_iterator): # type: ignore
|
||||
while True:
|
||||
try:
|
||||
yield self.run(async_iterator.__anext__)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import cgi
|
||||
import email.message
|
||||
import json as jsonlib
|
||||
@ -48,12 +47,16 @@ CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
|
||||
|
||||
AuthTypes = typing.Union[
|
||||
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
|
||||
typing.Callable[["Request"], "Request"],
|
||||
typing.Callable[["AsyncRequest"], "AsyncRequest"],
|
||||
]
|
||||
|
||||
RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
|
||||
AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
|
||||
|
||||
ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
|
||||
RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
|
||||
|
||||
AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
|
||||
|
||||
ResponseContent = typing.Union[bytes, typing.Iterator[bytes]]
|
||||
|
||||
|
||||
class URL:
|
||||
@ -469,14 +472,12 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
return f"{class_name}({as_list!r}{encoding_str})"
|
||||
|
||||
|
||||
class Request:
|
||||
class BaseRequest:
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
data: RequestData = b"",
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
@ -488,18 +489,82 @@ class Request:
|
||||
self._cookies = Cookies(cookies)
|
||||
self._cookies.set_cookie_header(self)
|
||||
|
||||
if json is not None:
|
||||
data = jsonlib.dumps(json).encode("utf-8")
|
||||
self.headers["Content-Type"] = "application/json"
|
||||
def encode_json(self, json: typing.Any) -> bytes:
|
||||
return jsonlib.dumps(json).encode("utf-8")
|
||||
|
||||
if isinstance(data, bytes):
|
||||
def urlencode_data(self, data: dict) -> bytes:
|
||||
return urlencode(data, doseq=True).encode("utf-8")
|
||||
|
||||
def prepare(self) -> None:
|
||||
content = getattr(self, "content", None) # type: bytes
|
||||
is_streaming = getattr(self, "is_streaming", False)
|
||||
|
||||
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
has_user_agent = "user-agent" in self.headers
|
||||
has_accept = "accept" in self.headers
|
||||
has_content_length = (
|
||||
"content-length" in self.headers or "transfer-encoding" in self.headers
|
||||
)
|
||||
has_accept_encoding = "accept-encoding" in self.headers
|
||||
|
||||
if not has_user_agent:
|
||||
auto_headers.append((b"user-agent", b"httpcore"))
|
||||
if not has_accept:
|
||||
auto_headers.append((b"accept", b"*/*"))
|
||||
if not has_content_length:
|
||||
if is_streaming:
|
||||
auto_headers.append((b"transfer-encoding", b"chunked"))
|
||||
elif content:
|
||||
content_length = str(len(content)).encode()
|
||||
auto_headers.append((b"content-length", content_length))
|
||||
if not has_accept_encoding:
|
||||
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
|
||||
|
||||
for item in reversed(auto_headers):
|
||||
self.headers.raw.insert(0, item)
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
if not hasattr(self, "_cookies"):
|
||||
self._cookies = Cookies()
|
||||
return self._cookies
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
url = str(self.url)
|
||||
return f"<{class_name}({self.method!r}, {url!r})>"
|
||||
|
||||
|
||||
class AsyncRequest(BaseRequest):
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
data: AsyncRequestData = b"",
|
||||
json: typing.Any = None,
|
||||
):
|
||||
super().__init__(
|
||||
method=method, url=url, params=params, headers=headers, cookies=cookies
|
||||
)
|
||||
|
||||
if json is not None:
|
||||
self.is_streaming = False
|
||||
self.content = self.encode_json(json)
|
||||
self.headers["Content-Type"] = "application/json"
|
||||
elif isinstance(data, bytes):
|
||||
self.is_streaming = False
|
||||
self.content = data
|
||||
elif isinstance(data, dict):
|
||||
self.is_streaming = False
|
||||
self.content = urlencode(data, doseq=True).encode("utf-8")
|
||||
self.content = self.urlencode_data(data)
|
||||
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
else:
|
||||
assert hasattr(data, "__aiter__")
|
||||
self.is_streaming = True
|
||||
self.content_aiter = data
|
||||
|
||||
@ -520,39 +585,55 @@ class Request:
|
||||
elif self.content:
|
||||
yield self.content
|
||||
|
||||
def prepare(self) -> None:
|
||||
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
has_content_length = (
|
||||
"content-length" in self.headers or "transfer-encoding" in self.headers
|
||||
class Request(BaseRequest):
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
data: RequestData = b"",
|
||||
json: typing.Any = None,
|
||||
):
|
||||
super().__init__(
|
||||
method=method, url=url, params=params, headers=headers, cookies=cookies
|
||||
)
|
||||
has_accept_encoding = "accept-encoding" in self.headers
|
||||
|
||||
if not has_content_length:
|
||||
if self.is_streaming:
|
||||
auto_headers.append((b"transfer-encoding", b"chunked"))
|
||||
elif self.content:
|
||||
content_length = str(len(self.content)).encode()
|
||||
auto_headers.append((b"content-length", content_length))
|
||||
if not has_accept_encoding:
|
||||
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
|
||||
if json is not None:
|
||||
self.is_streaming = False
|
||||
self.content = self.encode_json(json)
|
||||
self.headers["Content-Type"] = "application/json"
|
||||
elif isinstance(data, bytes):
|
||||
self.is_streaming = False
|
||||
self.content = data
|
||||
elif isinstance(data, dict):
|
||||
self.is_streaming = False
|
||||
self.content = self.urlencode_data(data)
|
||||
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
else:
|
||||
assert hasattr(data, "__iter__")
|
||||
self.is_streaming = True
|
||||
self.content_iter = data
|
||||
|
||||
for item in reversed(auto_headers):
|
||||
self.headers.raw.insert(0, item)
|
||||
self.prepare()
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
if not hasattr(self, "_cookies"):
|
||||
self._cookies = Cookies()
|
||||
return self._cookies
|
||||
def read(self) -> bytes:
|
||||
if not hasattr(self, "content"):
|
||||
self.content = b"".join([part for part in self.stream()])
|
||||
return self.content
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
url = str(self.url)
|
||||
return f"<{class_name}({self.method!r}, {url!r})>"
|
||||
def stream(self) -> typing.Iterator[bytes]:
|
||||
if self.is_streaming:
|
||||
for part in self.content_iter:
|
||||
yield part
|
||||
elif self.content:
|
||||
yield self.content
|
||||
|
||||
|
||||
class Response:
|
||||
class BaseResponse:
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
@ -560,28 +641,16 @@ class Response:
|
||||
reason_phrase: str = None,
|
||||
protocol: str = None,
|
||||
headers: HeaderTypes = None,
|
||||
content: ResponseContent = b"",
|
||||
request: BaseRequest = None,
|
||||
on_close: typing.Callable = None,
|
||||
request: Request = None,
|
||||
history: typing.List["Response"] = None,
|
||||
):
|
||||
self.status_code = StatusCode.enum_or_int(status_code)
|
||||
self.reason_phrase = StatusCode.get_reason_phrase(status_code)
|
||||
self.protocol = protocol
|
||||
self.headers = Headers(headers)
|
||||
|
||||
if isinstance(content, bytes):
|
||||
self.is_closed = True
|
||||
self.is_stream_consumed = True
|
||||
self._raw_content = content
|
||||
else:
|
||||
self.is_closed = False
|
||||
self.is_stream_consumed = False
|
||||
self._raw_stream = content
|
||||
|
||||
self.on_close = on_close
|
||||
self.request = request
|
||||
self.history = [] if history is None else list(history)
|
||||
self.on_close = on_close
|
||||
self.next = None # typing.Optional[typing.Callable]
|
||||
|
||||
@property
|
||||
@ -597,7 +666,8 @@ class Response:
|
||||
def content(self) -> bytes:
|
||||
if not hasattr(self, "_content"):
|
||||
if hasattr(self, "_raw_content"):
|
||||
content = self.decoder.decode(self._raw_content)
|
||||
raw_content = getattr(self, "_raw_content") # type: bytes
|
||||
content = self.decoder.decode(raw_content)
|
||||
content += self.decoder.flush()
|
||||
self._content = content
|
||||
else:
|
||||
@ -682,6 +752,77 @@ class Response:
|
||||
|
||||
return self._decoder
|
||||
|
||||
@property
|
||||
def is_redirect(self) -> bool:
|
||||
return StatusCode.is_redirect(self.status_code) and "location" in self.headers
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
Raise the `HttpError` if one occurred.
|
||||
"""
|
||||
message = (
|
||||
"{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
|
||||
"For more information check: https://httpstatuses.com/{0.status_code}"
|
||||
)
|
||||
|
||||
if StatusCode.is_client_error(self.status_code):
|
||||
message = message.format(self, error_type="Client Error")
|
||||
elif StatusCode.is_server_error(self.status_code):
|
||||
message = message.format(self, error_type="Server Error")
|
||||
else:
|
||||
message = ""
|
||||
|
||||
if message:
|
||||
raise HttpError(message)
|
||||
|
||||
def json(self) -> typing.Any:
|
||||
return jsonlib.loads(self.content.decode("utf-8"))
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
if not hasattr(self, "_cookies"):
|
||||
assert self.request is not None
|
||||
self._cookies = Cookies()
|
||||
self._cookies.extract_cookies(self)
|
||||
return self._cookies
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Response({self.status_code}, {self.reason_phrase!r})>"
|
||||
|
||||
|
||||
class AsyncResponse(BaseResponse):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
*,
|
||||
reason_phrase: str = None,
|
||||
protocol: str = None,
|
||||
headers: HeaderTypes = None,
|
||||
content: AsyncResponseContent = b"",
|
||||
on_close: typing.Callable = None,
|
||||
request: AsyncRequest = None,
|
||||
history: typing.List["BaseResponse"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
reason_phrase=reason_phrase,
|
||||
protocol=protocol,
|
||||
headers=headers,
|
||||
request=request,
|
||||
on_close=on_close,
|
||||
)
|
||||
|
||||
self.history = [] if history is None else list(history)
|
||||
|
||||
if isinstance(content, bytes):
|
||||
self.is_closed = True
|
||||
self.is_stream_consumed = True
|
||||
self._raw_content = content
|
||||
else:
|
||||
self.is_closed = False
|
||||
self.is_stream_consumed = False
|
||||
self._raw_stream = content
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""
|
||||
Read and return the response content.
|
||||
@ -729,128 +870,86 @@ class Response:
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
|
||||
@property
|
||||
def is_redirect(self) -> bool:
|
||||
return StatusCode.is_redirect(self.status_code) and "location" in self.headers
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
Raise the `HttpError` if one occurred.
|
||||
"""
|
||||
message = (
|
||||
"{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
|
||||
"For more information check: https://httpstatuses.com/{0.status_code}"
|
||||
class Response(BaseResponse):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
*,
|
||||
reason_phrase: str = None,
|
||||
protocol: str = None,
|
||||
headers: HeaderTypes = None,
|
||||
content: ResponseContent = b"",
|
||||
on_close: typing.Callable = None,
|
||||
request: Request = None,
|
||||
history: typing.List["BaseResponse"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
reason_phrase=reason_phrase,
|
||||
protocol=protocol,
|
||||
headers=headers,
|
||||
request=request,
|
||||
on_close=on_close,
|
||||
)
|
||||
|
||||
if StatusCode.is_client_error(self.status_code):
|
||||
message = message.format(self, error_type="Client Error")
|
||||
elif StatusCode.is_server_error(self.status_code):
|
||||
message = message.format(self, error_type="Server Error")
|
||||
self.history = [] if history is None else list(history)
|
||||
|
||||
if isinstance(content, bytes):
|
||||
self.is_closed = True
|
||||
self.is_stream_consumed = True
|
||||
self._raw_content = content
|
||||
else:
|
||||
message = ""
|
||||
|
||||
if message:
|
||||
raise HttpError(message)
|
||||
|
||||
def json(self) -> typing.Any:
|
||||
return jsonlib.loads(self.content.decode("utf-8"))
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
if not hasattr(self, "_cookies"):
|
||||
assert self.request is not None
|
||||
self._cookies = Cookies()
|
||||
self._cookies.extract_cookies(self)
|
||||
return self._cookies
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Response({self.status_code}, {self.reason_phrase!r})>"
|
||||
|
||||
|
||||
class SyncResponse:
|
||||
"""
|
||||
A thread-synchronous response. This class proxies onto a `Response`
|
||||
instance, providing standard synchronous interfaces where required.
|
||||
"""
|
||||
|
||||
def __init__(self, response: Response, loop: asyncio.AbstractEventLoop):
|
||||
self._response = response
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self._response.status_code
|
||||
|
||||
@property
|
||||
def reason_phrase(self) -> str:
|
||||
return self._response.reason_phrase
|
||||
|
||||
@property
|
||||
def protocol(self) -> typing.Optional[str]:
|
||||
return self._response.protocol
|
||||
|
||||
@property
|
||||
def url(self) -> typing.Optional[URL]:
|
||||
return self._response.url
|
||||
|
||||
@property
|
||||
def request(self) -> typing.Optional[Request]:
|
||||
return self._response.request
|
||||
|
||||
@property
|
||||
def headers(self) -> Headers:
|
||||
return self._response.headers
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self._response.content
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self._response.text
|
||||
|
||||
@property
|
||||
def encoding(self) -> str:
|
||||
return self._response.encoding
|
||||
|
||||
@property
|
||||
def is_redirect(self) -> bool:
|
||||
return self._response.is_redirect
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return self._response.raise_for_status()
|
||||
|
||||
def json(self) -> typing.Any:
|
||||
return self._response.json()
|
||||
self.is_closed = False
|
||||
self.is_stream_consumed = False
|
||||
self._raw_stream = content
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self._loop.run_until_complete(self._response.read())
|
||||
"""
|
||||
Read and return the response content.
|
||||
"""
|
||||
if not hasattr(self, "_content"):
|
||||
self._content = b"".join([part for part in self.stream()])
|
||||
return self._content
|
||||
|
||||
def stream(self) -> typing.Iterator[bytes]:
|
||||
inner = self._response.stream()
|
||||
while True:
|
||||
try:
|
||||
yield self._loop.run_until_complete(inner.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
"""
|
||||
if hasattr(self, "_content"):
|
||||
yield self._content
|
||||
else:
|
||||
for chunk in self.raw():
|
||||
yield self.decoder.decode(chunk)
|
||||
yield self.decoder.flush()
|
||||
|
||||
def raw(self) -> typing.Iterator[bytes]:
|
||||
inner = self._response.raw()
|
||||
while True:
|
||||
try:
|
||||
yield self._loop.run_until_complete(inner.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
if hasattr(self, "_raw_content"):
|
||||
yield self._raw_content
|
||||
else:
|
||||
if self.is_stream_consumed:
|
||||
raise StreamConsumed()
|
||||
if self.is_closed:
|
||||
raise ResponseClosed()
|
||||
|
||||
self.is_stream_consumed = True
|
||||
for part in self._raw_stream:
|
||||
yield part
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
return self._loop.run_until_complete(self._response.close())
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
return self._response.cookies
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Response({self.status_code}, {self.reason_phrase!r})>"
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
if not self.is_closed:
|
||||
self.is_closed = True
|
||||
if self.on_close is not None:
|
||||
self.on_close()
|
||||
|
||||
|
||||
class Cookies(MutableMapping):
|
||||
@ -871,7 +970,7 @@ class Cookies(MutableMapping):
|
||||
else:
|
||||
self.jar = cookies
|
||||
|
||||
def extract_cookies(self, response: Response) -> None:
|
||||
def extract_cookies(self, response: BaseResponse) -> None:
|
||||
"""
|
||||
Loads any cookies based on the response `Set-Cookie` headers.
|
||||
"""
|
||||
@ -881,7 +980,7 @@ class Cookies(MutableMapping):
|
||||
|
||||
self.jar.extract_cookies(urlib_response, urllib_request) # type: ignore
|
||||
|
||||
def set_cookie_header(self, request: Request) -> None:
|
||||
def set_cookie_header(self, request: BaseRequest) -> None:
|
||||
"""
|
||||
Sets an appropriate 'Cookie:' HTTP header on the `Request`.
|
||||
"""
|
||||
@ -1000,7 +1099,7 @@ class Cookies(MutableMapping):
|
||||
for use with `CookieJar` operations.
|
||||
"""
|
||||
|
||||
def __init__(self, request: Request) -> None:
|
||||
def __init__(self, request: BaseRequest) -> None:
|
||||
super().__init__(
|
||||
url=str(request.url),
|
||||
headers=dict(request.headers),
|
||||
@ -1018,7 +1117,7 @@ class Cookies(MutableMapping):
|
||||
for use with `CookieJar` operations.
|
||||
"""
|
||||
|
||||
def __init__(self, response: Response):
|
||||
def __init__(self, response: BaseResponse):
|
||||
self.response = response
|
||||
|
||||
def info(self) -> email.message.Message:
|
||||
|
||||
@ -4,27 +4,26 @@ import pytest
|
||||
|
||||
from httpcore import (
|
||||
URL,
|
||||
AsyncDispatcher,
|
||||
AsyncRequest,
|
||||
AsyncResponse,
|
||||
CertTypes,
|
||||
Client,
|
||||
Dispatcher,
|
||||
Request,
|
||||
Response,
|
||||
TimeoutTypes,
|
||||
VerifyTypes,
|
||||
)
|
||||
|
||||
|
||||
class MockDispatch(Dispatcher):
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
return AsyncResponse(200, content=body, request=request)
|
||||
|
||||
|
||||
def test_basic_auth():
|
||||
|
||||
@ -5,32 +5,31 @@ import pytest
|
||||
|
||||
from httpcore import (
|
||||
URL,
|
||||
AsyncDispatcher,
|
||||
AsyncRequest,
|
||||
AsyncResponse,
|
||||
CertTypes,
|
||||
Client,
|
||||
Cookies,
|
||||
Dispatcher,
|
||||
Request,
|
||||
Response,
|
||||
TimeoutTypes,
|
||||
VerifyTypes,
|
||||
)
|
||||
|
||||
|
||||
class MockDispatch(Dispatcher):
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
if request.url.path.startswith("/echo_cookies"):
|
||||
body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
return AsyncResponse(200, content=body, request=request)
|
||||
elif request.url.path.startswith("/set_cookie"):
|
||||
headers = {"set-cookie": "example-name=example-value"}
|
||||
return Response(200, headers=headers, request=request)
|
||||
return AsyncResponse(200, headers=headers, request=request)
|
||||
|
||||
|
||||
def test_set_cookie():
|
||||
|
||||
@ -6,8 +6,10 @@ import pytest
|
||||
from httpcore import (
|
||||
URL,
|
||||
AsyncClient,
|
||||
AsyncDispatcher,
|
||||
AsyncRequest,
|
||||
AsyncResponse,
|
||||
CertTypes,
|
||||
Dispatcher,
|
||||
RedirectBodyUnavailable,
|
||||
RedirectLoop,
|
||||
Request,
|
||||
@ -19,37 +21,36 @@ from httpcore import (
|
||||
)
|
||||
|
||||
|
||||
class MockDispatch(Dispatcher):
|
||||
class MockDispatch(AsyncDispatcher):
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
request: AsyncRequest,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
) -> AsyncResponse:
|
||||
if request.url.path == "/redirect_301":
|
||||
status_code = codes.MOVED_PERMANENTLY
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
return AsyncResponse(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_302":
|
||||
status_code = codes.FOUND
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
return AsyncResponse(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_303":
|
||||
status_code = codes.SEE_OTHER
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
return AsyncResponse(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/relative_redirect":
|
||||
headers = {"location": "/"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/no_scheme_redirect":
|
||||
headers = {"location": "//example.org/"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/multiple_redirects":
|
||||
params = parse_qs(request.url.query)
|
||||
@ -60,32 +61,34 @@ class MockDispatch(Dispatcher):
|
||||
if redirect_count:
|
||||
location += "?count=" + str(redirect_count)
|
||||
headers = {"location": location} if count else {}
|
||||
return Response(code, headers=headers, request=request)
|
||||
return AsyncResponse(code, headers=headers, request=request)
|
||||
|
||||
if request.url.path == "/redirect_loop":
|
||||
headers = {"location": "/redirect_loop"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/cross_domain":
|
||||
headers = {"location": "https://example.org/cross_domain_target"}
|
||||
return Response(codes.SEE_OTHER, headers=headers, request=request)
|
||||
return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/cross_domain_target":
|
||||
headers = dict(request.headers.items())
|
||||
content = json.dumps({"headers": headers}).encode()
|
||||
return Response(codes.OK, content=content, request=request)
|
||||
return AsyncResponse(codes.OK, content=content, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_body":
|
||||
await request.read()
|
||||
headers = {"location": "/redirect_body_target"}
|
||||
return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
|
||||
return AsyncResponse(
|
||||
codes.PERMANENT_REDIRECT, headers=headers, request=request
|
||||
)
|
||||
|
||||
elif request.url.path == "/redirect_body_target":
|
||||
content = await request.read()
|
||||
body = json.dumps({"body": content.decode()}).encode()
|
||||
return Response(codes.OK, content=body, request=request)
|
||||
return AsyncResponse(codes.OK, content=body, request=request)
|
||||
|
||||
return Response(codes.OK, content=b"Hello, world!", request=request)
|
||||
return AsyncResponse(codes.OK, content=b"Hello, world!", request=request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -10,10 +10,12 @@ async def test_keepalive_connections(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
@ -25,10 +27,12 @@ async def test_differing_connection_keys(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
@ -42,10 +46,12 @@ async def test_soft_limit(server):
|
||||
|
||||
async with httpcore.ConnectionPool(pool_limits=pool_limits) as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
@ -56,7 +62,7 @@ async def test_streaming_response_holds_connection(server):
|
||||
A streaming request should hold the connection open until the response is read.
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
@ -72,11 +78,11 @@ async def test_multiple_concurrent_connections(server):
|
||||
Multiple conncurrent requests should open multiple conncurrent connections.
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
response_a = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
response_b = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert len(http.active_connections) == 2
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
@ -97,6 +103,7 @@ async def test_close_connections(server):
|
||||
headers = [(b"connection", b"close")]
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
@ -107,7 +114,7 @@ async def test_standard_response_close(server):
|
||||
A standard close should keep the connection open.
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.read()
|
||||
await response.close()
|
||||
assert len(http.active_connections) == 0
|
||||
@ -120,7 +127,7 @@ async def test_premature_response_close(server):
|
||||
A premature close should close the connection.
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.close()
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
@ -7,6 +7,7 @@ from httpcore import HTTPConnection, Request, SSLConfig
|
||||
async def test_get(server):
|
||||
conn = HTTPConnection(origin="http://127.0.0.1:8000/")
|
||||
response = await conn.request("GET", "http://127.0.0.1:8000/")
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
|
||||
@ -27,6 +28,7 @@ async def test_https_get_with_ssl_defaults(https_server):
|
||||
"""
|
||||
conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False)
|
||||
response = await conn.request("GET", "https://127.0.0.1:8001/")
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
|
||||
@ -38,5 +40,6 @@ async def test_https_get_with_sll_overrides(https_server):
|
||||
"""
|
||||
conn = HTTPConnection(origin="https://127.0.0.1:8001/")
|
||||
response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
|
||||
100
tests/dispatch/test_threaded.py
Normal file
100
tests/dispatch/test_threaded.py
Normal file
@ -0,0 +1,100 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from httpcore import (
|
||||
CertTypes,
|
||||
Client,
|
||||
Dispatcher,
|
||||
Request,
|
||||
Response,
|
||||
TimeoutTypes,
|
||||
VerifyTypes,
|
||||
)
|
||||
|
||||
|
||||
def streaming_body():
|
||||
for part in [b"Hello", b", ", b"world!"]:
|
||||
yield part
|
||||
|
||||
|
||||
class MockDispatch(Dispatcher):
|
||||
def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
if request.url.path == "/streaming_response":
|
||||
return Response(200, content=streaming_body(), request=request)
|
||||
elif request.url.path == "/echo_request_body":
|
||||
content = request.read()
|
||||
return Response(200, content=content, request=request)
|
||||
elif request.url.path == "/echo_request_body_streaming":
|
||||
content = b"".join([part for part in request.stream()])
|
||||
return Response(200, content=content, request=request)
|
||||
else:
|
||||
body = json.dumps({"hello": "world"}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
|
||||
|
||||
def test_threaded_dispatch():
|
||||
"""
|
||||
Use a syncronous 'Dispatcher' class with the client.
|
||||
Calls to the dispatcher will end up running within a thread pool.
|
||||
"""
|
||||
url = "https://example.org/"
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"hello": "world"}
|
||||
|
||||
|
||||
def test_threaded_streaming_response():
|
||||
url = "https://example.org/streaming_response"
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
def test_threaded_streaming_request():
|
||||
url = "https://example.org/echo_request_body"
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.post(url, data=streaming_body())
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
def test_threaded_request_body():
|
||||
url = "https://example.org/echo_request_body"
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.post(url, data=b"Hello, world!")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
def test_threaded_request_body_streaming():
|
||||
url = "https://example.org/echo_request_body_streaming"
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.post(url, data=b"Hello, world!")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
def test_dispatch_class():
|
||||
"""
|
||||
Use a syncronous 'Dispatcher' class directly.
|
||||
"""
|
||||
url = "https://example.org/"
|
||||
with MockDispatch() as dispatcher:
|
||||
response = dispatcher.request("GET", url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"hello": "world"}
|
||||
@ -10,87 +10,62 @@ def test_request_repr():
|
||||
|
||||
def test_no_content():
|
||||
request = httpcore.Request("GET", "http://example.org")
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"accept-encoding", b"deflate, gzip, br")]
|
||||
)
|
||||
assert "Content-Length" not in request.headers
|
||||
|
||||
|
||||
def test_content_length_header():
|
||||
request = httpcore.Request("POST", "http://example.org", data=b"test 123")
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"content-length", b"8"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
)
|
||||
assert request.headers["Content-Length"] == "8"
|
||||
|
||||
|
||||
def test_url_encoded_data():
|
||||
request = httpcore.Request("POST", "http://example.org", data={"test": "123"})
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"content-length", b"8"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"content-type", b"application/x-www-form-urlencoded"),
|
||||
]
|
||||
)
|
||||
assert request.content == b"test=123"
|
||||
for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
|
||||
request = RequestClass("POST", "http://example.org", data={"test": "123"})
|
||||
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
assert request.content == b"test=123"
|
||||
|
||||
|
||||
def test_json_encoded_data():
|
||||
for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
|
||||
request = RequestClass("POST", "http://example.org", json={"test": 123})
|
||||
assert request.headers["Content-Type"] == "application/json"
|
||||
assert request.content == b'{"test": 123}'
|
||||
|
||||
|
||||
def test_transfer_encoding_header():
|
||||
async def streaming_body(data):
|
||||
def streaming_body(data):
|
||||
yield data # pragma: nocover
|
||||
|
||||
data = streaming_body(b"test 123")
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", data=data)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"transfer-encoding", b"chunked"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
)
|
||||
assert "Content-Length" not in request.headers
|
||||
assert request.headers["Transfer-Encoding"] == "chunked"
|
||||
|
||||
|
||||
def test_override_host_header():
|
||||
headers = [(b"host", b"1.2.3.4:80")]
|
||||
headers = {"host": "1.2.3.4:80"}
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
|
||||
)
|
||||
assert request.headers["Host"] == "1.2.3.4:80"
|
||||
|
||||
|
||||
def test_override_accept_encoding_header():
|
||||
headers = [(b"accept-encoding", b"identity")]
|
||||
headers = {"Accept-Encoding": "identity"}
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"accept-encoding", b"identity")]
|
||||
)
|
||||
assert request.headers["Accept-Encoding"] == "identity"
|
||||
|
||||
|
||||
def test_override_content_length_header():
|
||||
async def streaming_body(data):
|
||||
def streaming_body(data):
|
||||
yield data # pragma: nocover
|
||||
|
||||
data = streaming_body(b"test 123")
|
||||
headers = [(b"content-length", b"8")]
|
||||
headers = {"Content-Length": "8"}
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", data=data, headers=headers)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"content-length", b"8"),
|
||||
]
|
||||
)
|
||||
assert request.headers["Content-Length"] == "8"
|
||||
|
||||
|
||||
def test_url():
|
||||
|
||||
@ -3,7 +3,12 @@ import pytest
|
||||
import httpcore
|
||||
|
||||
|
||||
async def streaming_body():
|
||||
def streaming_body():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
|
||||
async def async_streaming_body():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
@ -105,8 +110,7 @@ def test_response_force_encoding():
|
||||
assert response.encoding == "iso-8859-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_response():
|
||||
def test_read_response():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
assert response.status_code == 200
|
||||
@ -114,37 +118,56 @@ async def test_read_response():
|
||||
assert response.encoding == "ascii"
|
||||
assert response.is_closed
|
||||
|
||||
content = await response.read()
|
||||
content = response.read()
|
||||
|
||||
assert content == b"Hello, world!"
|
||||
assert response.content == b"Hello, world!"
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_interface():
|
||||
def test_raw_interface():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
raw = b""
|
||||
async for part in response.raw():
|
||||
for part in response.raw():
|
||||
raw += part
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_interface():
|
||||
def test_stream_interface():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
content = b""
|
||||
for part in response.stream():
|
||||
content += part
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_interface():
|
||||
response = httpcore.AsyncResponse(200, content=b"Hello, world!")
|
||||
|
||||
content = b""
|
||||
async for part in response.stream():
|
||||
content += part
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_interface_after_read():
|
||||
def test_stream_interface_after_read():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
response.read()
|
||||
|
||||
content = b""
|
||||
for part in response.stream():
|
||||
content += part
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_interface_after_read():
|
||||
response = httpcore.AsyncResponse(200, content=b"Hello, world!")
|
||||
|
||||
await response.read()
|
||||
|
||||
content = b""
|
||||
@ -153,13 +176,26 @@ async def test_stream_interface_after_read():
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response():
|
||||
def test_streaming_response():
|
||||
response = httpcore.Response(200, content=streaming_body())
|
||||
|
||||
assert response.status_code == 200
|
||||
assert not response.is_closed
|
||||
|
||||
content = response.read()
|
||||
|
||||
assert content == b"Hello, world!"
|
||||
assert response.content == b"Hello, world!"
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_response():
|
||||
response = httpcore.AsyncResponse(200, content=async_streaming_body())
|
||||
|
||||
assert response.status_code == 200
|
||||
assert not response.is_closed
|
||||
|
||||
content = await response.read()
|
||||
|
||||
assert content == b"Hello, world!"
|
||||
@ -167,10 +203,21 @@ async def test_streaming_response():
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_read_after_stream_consumed():
|
||||
def test_cannot_read_after_stream_consumed():
|
||||
response = httpcore.Response(200, content=streaming_body())
|
||||
|
||||
content = b""
|
||||
for part in response.stream():
|
||||
content += part
|
||||
|
||||
with pytest.raises(httpcore.StreamConsumed):
|
||||
response.read()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cannot_read_after_stream_consumed():
|
||||
response = httpcore.AsyncResponse(200, content=async_streaming_body())
|
||||
|
||||
content = b""
|
||||
async for part in response.stream():
|
||||
content += part
|
||||
@ -179,10 +226,19 @@ async def test_cannot_read_after_stream_consumed():
|
||||
await response.read()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_read_after_response_closed():
|
||||
def test_cannot_read_after_response_closed():
|
||||
response = httpcore.Response(200, content=streaming_body())
|
||||
|
||||
response.close()
|
||||
|
||||
with pytest.raises(httpcore.ResponseClosed):
|
||||
response.read()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cannot_read_after_response_closed():
|
||||
response = httpcore.AsyncResponse(200, content=async_streaming_body())
|
||||
|
||||
await response.close()
|
||||
|
||||
with pytest.raises(httpcore.ResponseClosed):
|
||||
|
||||
@ -38,6 +38,18 @@ def test_post(server):
|
||||
assert response.reason_phrase == "OK"
|
||||
|
||||
|
||||
@threadpool
|
||||
def test_post_byte_iterator(server):
|
||||
def data():
|
||||
yield b"Hello"
|
||||
yield b", "
|
||||
yield b"world!"
|
||||
|
||||
response = httpcore.post("http://127.0.0.1:8000/", data=data())
|
||||
assert response.status_code == 200
|
||||
assert response.reason_phrase == "OK"
|
||||
|
||||
|
||||
@threadpool
|
||||
def test_options(server):
|
||||
response = httpcore.options("http://127.0.0.1:8000/")
|
||||
|
||||
@ -64,19 +64,18 @@ def test_multi_with_identity():
|
||||
assert response.content == body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming():
|
||||
def test_streaming():
|
||||
body = b"test 123"
|
||||
compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
|
||||
|
||||
async def compress(body):
|
||||
def compress(body):
|
||||
yield compressor.compress(body)
|
||||
yield compressor.flush()
|
||||
|
||||
headers = [(b"Content-Encoding", b"gzip")]
|
||||
response = httpcore.Response(200, headers=headers, content=compress(body))
|
||||
assert not hasattr(response, "body")
|
||||
assert await response.read() == body
|
||||
assert response.read() == body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user