commit
696c054968
90
API.md
Normal file
90
API.md
Normal file
@ -0,0 +1,90 @@
|
||||
Client(...)
|
||||
|
||||
.request(method, url, ...)
|
||||
|
||||
.get(url, ...)
|
||||
.options(url, ...)
|
||||
.head(url, ...)
|
||||
.post(url, ...)
|
||||
.put(url, ...)
|
||||
.patch(url, ...)
|
||||
.delete(url, ...)
|
||||
|
||||
.prepare_request(request)
|
||||
.send(request, ...)
|
||||
.close()
|
||||
|
||||
|
||||
Adapter()
|
||||
|
||||
.prepare_request(request)
|
||||
.send(request)
|
||||
.close()
|
||||
|
||||
|
||||
+ EnvironmentAdapter
|
||||
+ RedirectAdapter
|
||||
+ CookieAdapter
|
||||
+ AuthAdapter
|
||||
+ ConnectionPool
|
||||
+ HTTPConnection
|
||||
+ HTTP11Connection
|
||||
+ HTTP2Connection
|
||||
|
||||
|
||||
|
||||
Response(...)
|
||||
.status_code - int
|
||||
.reason_phrase - str
|
||||
.protocol - "HTTP/2" or "HTTP/1.1"
|
||||
.url - URL
|
||||
.headers - Headers
|
||||
|
||||
.content - bytes
|
||||
.text - str
|
||||
.encoding - str
|
||||
.json() - Any
|
||||
|
||||
.read() - bytes
|
||||
.stream() - bytes iterator
|
||||
.raw() - bytes iterator
|
||||
.close() - None
|
||||
|
||||
.is_redirect - bool
|
||||
.request - Request
|
||||
.cookies - Cookies
|
||||
.history - List[Response]
|
||||
|
||||
.raise_for_status()
|
||||
.next()
|
||||
|
||||
|
||||
Request(...)
|
||||
.method
|
||||
.url
|
||||
.headers
|
||||
|
||||
...
|
||||
|
||||
|
||||
Headers
|
||||
|
||||
URL
|
||||
|
||||
Origin
|
||||
|
||||
Cookies
|
||||
|
||||
|
||||
# Sync
|
||||
|
||||
SyncClient
|
||||
SyncResponse
|
||||
SyncRequest
|
||||
SyncAdapter
|
||||
|
||||
|
||||
|
||||
SSE
|
||||
HTTP/2 server push support
|
||||
Concurrency
|
||||
@ -1,16 +1,26 @@
|
||||
from .adapters.redirects import RedirectAdapter
|
||||
from .client import Client
|
||||
from .config import PoolLimits, SSLConfig, TimeoutConfig
|
||||
from .connectionpool import ConnectionPool
|
||||
from .datastructures import URL, Origin, Request, Response
|
||||
from .dispatch.connection import HTTPConnection
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
from .dispatch.http2 import HTTP2Connection
|
||||
from .dispatch.http11 import HTTP11Connection
|
||||
from .exceptions import (
|
||||
ConnectTimeout,
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ReadTimeout,
|
||||
RedirectBodyUnavailable,
|
||||
RedirectLoop,
|
||||
ResponseClosed,
|
||||
StreamConsumed,
|
||||
Timeout,
|
||||
TooManyRedirects,
|
||||
)
|
||||
from .http11 import HTTP11Connection
|
||||
from .interfaces import Adapter
|
||||
from .models import URL, Headers, Origin, Request, Response
|
||||
from .status_codes import codes
|
||||
from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
|
||||
from .sync import SyncClient, SyncConnectionPool
|
||||
|
||||
__version__ = "0.2.1"
|
||||
|
||||
4
httpcore/adapters/__init__.py
Normal file
4
httpcore/adapters/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""
|
||||
Adapter classes layer additional behavior over the raw dispatching of the
|
||||
HTTP request/response.
|
||||
"""
|
||||
18
httpcore/adapters/authentication.py
Normal file
18
httpcore/adapters/authentication.py
Normal file
@ -0,0 +1,18 @@
|
||||
import typing
|
||||
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Request, Response
|
||||
|
||||
|
||||
class AuthenticationAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter):
|
||||
self.dispatch = dispatch
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
18
httpcore/adapters/cookies.py
Normal file
18
httpcore/adapters/cookies.py
Normal file
@ -0,0 +1,18 @@
|
||||
import typing
|
||||
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Request, Response
|
||||
|
||||
|
||||
class CookieAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter):
|
||||
self.dispatch = dispatch
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
27
httpcore/adapters/environment.py
Normal file
27
httpcore/adapters/environment.py
Normal file
@ -0,0 +1,27 @@
|
||||
import typing
|
||||
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Request, Response
|
||||
|
||||
|
||||
class EnvironmentAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter, trust_env: bool = True):
|
||||
self.dispatch = dispatch
|
||||
self.trust_env = trust_env
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
if self.trust_env:
|
||||
self.merge_environment_options(options)
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
|
||||
def merge_environment_options(self, options: dict) -> None:
|
||||
"""
|
||||
Add environment options.
|
||||
"""
|
||||
# TODO
|
||||
125
httpcore/adapters/redirects.py
Normal file
125
httpcore/adapters/redirects.py
Normal file
@ -0,0 +1,125 @@
|
||||
import functools
|
||||
import typing
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from ..config import DEFAULT_MAX_REDIRECTS
|
||||
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from ..interfaces import Adapter
|
||||
from ..models import URL, Headers, Request, Response
|
||||
from ..status_codes import codes
|
||||
from ..utils import requote_uri
|
||||
|
||||
|
||||
class RedirectAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter, max_redirects: int = DEFAULT_MAX_REDIRECTS):
|
||||
self.dispatch = dispatch
|
||||
self.max_redirects = max_redirects
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
allow_redirects = options.pop("allow_redirects", True)
|
||||
history = options.pop("history", []) # type: typing.List[Response]
|
||||
seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL]
|
||||
seen_urls.add(request.url)
|
||||
|
||||
while True:
|
||||
response = await self.dispatch.send(request, **options)
|
||||
response.history = list(history)
|
||||
if not response.is_redirect:
|
||||
break
|
||||
history.append(response)
|
||||
request = self.build_redirect_request(request, response)
|
||||
if not allow_redirects:
|
||||
next_options = dict(options)
|
||||
next_options["seen_urls"] = seen_urls
|
||||
next_options["history"] = history
|
||||
response.next = functools.partial(self.send, request=request, **next_options)
|
||||
break
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
if request.url in seen_urls:
|
||||
raise RedirectLoop()
|
||||
seen_urls.add(request.url)
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
headers = self.redirect_headers(request, url)
|
||||
body = self.redirect_body(request, method)
|
||||
return Request(method=method, url=url, headers=headers, body=body)
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
When being redirected we may want to change the method of the request
|
||||
based on certain specs or browser behavior.
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.4.4
|
||||
if response.status_code == codes.see_other and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Do what the browsers do, despite standards...
|
||||
# Turn 302s into GETs.
|
||||
if response.status_code == codes.found and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# If a POST is responded to with a 301, turn it into a GET.
|
||||
# This bizarre behaviour is explained in 'requests' issue 1704.
|
||||
if response.status_code == codes.moved_permanently and method == "POST":
|
||||
method = "GET"
|
||||
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: Request, response: Response) -> URL:
|
||||
"""
|
||||
Return the URL for the redirect to follow.
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
# Handle redirection without scheme (see: RFC 1808 Section 4)
|
||||
if location.startswith("//"):
|
||||
location = f"{request.url.scheme}:{location}"
|
||||
|
||||
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
parsed = urlparse(location)
|
||||
if parsed.fragment == "" and request.url.fragment:
|
||||
parsed = parsed._replace(fragment=request.url.fragment)
|
||||
url = parsed.geturl()
|
||||
|
||||
# Facilitate relative 'location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
# Compliant with RFC3986, we percent encode the url.
|
||||
if not parsed.netloc:
|
||||
url = urljoin(str(request.url), requote_uri(url))
|
||||
else:
|
||||
url = requote_uri(url)
|
||||
|
||||
return URL(url)
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL) -> Headers:
|
||||
"""
|
||||
Strip Authorization headers when responses are redirected away from
|
||||
the origin.
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
if url.origin != request.url.origin:
|
||||
del headers["Authorization"]
|
||||
return headers
|
||||
|
||||
def redirect_body(self, request: Request, method: str) -> bytes:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return b""
|
||||
if request.is_streaming:
|
||||
raise RedirectBodyUnavailable()
|
||||
return request.body
|
||||
124
httpcore/client.py
Normal file
124
httpcore/client.py
Normal file
@ -0,0 +1,124 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .adapters.authentication import AuthenticationAdapter
|
||||
from .adapters.cookies import CookieAdapter
|
||||
from .adapters.environment import EnvironmentAdapter
|
||||
from .adapters.redirects import RedirectAdapter
|
||||
from .config import (
|
||||
DEFAULT_MAX_REDIRECTS,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
from .models import URL, Request, Response
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
):
|
||||
connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits)
|
||||
cookie_adapter = CookieAdapter(dispatch=connection_pool)
|
||||
auth_adapter = AuthenticationAdapter(dispatch=cookie_adapter)
|
||||
redirect_adapter = RedirectAdapter(
|
||||
dispatch=auth_adapter, max_redirects=max_redirects
|
||||
)
|
||||
self.adapter = EnvironmentAdapter(dispatch=redirect_adapter)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
stream: bool = False,
|
||||
allow_redirects: bool = True,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
request = Request(method, url, headers=headers, body=body)
|
||||
self.prepare_request(request)
|
||||
response = await self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
stream: bool = False,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
return await self.request(
|
||||
"GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
stream: bool = False,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
return await self.request(
|
||||
"POST",
|
||||
url,
|
||||
body=body,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.adapter.prepare_request(request)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
allow_redirects: bool = True,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
options = {"stream": stream} # type: typing.Dict[str, typing.Any]
|
||||
if ssl is not None:
|
||||
options["ssl"] = ssl
|
||||
if timeout is not None:
|
||||
options["timeout"] = timeout
|
||||
return await self.adapter.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.adapter.close()
|
||||
|
||||
async def __aenter__(self) -> "Client":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
@ -27,10 +27,6 @@ class SSLConfig:
|
||||
and self.verify == other.verify
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
as_tuple = (self.cert, self.verify)
|
||||
return hash(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(cert={self.cert}, verify={self.verify})"
|
||||
@ -73,7 +69,23 @@ class SSLConfig:
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
|
||||
context = ssl.create_default_context()
|
||||
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
||||
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
|
||||
# RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST
|
||||
# support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the
|
||||
# blacklist defined in this section allows only the AES GCM and ChaCha20
|
||||
# cipher suites with ephemeral key negotiation.
|
||||
context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20")
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
context.set_alpn_protocols(["h2", "http/1.1"])
|
||||
if ssl.HAS_NPN:
|
||||
context.set_npn_protocols(["h2", "http/1.1"])
|
||||
|
||||
if os.path.isfile(ca_bundle_path):
|
||||
context.load_verify_locations(cafile=ca_bundle_path)
|
||||
elif os.path.isdir(ca_bundle_path):
|
||||
@ -99,39 +111,35 @@ class TimeoutConfig:
|
||||
*,
|
||||
connect_timeout: float = None,
|
||||
read_timeout: float = None,
|
||||
pool_timeout: float = None,
|
||||
write_timeout: float = None,
|
||||
):
|
||||
if timeout is not None:
|
||||
# Specified as a single timeout value
|
||||
assert connect_timeout is None
|
||||
assert read_timeout is None
|
||||
assert pool_timeout is None
|
||||
assert write_timeout is None
|
||||
connect_timeout = timeout
|
||||
read_timeout = timeout
|
||||
pool_timeout = timeout
|
||||
write_timeout = timeout
|
||||
|
||||
self.timeout = timeout
|
||||
self.connect_timeout = connect_timeout
|
||||
self.read_timeout = read_timeout
|
||||
self.pool_timeout = pool_timeout
|
||||
self.write_timeout = write_timeout
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.connect_timeout == other.connect_timeout
|
||||
and self.read_timeout == other.read_timeout
|
||||
and self.pool_timeout == other.pool_timeout
|
||||
and self.write_timeout == other.write_timeout
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout)
|
||||
return hash(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
if self.timeout is not None:
|
||||
return f"{class_name}(timeout={self.timeout})"
|
||||
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})"
|
||||
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
|
||||
|
||||
|
||||
class PoolLimits:
|
||||
@ -142,31 +150,29 @@ class PoolLimits:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
soft_limit: typing.Optional[int] = None,
|
||||
hard_limit: typing.Optional[int] = None,
|
||||
soft_limit: int = None,
|
||||
hard_limit: int = None,
|
||||
pool_timeout: float = None,
|
||||
):
|
||||
self.soft_limit = soft_limit
|
||||
self.hard_limit = hard_limit
|
||||
self.pool_timeout = pool_timeout
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.soft_limit == other.soft_limit
|
||||
and self.hard_limit == other.hard_limit
|
||||
and self.pool_timeout == other.pool_timeout
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
as_tuple = (self.soft_limit, self.hard_limit)
|
||||
return hash(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return (
|
||||
f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
|
||||
)
|
||||
return f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})"
|
||||
|
||||
|
||||
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
|
||||
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
|
||||
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
|
||||
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0)
|
||||
DEFAULT_CA_BUNDLE_PATH = certifi.where()
|
||||
DEFAULT_MAX_REDIRECTS = 20
|
||||
|
||||
@ -1,132 +0,0 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from .config import (
|
||||
DEFAULT_CA_BUNDLE_PATH,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from .datastructures import Client, Origin, Request, Response
|
||||
from .exceptions import PoolTimeout
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
|
||||
class ConnectionPool(Client):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
):
|
||||
self.ssl = ssl
|
||||
self.timeout = timeout
|
||||
self.limits = limits
|
||||
self.is_closed = False
|
||||
self.num_active_connections = 0
|
||||
self.num_keepalive_connections = 0
|
||||
self._keepalive_connections = (
|
||||
{}
|
||||
) # type: typing.Dict[Origin, typing.List[HTTP11Connection]]
|
||||
self._max_connections = ConnectionSemaphore(
|
||||
max_connections=self.limits.hard_limit
|
||||
)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
connection = await self.acquire_connection(request.url.origin, timeout=timeout)
|
||||
response = await connection.send(request, ssl=ssl, timeout=timeout)
|
||||
return response
|
||||
|
||||
@property
|
||||
def num_connections(self) -> int:
|
||||
return self.num_active_connections + self.num_keepalive_connections
|
||||
|
||||
async def acquire_connection(
|
||||
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> HTTP11Connection:
|
||||
try:
|
||||
connection = self._keepalive_connections[origin].pop()
|
||||
if not self._keepalive_connections[origin]:
|
||||
del self._keepalive_connections[origin]
|
||||
self.num_keepalive_connections -= 1
|
||||
self.num_active_connections += 1
|
||||
|
||||
except (KeyError, IndexError):
|
||||
if timeout is None:
|
||||
pool_timeout = self.timeout.pool_timeout
|
||||
else:
|
||||
pool_timeout = timeout.pool_timeout
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise PoolTimeout()
|
||||
connection = HTTP11Connection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
timeout=self.timeout,
|
||||
on_release=self.release_connection,
|
||||
)
|
||||
self.num_active_connections += 1
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, connection: HTTP11Connection) -> None:
|
||||
if connection.is_closed:
|
||||
self._max_connections.release()
|
||||
self.num_active_connections -= 1
|
||||
elif (
|
||||
self.limits.soft_limit is not None
|
||||
and self.num_connections > self.limits.soft_limit
|
||||
):
|
||||
self._max_connections.release()
|
||||
self.num_active_connections -= 1
|
||||
await connection.close()
|
||||
else:
|
||||
self.num_active_connections -= 1
|
||||
self.num_keepalive_connections += 1
|
||||
try:
|
||||
self._keepalive_connections[connection.origin].append(connection)
|
||||
except KeyError:
|
||||
self._keepalive_connections[connection.origin] = [connection]
|
||||
|
||||
async def close(self) -> None:
|
||||
self.is_closed = True
|
||||
all_connections = []
|
||||
for connections in self._keepalive_connections.values():
|
||||
all_connections.extend(list(connections))
|
||||
self._keepalive_connections.clear()
|
||||
for connection in all_connections:
|
||||
await connection.close()
|
||||
|
||||
|
||||
class ConnectionSemaphore:
|
||||
def __init__(self, max_connections: int = None):
|
||||
self.max_connections = max_connections
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
||||
if not hasattr(self, "_semaphore"):
|
||||
if self.max_connections is None:
|
||||
self._semaphore = None
|
||||
else:
|
||||
self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.semaphore is not None:
|
||||
await self.semaphore.acquire()
|
||||
|
||||
def release(self) -> None:
|
||||
if self.semaphore is not None:
|
||||
self.semaphore.release()
|
||||
@ -1,280 +0,0 @@
|
||||
import http
|
||||
import typing
|
||||
from types import TracebackType
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .decoders import (
|
||||
ACCEPT_ENCODING,
|
||||
SUPPORTED_DECODERS,
|
||||
Decoder,
|
||||
IdentityDecoder,
|
||||
MultiDecoder,
|
||||
)
|
||||
from .exceptions import ResponseClosed, StreamConsumed
|
||||
|
||||
|
||||
class URL:
|
||||
def __init__(self, url: str = "") -> None:
|
||||
self.components = urlsplit(url)
|
||||
if not self.components.scheme:
|
||||
raise ValueError("No scheme included in URL.")
|
||||
if self.components.scheme not in ("http", "https"):
|
||||
raise ValueError('URL scheme must be "http" or "https".')
|
||||
if not self.components.hostname:
|
||||
raise ValueError("No hostname included in URL.")
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return self.components.scheme
|
||||
|
||||
@property
|
||||
def netloc(self) -> str:
|
||||
return self.components.netloc
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.components.path
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self.components.hostname
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
port = self.components.port
|
||||
if port is None:
|
||||
return {"https": 443, "http": 80}[self.scheme]
|
||||
return port
|
||||
|
||||
@property
|
||||
def target(self) -> str:
|
||||
path = self.path or "/"
|
||||
query = self.query
|
||||
if query:
|
||||
return path + "?" + query
|
||||
return path
|
||||
|
||||
@property
|
||||
def is_secure(self) -> bool:
|
||||
return self.components.scheme == "https"
|
||||
|
||||
@property
|
||||
def origin(self) -> "Origin":
|
||||
return Origin(self)
|
||||
|
||||
|
||||
class Origin:
|
||||
def __init__(self, url: typing.Union[str, URL]) -> None:
|
||||
if isinstance(url, str):
|
||||
url = URL(url)
|
||||
self.is_ssl = url.scheme == "https"
|
||||
self.hostname = url.hostname.lower()
|
||||
self.port = url.port
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.is_ssl == other.is_ssl
|
||||
and self.hostname == other.hostname
|
||||
and self.port == other.port
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.is_ssl, self.hostname, self.port))
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
):
|
||||
self.method = method.upper()
|
||||
self.url = URL(url) if isinstance(url, str) else url
|
||||
self.headers = list(headers)
|
||||
if isinstance(body, bytes):
|
||||
self.is_streaming = False
|
||||
self.body = body
|
||||
else:
|
||||
self.is_streaming = True
|
||||
self.body_aiter = body
|
||||
self.headers = self._auto_headers() + self.headers
|
||||
|
||||
def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
has_host = False
|
||||
has_content_length = False
|
||||
has_accept_encoding = False
|
||||
|
||||
for header, value in self.headers:
|
||||
header = header.strip().lower()
|
||||
if header == b"host":
|
||||
has_host = True
|
||||
elif header in (b"content-length", b"transfer-encoding"):
|
||||
has_content_length = True
|
||||
elif header == b"accept-encoding":
|
||||
has_accept_encoding = True
|
||||
|
||||
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
if not has_host:
|
||||
headers.append((b"host", self.url.netloc.encode("ascii")))
|
||||
if not has_content_length:
|
||||
if self.is_streaming:
|
||||
headers.append((b"transfer-encoding", b"chunked"))
|
||||
elif self.body:
|
||||
content_length = str(len(self.body)).encode()
|
||||
headers.append((b"content-length", content_length))
|
||||
if not has_accept_encoding:
|
||||
headers.append((b"accept-encoding", ACCEPT_ENCODING))
|
||||
|
||||
return headers
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
assert self.is_streaming
|
||||
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
|
||||
|
||||
class Response:
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
*,
|
||||
reason: typing.Optional[str] = None,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
on_close: typing.Callable = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
if not reason:
|
||||
try:
|
||||
self.reason = http.HTTPStatus(status_code).phrase
|
||||
except ValueError as exc:
|
||||
self.reason = ""
|
||||
else:
|
||||
self.reason = reason
|
||||
self.headers = list(headers)
|
||||
self.on_close = on_close
|
||||
self.is_closed = False
|
||||
self.is_streamed = False
|
||||
|
||||
decoders = [] # type: typing.List[Decoder]
|
||||
for header, value in self.headers:
|
||||
if header.strip().lower() == b"content-encoding":
|
||||
for part in value.split(b","):
|
||||
part = part.strip().lower()
|
||||
decoder_cls = SUPPORTED_DECODERS[part]
|
||||
decoders.append(decoder_cls())
|
||||
|
||||
if len(decoders) == 0:
|
||||
self.decoder = IdentityDecoder() # type: Decoder
|
||||
elif len(decoders) == 1:
|
||||
self.decoder = decoders[0]
|
||||
else:
|
||||
self.decoder = MultiDecoder(decoders)
|
||||
|
||||
if isinstance(body, bytes):
|
||||
self.is_closed = True
|
||||
self.body = self.decoder.decode(body) + self.decoder.flush()
|
||||
else:
|
||||
self.body_aiter = body
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""
|
||||
Read and return the response content.
|
||||
"""
|
||||
if not hasattr(self, "body"):
|
||||
body = b""
|
||||
async for part in self.stream():
|
||||
body += part
|
||||
self.body = body
|
||||
return self.body
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
"""
|
||||
if hasattr(self, "body"):
|
||||
yield self.body
|
||||
else:
|
||||
async for chunk in self.raw():
|
||||
yield self.decoder.decode(chunk)
|
||||
yield self.decoder.flush()
|
||||
|
||||
async def raw(self) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
if self.is_streamed:
|
||||
raise StreamConsumed()
|
||||
if self.is_closed:
|
||||
raise ResponseClosed()
|
||||
self.is_streamed = True
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
if not self.is_closed:
|
||||
self.is_closed = True
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
|
||||
|
||||
class Client:
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
stream: bool = False,
|
||||
) -> Response:
|
||||
request = Request(method, url, headers=headers, body=body)
|
||||
response = await self.send(request, ssl=ssl, timeout=timeout)
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
return response
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def __aenter__(self) -> "Client":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
@ -110,17 +110,17 @@ class MultiDecoder(Decoder):
|
||||
|
||||
|
||||
SUPPORTED_DECODERS = {
|
||||
b"identity": IdentityDecoder,
|
||||
b"deflate": DeflateDecoder,
|
||||
b"gzip": GZipDecoder,
|
||||
b"br": BrotliDecoder,
|
||||
"identity": IdentityDecoder,
|
||||
"deflate": DeflateDecoder,
|
||||
"gzip": GZipDecoder,
|
||||
"br": BrotliDecoder,
|
||||
}
|
||||
|
||||
|
||||
if brotli is None:
|
||||
SUPPORTED_DECODERS.pop(b"br") # pragma: nocover
|
||||
SUPPORTED_DECODERS.pop("br") # pragma: nocover
|
||||
|
||||
|
||||
ACCEPT_ENCODING = b", ".join(
|
||||
[key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
|
||||
ACCEPT_ENCODING = ", ".join(
|
||||
[key for key in SUPPORTED_DECODERS.keys() if key != "identity"]
|
||||
)
|
||||
|
||||
4
httpcore/dispatch/__init__.py
Normal file
4
httpcore/dispatch/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""
|
||||
Dispatch classes handle the raw network connections and the implementation
|
||||
details of making the HTTP request and receiving the response.
|
||||
"""
|
||||
93
httpcore/dispatch/connection.py
Normal file
93
httpcore/dispatch/connection.py
Normal file
@ -0,0 +1,93 @@
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
import h11
|
||||
|
||||
from ..config import (
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from ..exceptions import ConnectTimeout
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Origin, Request, Response
|
||||
from ..streams import Protocol, connect
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
# Callback signature: async def callback(conn: HTTPConnection) -> None
|
||||
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
|
||||
|
||||
|
||||
class HTTPConnection(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
release_func: typing.Optional[ReleaseCallback] = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = ssl
|
||||
self.timeout = timeout
|
||||
self.release_func = release_func
|
||||
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
|
||||
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
if self.h11_connection is None and self.h2_connection is None:
|
||||
await self.connect(**options)
|
||||
|
||||
if self.h2_connection is not None:
|
||||
response = await self.h2_connection.send(request, **options)
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
response = await self.h11_connection.send(request, **options)
|
||||
|
||||
return response
|
||||
|
||||
async def connect(self, **options: typing.Any) -> None:
|
||||
ssl = options.get("ssl", self.ssl)
|
||||
timeout = options.get("timeout", self.timeout)
|
||||
assert isinstance(ssl, SSLConfig)
|
||||
assert isinstance(timeout, TimeoutConfig)
|
||||
|
||||
hostname = self.origin.hostname
|
||||
port = self.origin.port
|
||||
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
|
||||
|
||||
if self.release_func is None:
|
||||
on_release = None
|
||||
else:
|
||||
on_release = functools.partial(self.release_func, self)
|
||||
|
||||
reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
|
||||
if protocol == Protocol.HTTP_2:
|
||||
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
|
||||
else:
|
||||
self.h11_connection = HTTP11Connection(
|
||||
reader, writer, on_release=on_release
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.h2_connection is not None:
|
||||
await self.h2_connection.close()
|
||||
elif self.h11_connection is not None:
|
||||
await self.h11_connection.close()
|
||||
|
||||
@property
|
||||
def is_http2(self) -> bool:
|
||||
return self.h2_connection is not None
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
if self.h2_connection is not None:
|
||||
return self.h2_connection.is_closed
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
return self.h11_connection.is_closed
|
||||
158
httpcore/dispatch/connection_pool.py
Normal file
158
httpcore/dispatch/connection_pool.py
Normal file
@ -0,0 +1,158 @@
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
from ..config import (
|
||||
DEFAULT_CA_BUNDLE_PATH,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from ..decoders import ACCEPT_ENCODING
|
||||
from ..exceptions import PoolTimeout
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Origin, Request, Response
|
||||
from ..streams import PoolSemaphore
|
||||
from .connection import HTTPConnection
|
||||
|
||||
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
|
||||
|
||||
class ConnectionStore(collections.abc.Sequence):
|
||||
"""
|
||||
We need to maintain collections of connections in a way that allows us to:
|
||||
|
||||
* Lookup connections by origin.
|
||||
* Iterate over connections by insertion time.
|
||||
* Return the total number of connections.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all = {} # type: typing.Dict[HTTPConnection, float]
|
||||
self.by_origin = (
|
||||
{}
|
||||
) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
|
||||
|
||||
def pop_by_origin(
|
||||
self, origin: Origin, http2_only: bool = False
|
||||
) -> typing.Optional[HTTPConnection]:
|
||||
try:
|
||||
connections = self.by_origin[origin]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
connection = next(reversed(list(connections.keys())))
|
||||
if http2_only and not connection.is_http2:
|
||||
return None
|
||||
|
||||
del connections[connection]
|
||||
if not connections:
|
||||
del self.by_origin[origin]
|
||||
del self.all[connection]
|
||||
|
||||
return connection
|
||||
|
||||
def add(self, connection: HTTPConnection) -> None:
|
||||
self.all[connection] = 0.0
|
||||
try:
|
||||
self.by_origin[connection.origin][connection] = 0.0
|
||||
except KeyError:
|
||||
self.by_origin[connection.origin] = {connection: 0.0}
|
||||
|
||||
def remove(self, connection: HTTPConnection) -> None:
|
||||
del self.all[connection]
|
||||
del self.by_origin[connection.origin][connection]
|
||||
if not self.by_origin[connection.origin]:
|
||||
del self.by_origin[connection.origin]
|
||||
|
||||
def clear(self) -> None:
|
||||
self.all.clear()
|
||||
self.by_origin.clear()
|
||||
|
||||
def __iter__(self) -> typing.Iterator[HTTPConnection]:
|
||||
return iter(self.all.keys())
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> typing.Any:
|
||||
if key in self.all:
|
||||
return key
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.all)
|
||||
|
||||
|
||||
class ConnectionPool(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
):
|
||||
self.ssl = ssl
|
||||
self.timeout = timeout
|
||||
self.limits = limits
|
||||
self.is_closed = False
|
||||
|
||||
self.max_connections = PoolSemaphore(limits)
|
||||
self.keepalive_connections = ConnectionStore()
|
||||
self.active_connections = ConnectionStore()
|
||||
|
||||
@property
|
||||
def num_connections(self) -> int:
|
||||
return len(self.keepalive_connections) + len(self.active_connections)
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
connection = await self.acquire_connection(request.url.origin)
|
||||
try:
|
||||
response = await connection.send(request, **options)
|
||||
except BaseException as exc:
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
raise exc
|
||||
return response
|
||||
|
||||
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
|
||||
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
|
||||
if connection is None:
|
||||
connection = self.keepalive_connections.pop_by_origin(origin)
|
||||
|
||||
if connection is None:
|
||||
await self.max_connections.acquire()
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
timeout=self.timeout,
|
||||
release_func=self.release_connection,
|
||||
)
|
||||
|
||||
self.active_connections.add(connection)
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, connection: HTTPConnection) -> None:
|
||||
if connection.is_closed:
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
elif (
|
||||
self.limits.soft_limit is not None
|
||||
and self.num_connections > self.limits.soft_limit
|
||||
):
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
await connection.close()
|
||||
else:
|
||||
self.active_connections.remove(connection)
|
||||
self.keepalive_connections.add(connection)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.is_closed = True
|
||||
connections = list(self.keepalive_connections)
|
||||
self.keepalive_connections.clear()
|
||||
for connection in connections:
|
||||
await connection.close()
|
||||
148
httpcore/dispatch/http11.py
Normal file
148
httpcore/dispatch/http11.py
Normal file
@ -0,0 +1,148 @@
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
from ..config import (
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from ..exceptions import ConnectTimeout, ReadTimeout
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Request, Response
|
||||
from ..streams import BaseReader, BaseWriter
|
||||
|
||||
H11Event = typing.Union[
|
||||
h11.Request,
|
||||
h11.Response,
|
||||
h11.InformationalResponse,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
h11.ConnectionClosed,
|
||||
]
|
||||
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
# Callback signature: async def callback() -> None
|
||||
# In practice the callback will be a functools partial, which binds
|
||||
# the `ConnectionPool.release_connection(conn: HTTPConnection)` method.
|
||||
OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
|
||||
|
||||
|
||||
class HTTP11Connection(Adapter):
|
||||
READ_NUM_BYTES = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: BaseReader,
|
||||
writer: BaseWriter,
|
||||
on_release: typing.Optional[OnReleaseCallback] = None,
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.on_release = on_release
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
timeout = options.get("timeout")
|
||||
stream = options.get("stream", False)
|
||||
assert timeout is None or isinstance(timeout, TimeoutConfig)
|
||||
|
||||
# Start sending the request.
|
||||
method = request.method.encode()
|
||||
target = request.url.full_path
|
||||
headers = request.headers.raw
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
# Send the request body.
|
||||
async for data in request.stream():
|
||||
event = h11.Data(data=data)
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
# Finalize sending the request.
|
||||
event = h11.EndOfMessage()
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
# Start getting the response.
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h11.InformationalResponse):
|
||||
event = await self._receive_event(timeout)
|
||||
|
||||
assert isinstance(event, h11.Response)
|
||||
reason = event.reason.decode("latin1")
|
||||
status_code = event.status_code
|
||||
headers = event.headers
|
||||
body = self._body_iter(timeout)
|
||||
|
||||
response = Response(
|
||||
status_code=status_code,
|
||||
reason=reason,
|
||||
protocol="HTTP/1.1",
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=self.response_closed,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
|
||||
self.h11_state.send(event)
|
||||
except h11.ProtocolError:
|
||||
# If we're in some other state then it's a premature close,
|
||||
# and we'll end up in h11.ERROR.
|
||||
pass
|
||||
|
||||
await self.writer.close()
|
||||
|
||||
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
|
||||
event = await self._receive_event(timeout)
|
||||
while isinstance(event, h11.Data):
|
||||
yield event.data
|
||||
event = await self._receive_event(timeout)
|
||||
assert isinstance(event, h11.EndOfMessage)
|
||||
|
||||
async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None:
|
||||
data = self.h11_state.send(event)
|
||||
await self.writer.write(data, timeout)
|
||||
|
||||
async def _receive_event(self, timeout: OptionalTimeout) -> H11Event:
|
||||
event = self.h11_state.next_event()
|
||||
|
||||
while event is h11.NEED_DATA:
|
||||
data = await self.reader.read(self.READ_NUM_BYTES, timeout)
|
||||
self.h11_state.receive_data(data)
|
||||
event = self.h11_state.next_event()
|
||||
|
||||
return event
|
||||
|
||||
async def response_closed(self) -> None:
|
||||
if (
|
||||
self.h11_state.our_state is h11.DONE
|
||||
and self.h11_state.their_state is h11.DONE
|
||||
):
|
||||
self.h11_state.start_next_cycle()
|
||||
else:
|
||||
await self.close()
|
||||
|
||||
if self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
156
httpcore/dispatch/http2.py
Normal file
156
httpcore/dispatch/http2.py
Normal file
@ -0,0 +1,156 @@
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from ..config import (
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from ..exceptions import ConnectTimeout, ReadTimeout
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Request, Response
|
||||
from ..streams import BaseReader, BaseWriter
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
|
||||
class HTTP2Connection(Adapter):
|
||||
READ_NUM_BYTES = 4096
|
||||
|
||||
def __init__(
|
||||
self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.on_release = on_release
|
||||
self.h2_state = h2.connection.H2Connection()
|
||||
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
|
||||
self.initialized = False
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
timeout = options.get("timeout")
|
||||
stream = options.get("stream", False)
|
||||
assert timeout is None or isinstance(timeout, TimeoutConfig)
|
||||
|
||||
# Start sending the request.
|
||||
if not self.initialized:
|
||||
self.initiate_connection()
|
||||
stream_id = await self.send_headers(request, timeout)
|
||||
self.events[stream_id] = []
|
||||
|
||||
# Send the request body.
|
||||
async for data in request.stream():
|
||||
await self.send_data(stream_id, data, timeout)
|
||||
|
||||
# Finalize sending the request.
|
||||
await self.end_stream(stream_id, timeout)
|
||||
|
||||
# Start getting the response.
|
||||
while True:
|
||||
event = await self.receive_event(stream_id, timeout)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode())
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
|
||||
body = self.body_iter(stream_id, timeout)
|
||||
on_close = functools.partial(self.response_closed, stream_id=stream_id)
|
||||
|
||||
response = Response(
|
||||
status_code=status_code,
|
||||
protocol="HTTP/2",
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=on_close,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.writer.close()
|
||||
|
||||
def initiate_connection(self) -> None:
|
||||
self.h2_state.initiate_connection()
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.writer.write_no_block(data_to_send)
|
||||
self.initialized = True
|
||||
|
||||
async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int:
|
||||
stream_id = self.h2_state.get_next_available_stream_id()
|
||||
headers = [
|
||||
(b":method", request.method.encode()),
|
||||
(b":authority", request.url.hostname.encode()),
|
||||
(b":scheme", request.url.scheme.encode()),
|
||||
(b":path", request.url.full_path.encode()),
|
||||
] + request.headers.raw
|
||||
self.h2_state.send_headers(stream_id, headers)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
return stream_id
|
||||
|
||||
async def send_data(
|
||||
self, stream_id: int, data: bytes, timeout: OptionalTimeout
|
||||
) -> None:
|
||||
self.h2_state.send_data(stream_id, data)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
|
||||
async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None:
|
||||
self.h2_state.end_stream(stream_id)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
|
||||
async def body_iter(
|
||||
self, stream_id: int, timeout: OptionalTimeout
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
while True:
|
||||
event = await self.receive_event(stream_id, timeout)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
yield event.data
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
break
|
||||
|
||||
async def receive_event(
|
||||
self, stream_id: int, timeout: OptionalTimeout
|
||||
) -> h2.events.Event:
|
||||
while not self.events[stream_id]:
|
||||
data = await self.reader.read(self.READ_NUM_BYTES, timeout)
|
||||
events = self.h2_state.receive_data(data)
|
||||
for event in events:
|
||||
if getattr(event, "stream_id", 0):
|
||||
self.events[event.stream_id].append(event)
|
||||
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
|
||||
return self.events[stream_id].pop(0)
|
||||
|
||||
async def response_closed(self, stream_id: int) -> None:
|
||||
del self.events[stream_id]
|
||||
|
||||
if not self.events and self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
@ -16,12 +16,43 @@ class ReadTimeout(Timeout):
|
||||
"""
|
||||
|
||||
|
||||
class WriteTimeout(Timeout):
|
||||
"""
|
||||
Timeout while writing request data.
|
||||
"""
|
||||
|
||||
|
||||
class PoolTimeout(Timeout):
|
||||
"""
|
||||
Timeout while waiting to acquire a connection from the pool.
|
||||
"""
|
||||
|
||||
|
||||
class RedirectError(Exception):
|
||||
"""
|
||||
Base class for HTTP redirect errors.
|
||||
"""
|
||||
|
||||
|
||||
class TooManyRedirects(RedirectError):
|
||||
"""
|
||||
Too many redirects.
|
||||
"""
|
||||
|
||||
|
||||
class RedirectBodyUnavailable(RedirectError):
|
||||
"""
|
||||
Got a redirect response, but the request body was streaming, and is
|
||||
no longer available.
|
||||
"""
|
||||
|
||||
|
||||
class RedirectLoop(RedirectError):
|
||||
"""
|
||||
Infinite redirect loop.
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
"""
|
||||
Malformed HTTP.
|
||||
@ -40,3 +71,8 @@ class ResponseClosed(Exception):
|
||||
Attempted to read or stream response content, but the request has been
|
||||
closed without loading the body.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidURL(Exception):
|
||||
"""
|
||||
"""
|
||||
|
||||
@ -1,163 +0,0 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
|
||||
from .datastructures import Client, Origin, Request, Response
|
||||
from .exceptions import ConnectTimeout, ReadTimeout
|
||||
|
||||
H11Event = typing.Union[
|
||||
h11.Request,
|
||||
h11.Response,
|
||||
h11.InformationalResponse,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
h11.ConnectionClosed,
|
||||
]
|
||||
|
||||
|
||||
class HTTP11Connection(Client):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = ssl
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
assert request.url.origin == self.origin
|
||||
|
||||
if ssl is None:
|
||||
ssl = self.ssl
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
# Make the connection
|
||||
if self._reader is None:
|
||||
await self._connect(ssl, timeout)
|
||||
|
||||
# Start sending the request.
|
||||
method = request.method.encode()
|
||||
target = request.url.target
|
||||
headers = request.headers
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
await self._send_event(event)
|
||||
|
||||
# Send the request body.
|
||||
if request.is_streaming:
|
||||
async for data in request.stream():
|
||||
event = h11.Data(data=data)
|
||||
await self._send_event(event)
|
||||
elif request.body:
|
||||
event = h11.Data(data=request.body)
|
||||
await self._send_event(event)
|
||||
|
||||
# Finalize sending the request.
|
||||
event = h11.EndOfMessage()
|
||||
await self._send_event(event)
|
||||
|
||||
# Start getting the response.
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h11.InformationalResponse):
|
||||
event = await self._receive_event(timeout)
|
||||
assert isinstance(event, h11.Response)
|
||||
reason = event.reason.decode("latin1")
|
||||
status_code = event.status_code
|
||||
headers = event.headers
|
||||
body = self._body_iter(timeout)
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
reason=reason,
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=self._release,
|
||||
)
|
||||
|
||||
async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
|
||||
hostname = self.origin.hostname
|
||||
port = self.origin.port
|
||||
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
|
||||
|
||||
try:
|
||||
self._reader, self._writer = await asyncio.wait_for( # type: ignore
|
||||
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
||||
timeout.connect_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectTimeout()
|
||||
|
||||
async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
|
||||
event = await self._receive_event(timeout)
|
||||
while isinstance(event, h11.Data):
|
||||
yield event.data
|
||||
event = await self._receive_event(timeout)
|
||||
assert isinstance(event, h11.EndOfMessage)
|
||||
|
||||
async def _send_event(self, event: H11Event) -> None:
|
||||
assert self._writer is not None
|
||||
|
||||
data = self._h11_state.send(event)
|
||||
self._writer.write(data)
|
||||
|
||||
async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
|
||||
assert self._reader is not None
|
||||
|
||||
event = self._h11_state.next_event()
|
||||
|
||||
while event is h11.NEED_DATA:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self._reader.read(2048), timeout.read_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout()
|
||||
self._h11_state.receive_data(data)
|
||||
event = self._h11_state.next_event()
|
||||
|
||||
return event
|
||||
|
||||
async def _release(self) -> None:
|
||||
assert self._writer is not None
|
||||
|
||||
if (
|
||||
self._h11_state.our_state is h11.DONE
|
||||
and self._h11_state.their_state is h11.DONE
|
||||
):
|
||||
self._h11_state.start_next_cycle()
|
||||
else:
|
||||
await self.close()
|
||||
|
||||
if self.on_release is not None:
|
||||
await self.on_release(self)
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
|
||||
self._h11_state.send(event)
|
||||
except h11.ProtocolError:
|
||||
# If we're in some other state then it's a premature close,
|
||||
# and we'll end up in h11.ERROR.
|
||||
pass
|
||||
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
67
httpcore/interfaces.py
Normal file
67
httpcore/interfaces.py
Normal file
@ -0,0 +1,67 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .config import TimeoutConfig
|
||||
from .models import URL, Request, Response
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
|
||||
class Adapter:
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
**options: typing.Any,
|
||||
) -> Response:
|
||||
request = Request(method, url, headers=headers, body=body)
|
||||
self.prepare_request(request)
|
||||
response = await self.send(request, **options)
|
||||
return response
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def __aenter__(self) -> "Adapter":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
class BaseReader:
|
||||
async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class BaseWriter:
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class BasePoolSemaphore:
|
||||
async def acquire(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
400
httpcore/models.py
Normal file
400
httpcore/models.py
Normal file
@ -0,0 +1,400 @@
|
||||
import http
|
||||
import typing
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .decoders import (
|
||||
ACCEPT_ENCODING,
|
||||
SUPPORTED_DECODERS,
|
||||
Decoder,
|
||||
IdentityDecoder,
|
||||
MultiDecoder,
|
||||
)
|
||||
from .exceptions import ResponseClosed, StreamConsumed
|
||||
from .utils import normalize_header_key, normalize_header_value
|
||||
|
||||
|
||||
class URL:
|
||||
def __init__(self, url: str = "") -> None:
|
||||
self.components = urlsplit(url)
|
||||
if not self.components.scheme:
|
||||
raise ValueError("No scheme included in URL.")
|
||||
if self.components.scheme not in ("http", "https"):
|
||||
raise ValueError('URL scheme must be "http" or "https".')
|
||||
if not self.components.hostname:
|
||||
raise ValueError("No hostname included in URL.")
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return self.components.scheme
|
||||
|
||||
@property
|
||||
def netloc(self) -> str:
|
||||
return self.components.netloc
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.components.path
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
|
||||
@property
|
||||
def fragment(self) -> str:
|
||||
return self.components.fragment
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self.components.hostname
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
port = self.components.port
|
||||
if port is None:
|
||||
return {"https": 443, "http": 80}[self.scheme]
|
||||
return port
|
||||
|
||||
@property
|
||||
def full_path(self) -> str:
|
||||
path = self.path or "/"
|
||||
query = self.query
|
||||
if query:
|
||||
return path + "?" + query
|
||||
return path
|
||||
|
||||
@property
|
||||
def is_secure(self) -> bool:
|
||||
return self.components.scheme == "https"
|
||||
|
||||
@property
|
||||
def origin(self) -> "Origin":
|
||||
return Origin(self)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, URL) and str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.components.geturl()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
url_str = str(self)
|
||||
return f"{class_name}({url_str!r})"
|
||||
|
||||
|
||||
class Origin:
|
||||
def __init__(self, url: typing.Union[str, URL]) -> None:
|
||||
if isinstance(url, str):
|
||||
url = URL(url)
|
||||
self.is_ssl = url.scheme == "https"
|
||||
self.hostname = url.hostname.lower()
|
||||
self.port = url.port
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.is_ssl == other.is_ssl
|
||||
and self.hostname == other.hostname
|
||||
and self.port == other.port
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.is_ssl, self.hostname, self.port))
|
||||
|
||||
|
||||
HeaderTypes = typing.Union[
|
||||
"Headers",
|
||||
typing.Dict[typing.AnyStr, typing.AnyStr],
|
||||
typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
|
||||
]
|
||||
|
||||
|
||||
class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
A case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(self, headers: HeaderTypes = None) -> None:
|
||||
if headers is None:
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
elif isinstance(headers, Headers):
|
||||
self._list = list(headers.raw)
|
||||
elif isinstance(headers, dict):
|
||||
self._list = [
|
||||
(normalize_header_key(k), normalize_header_value(v))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
else:
|
||||
self._list = [
|
||||
(normalize_header_key(k), normalize_header_value(v)) for k, v in headers
|
||||
]
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
return self._list
|
||||
|
||||
def keys(self) -> typing.List[str]: # type: ignore
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def values(self) -> typing.List[str]: # type: ignore
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
|
||||
return [
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in self._list
|
||||
]
|
||||
|
||||
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [
|
||||
item_value.decode("latin-1")
|
||||
for item_key, item_value in self._list
|
||||
if item_key == get_header_key
|
||||
]
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return header_value.decode("latin-1")
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
Retains insertion order.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
found_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
found_indexes.append(idx)
|
||||
|
||||
for idx in reversed(found_indexes[1:]):
|
||||
del self._list[idx]
|
||||
|
||||
if found_indexes:
|
||||
idx = found_indexes[0]
|
||||
self._list[idx] = (set_key, set_value)
|
||||
else:
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""
|
||||
Remove the header `key`.
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
as_dict = dict(self.items())
|
||||
if len(as_dict) == len(self):
|
||||
return f"{class_name}({as_dict!r})"
|
||||
return f"{class_name}(raw={self.raw!r})"
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
):
|
||||
self.method = method.upper()
|
||||
self.url = URL(url) if isinstance(url, str) else url
|
||||
if isinstance(body, bytes):
|
||||
self.is_streaming = False
|
||||
self.body = body
|
||||
else:
|
||||
self.is_streaming = True
|
||||
self.body_aiter = body
|
||||
self.headers = Headers(headers)
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""
|
||||
Read and return the response content.
|
||||
"""
|
||||
if not hasattr(self, "body"):
|
||||
body = b""
|
||||
async for part in self.stream():
|
||||
body += part
|
||||
self.body = body
|
||||
return self.body
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
if self.is_streaming:
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
elif self.body:
|
||||
yield self.body
|
||||
|
||||
def prepare(self) -> None:
|
||||
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
has_host = "host" in self.headers
|
||||
has_content_length = (
|
||||
"content-length" in self.headers or "transfer-encoding" in self.headers
|
||||
)
|
||||
has_accept_encoding = "accept-encoding" in self.headers
|
||||
|
||||
if not has_host:
|
||||
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
|
||||
if not has_content_length:
|
||||
if self.is_streaming:
|
||||
auto_headers.append((b"transfer-encoding", b"chunked"))
|
||||
elif self.body:
|
||||
content_length = str(len(self.body)).encode()
|
||||
auto_headers.append((b"content-length", content_length))
|
||||
if not has_accept_encoding:
|
||||
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
|
||||
|
||||
for item in reversed(auto_headers):
|
||||
self.headers.raw.insert(0, item)
|
||||
|
||||
|
||||
class Response:
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
*,
|
||||
reason: typing.Optional[str] = None,
|
||||
protocol: typing.Optional[str] = None,
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
on_close: typing.Callable = None,
|
||||
request: Request = None,
|
||||
history: typing.List["Response"] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
if not reason:
|
||||
try:
|
||||
self.reason = http.HTTPStatus(status_code).phrase
|
||||
except ValueError as exc:
|
||||
self.reason = ""
|
||||
else:
|
||||
self.reason = reason
|
||||
self.protocol = protocol
|
||||
self.headers = Headers(headers)
|
||||
self.on_close = on_close
|
||||
self.is_closed = False
|
||||
self.is_streamed = False
|
||||
|
||||
decoders = [] # type: typing.List[Decoder]
|
||||
value = self.headers.get("content-encoding", "identity")
|
||||
for part in value.split(","):
|
||||
part = part.strip().lower()
|
||||
decoder_cls = SUPPORTED_DECODERS[part]
|
||||
decoders.append(decoder_cls())
|
||||
|
||||
if len(decoders) == 0:
|
||||
self.decoder = IdentityDecoder() # type: Decoder
|
||||
elif len(decoders) == 1:
|
||||
self.decoder = decoders[0]
|
||||
else:
|
||||
self.decoder = MultiDecoder(decoders)
|
||||
|
||||
if isinstance(body, bytes):
|
||||
self.is_closed = True
|
||||
self.body = self.decoder.decode(body) + self.decoder.flush()
|
||||
else:
|
||||
self.body_aiter = body
|
||||
|
||||
self.request = request
|
||||
self.history = [] if history is None else list(history)
|
||||
|
||||
@property
|
||||
def url(self) -> typing.Optional[URL]:
|
||||
return None if self.request is None else self.request.url
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""
|
||||
Read and return the response content.
|
||||
"""
|
||||
if not hasattr(self, "body"):
|
||||
body = b""
|
||||
async for part in self.stream():
|
||||
body += part
|
||||
self.body = body
|
||||
return self.body
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
"""
|
||||
if hasattr(self, "body"):
|
||||
yield self.body
|
||||
else:
|
||||
async for chunk in self.raw():
|
||||
yield self.decoder.decode(chunk)
|
||||
yield self.decoder.flush()
|
||||
|
||||
async def raw(self) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
if self.is_streamed:
|
||||
raise StreamConsumed()
|
||||
if self.is_closed:
|
||||
raise ResponseClosed()
|
||||
self.is_streamed = True
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
if not self.is_closed:
|
||||
self.is_closed = True
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
|
||||
@property
|
||||
def is_redirect(self) -> bool:
|
||||
return (
|
||||
self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
|
||||
)
|
||||
61
httpcore/status_codes.py
Normal file
61
httpcore/status_codes.py
Normal file
@ -0,0 +1,61 @@
|
||||
import enum
|
||||
|
||||
codes = enum.IntEnum(
|
||||
"StatusCode",
|
||||
[
|
||||
("continue", 100),
|
||||
("switching_protocols", 101),
|
||||
("processing", 102),
|
||||
("ok", 200),
|
||||
("created", 201),
|
||||
("accepted", 202),
|
||||
("non_authoritative_information", 203),
|
||||
("no_content", 204),
|
||||
("reset_content", 205),
|
||||
("partial_content", 206),
|
||||
("multi_status", 207),
|
||||
("already_reported", 208),
|
||||
("im_used", 226),
|
||||
("multiple_choices", 300),
|
||||
("moved_permanently", 301),
|
||||
("found", 302),
|
||||
("see_other", 303),
|
||||
("not_modified", 304),
|
||||
("use_proxy", 305),
|
||||
("temporary_redirect", 307),
|
||||
("permanent_redirect", 308),
|
||||
("bad_request", 400),
|
||||
("unauthorized", 401),
|
||||
("payment_required", 402),
|
||||
("forbidden", 403),
|
||||
("not_found", 404),
|
||||
("method_not_allowed", 405),
|
||||
("not_acceptable", 406),
|
||||
("proxy_authentication_required", 407),
|
||||
("request_timeout", 408),
|
||||
("conflict", 409),
|
||||
("gone", 410),
|
||||
("length_required", 411),
|
||||
("precondition_failed", 412),
|
||||
("request_entity_too_large", 413),
|
||||
("request_uri_too_long", 414),
|
||||
("unsupported_media_type", 415),
|
||||
("requested_range_not_satisfiable", 416),
|
||||
("expectation_failed", 417),
|
||||
("unprocessable_entity", 422),
|
||||
("locked", 423),
|
||||
("failed_dependency", 424),
|
||||
("precondition_required", 428),
|
||||
("too_many_requests", 429),
|
||||
("request_header_fields_too_large", 431),
|
||||
("unavailable_for_legal_reasons", 451),
|
||||
("internal_server_error", 500),
|
||||
("not_implemented", 501),
|
||||
("bad_gateway", 502),
|
||||
("service_unavailable", 503),
|
||||
("gateway_timeout", 504),
|
||||
("http_version_not_supported", 505),
|
||||
("insufficient_storage", 507),
|
||||
("network_authentication_required", 511),
|
||||
],
|
||||
)
|
||||
133
httpcore/streams.py
Normal file
133
httpcore/streams.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
The `Reader` and `Writer` classes here provide a lightweight layer over
|
||||
`asyncio.StreamReader` and `asyncio.StreamWriter`.
|
||||
|
||||
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
|
||||
|
||||
These classes help encapsulate the timeout logic, make it easier to unit-test
|
||||
protocols, and help keep the rest of the package more `async`/`await`
|
||||
based, and less strictly `asyncio`-specific.
|
||||
"""
|
||||
import asyncio
|
||||
import enum
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
|
||||
class Protocol(enum.Enum):
|
||||
HTTP_11 = 1
|
||||
HTTP_2 = 2
|
||||
|
||||
|
||||
class Reader(BaseReader):
|
||||
def __init__(
|
||||
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
|
||||
) -> None:
|
||||
self.stream_reader = stream_reader
|
||||
self.timeout = timeout
|
||||
|
||||
async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self.stream_reader.read(n), timeout.read_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class Writer(BaseWriter):
|
||||
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
|
||||
self.stream_writer = stream_writer
|
||||
self.timeout = timeout
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
self.stream_writer.write(data)
|
||||
|
||||
async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
self.stream_writer.write(data)
|
||||
try:
|
||||
data = await asyncio.wait_for( # type: ignore
|
||||
self.stream_writer.drain(), timeout.write_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise WriteTimeout()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.stream_writer.close()
|
||||
|
||||
|
||||
class PoolSemaphore(BasePoolSemaphore):
|
||||
def __init__(self, limits: PoolLimits):
|
||||
self.limits = limits
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
||||
if not hasattr(self, "_semaphore"):
|
||||
max_connections = self.limits.hard_limit
|
||||
if max_connections is None:
|
||||
self._semaphore = None
|
||||
else:
|
||||
self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
timeout = self.limits.pool_timeout
|
||||
try:
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise PoolTimeout()
|
||||
|
||||
def release(self) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
self.semaphore.release()
|
||||
|
||||
|
||||
async def connect(
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
) -> typing.Tuple[Reader, Writer, Protocol]:
|
||||
try:
|
||||
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
|
||||
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
||||
timeout.connect_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectTimeout()
|
||||
|
||||
ssl_object = stream_writer.get_extra_info("ssl_object")
|
||||
if ssl_object is None:
|
||||
ident = "http/1.1"
|
||||
else:
|
||||
ident = ssl_object.selected_alpn_protocol()
|
||||
if ident is None:
|
||||
ident = ssl_object.selected_npn_protocol()
|
||||
|
||||
reader = Reader(stream_reader=stream_reader, timeout=timeout)
|
||||
writer = Writer(stream_writer=stream_writer, timeout=timeout)
|
||||
protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
|
||||
|
||||
return (reader, writer, protocol)
|
||||
@ -3,8 +3,9 @@ import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .connectionpool import ConnectionPool
|
||||
from .datastructures import URL, Client, Response
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
from .interfaces import Adapter
|
||||
from .models import URL, Headers, Response
|
||||
|
||||
|
||||
class SyncResponse:
|
||||
@ -21,7 +22,7 @@ class SyncResponse:
|
||||
return self._response.reason
|
||||
|
||||
@property
|
||||
def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
def headers(self) -> Headers:
|
||||
return self._response.headers
|
||||
|
||||
@property
|
||||
@ -44,8 +45,8 @@ class SyncResponse:
|
||||
|
||||
|
||||
class SyncClient:
|
||||
def __init__(self, client: Client):
|
||||
self._client = client
|
||||
def __init__(self, adapter: Adapter):
|
||||
self._client = adapter
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def request(
|
||||
@ -53,22 +54,12 @@ class SyncClient:
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
stream: bool = False,
|
||||
**options: typing.Any
|
||||
) -> SyncResponse:
|
||||
response = self._loop.run_until_complete(
|
||||
self._client.request(
|
||||
method,
|
||||
url,
|
||||
headers=headers,
|
||||
body=body,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
self._client.request(method, url, headers=headers, body=body, **options)
|
||||
)
|
||||
return SyncResponse(response, self._loop)
|
||||
|
||||
|
||||
71
httpcore/utils.py
Normal file
71
httpcore/utils.py
Normal file
@ -0,0 +1,71 @@
|
||||
import typing
|
||||
from urllib.parse import quote
|
||||
|
||||
from .exceptions import InvalidURL
|
||||
|
||||
# The unreserved URI characters (RFC 3986)
|
||||
UNRESERVED_SET = frozenset(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
|
||||
)
|
||||
|
||||
|
||||
def unquote_unreserved(uri: str) -> str:
|
||||
"""
|
||||
Un-escape any percent-escape sequences in a URI that are unreserved
|
||||
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
|
||||
"""
|
||||
parts = uri.split("%")
|
||||
for i in range(1, len(parts)):
|
||||
h = parts[i][0:2]
|
||||
if len(h) == 2 and h.isalnum():
|
||||
try:
|
||||
c = chr(int(h, 16))
|
||||
except ValueError:
|
||||
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
|
||||
|
||||
if c in UNRESERVED_SET:
|
||||
parts[i] = c + parts[i][2:]
|
||||
else:
|
||||
parts[i] = "%" + parts[i]
|
||||
else:
|
||||
parts[i] = "%" + parts[i]
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def requote_uri(uri: str) -> str:
|
||||
"""
|
||||
Re-quote the given URI.
|
||||
|
||||
This function passes the given URI through an unquote/quote cycle to
|
||||
ensure that it is fully and consistently quoted.
|
||||
"""
|
||||
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
|
||||
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
|
||||
try:
|
||||
# Unquote only the unreserved characters
|
||||
# Then quote only illegal characters (do not quote reserved,
|
||||
# unreserved, or '%')
|
||||
return quote(unquote_unreserved(uri), safe=safe_with_percent)
|
||||
except InvalidURL:
|
||||
# We couldn't unquote the given URI, so let's try quoting it, but
|
||||
# there may be unquoted '%'s in the URI. We need to make sure they're
|
||||
# properly quoted so they do not cause issues elsewhere.
|
||||
return quote(uri, safe=safe_without_percent)
|
||||
|
||||
|
||||
def normalize_header_key(value: typing.AnyStr) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header key.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
return value.lower()
|
||||
return value.encode("latin-1").lower()
|
||||
|
||||
|
||||
def normalize_header_value(value: typing.AnyStr) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header value.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return value.encode("latin-1")
|
||||
@ -1,5 +1,6 @@
|
||||
certifi
|
||||
h11
|
||||
h2
|
||||
|
||||
# Optional
|
||||
brotlipy
|
||||
|
||||
2
setup.py
2
setup.py
@ -47,7 +47,7 @@ setup(
|
||||
author_email="tom@tomchristie.com",
|
||||
packages=get_packages("httpcore"),
|
||||
data_files=[("", ["LICENSE.md"])],
|
||||
install_requires=["h11", "certifi"],
|
||||
install_requires=["h11", "h2", "certifi"],
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Web Environment",
|
||||
|
||||
219
tests/adapters/test_redirects.py
Normal file
219
tests/adapters/test_redirects.py
Normal file
@ -0,0 +1,219 @@
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import pytest
|
||||
|
||||
from httpcore import (
|
||||
URL,
|
||||
Adapter,
|
||||
RedirectAdapter,
|
||||
RedirectBodyUnavailable,
|
||||
RedirectLoop,
|
||||
Request,
|
||||
Response,
|
||||
TooManyRedirects,
|
||||
codes,
|
||||
)
|
||||
|
||||
|
||||
class MockDispatch(Adapter):
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, request: Request, **options) -> Response:
|
||||
if request.url.path == "/redirect_301":
|
||||
status_code = codes.moved_permanently
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_302":
|
||||
status_code = codes.found
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_303":
|
||||
status_code = codes.see_other
|
||||
headers = {"location": "https://example.org/"}
|
||||
return Response(status_code, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/relative_redirect":
|
||||
headers = {"location": "/"}
|
||||
return Response(codes.see_other, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/no_scheme_redirect":
|
||||
headers = {"location": "//example.org/"}
|
||||
return Response(codes.see_other, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/multiple_redirects":
|
||||
params = parse_qs(request.url.query)
|
||||
count = int(params.get("count", "0")[0])
|
||||
redirect_count = count - 1
|
||||
code = codes.see_other if count else codes.ok
|
||||
location = "/multiple_redirects"
|
||||
if redirect_count:
|
||||
location += "?count=" + str(redirect_count)
|
||||
headers = {"location": location} if count else {}
|
||||
return Response(code, headers=headers, request=request)
|
||||
|
||||
if request.url.path == "/redirect_loop":
|
||||
headers = {"location": "/redirect_loop"}
|
||||
return Response(codes.see_other, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/cross_domain":
|
||||
headers = {"location": "https://example.org/cross_domain_target"}
|
||||
return Response(codes.see_other, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/cross_domain_target":
|
||||
headers = dict(request.headers.items())
|
||||
body = json.dumps({"headers": headers}).encode()
|
||||
return Response(codes.ok, body=body, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_body":
|
||||
body = await request.read()
|
||||
headers = {"location": "/redirect_body_target"}
|
||||
return Response(codes.permanent_redirect, headers=headers, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_body_target":
|
||||
body = await request.read()
|
||||
body = json.dumps({"body": body.decode()}).encode()
|
||||
return Response(codes.ok, body=body, request=request)
|
||||
|
||||
return Response(codes.ok, body=b"Hello, world!", request=request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_301():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_301")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_302():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_302")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_303():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/redirect_303")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disallow_redirects():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_303", allow_redirects=False)
|
||||
assert response.status_code == codes.see_other
|
||||
assert response.url == URL("https://example.org/redirect_303")
|
||||
assert len(response.history) == 0
|
||||
|
||||
response = await response.next()
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relative_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/relative_redirect")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_scheme_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/no_scheme_redirect")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fragment_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
url = "https://example.org/relative_redirect#fragment"
|
||||
response = await client.request("GET", url)
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/#fragment")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_redirects():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
url = "https://example.org/multiple_redirects?count=20"
|
||||
response = await client.request("GET", url)
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/multiple_redirects")
|
||||
assert len(response.history) == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_many_redirects():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
with pytest.raises(TooManyRedirects):
|
||||
await client.request("GET", "https://example.org/multiple_redirects?count=21")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_loop():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
with pytest.raises(RedirectLoop):
|
||||
await client.request("GET", "https://example.org/redirect_loop")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_domain_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
url = "https://example.com/cross_domain"
|
||||
headers = {"Authorization": "abc"}
|
||||
response = await client.request("GET", url, headers=headers)
|
||||
data = json.loads(response.body.decode())
|
||||
assert response.url == URL("https://example.org/cross_domain_target")
|
||||
assert data == {"headers": {}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_domain_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
url = "https://example.org/cross_domain"
|
||||
headers = {"Authorization": "abc"}
|
||||
response = await client.request("GET", url, headers=headers)
|
||||
data = json.loads(response.body.decode())
|
||||
assert response.url == URL("https://example.org/cross_domain_target")
|
||||
assert data == {"headers": {"authorization": "abc"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_body_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
url = "https://example.org/redirect_body"
|
||||
body = b"Example request body"
|
||||
response = await client.request("POST", url, body=body)
|
||||
data = json.loads(response.body.decode())
|
||||
assert response.url == URL("https://example.org/redirect_body_target")
|
||||
assert data == {"body": "Example request body"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_redirect_streaming_body():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
url = "https://example.org/redirect_body"
|
||||
|
||||
async def body():
|
||||
yield b"Example request body"
|
||||
|
||||
with pytest.raises(RedirectBodyUnavailable):
|
||||
await client.request("POST", url, body=body())
|
||||
@ -5,16 +5,16 @@ import httpcore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request("GET", "http://127.0.0.1:8000/")
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request(
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request(
|
||||
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@ -22,8 +22,8 @@ async def test_post(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert response.status_code == 200
|
||||
assert not hasattr(response, "body")
|
||||
body = await response.read()
|
||||
@ -36,8 +36,8 @@ async def test_stream_request(server):
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request(
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request(
|
||||
"POST", "http://127.0.0.1:8000/", body=hello_world()
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
@ -13,13 +13,15 @@ def test_timeout_repr():
|
||||
timeout = httpcore.TimeoutConfig(read_timeout=5.0)
|
||||
assert (
|
||||
repr(timeout)
|
||||
== "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)"
|
||||
== "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)"
|
||||
)
|
||||
|
||||
|
||||
def test_limits_repr():
|
||||
limits = httpcore.PoolLimits(hard_limit=100)
|
||||
assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
|
||||
assert (
|
||||
repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)"
|
||||
)
|
||||
|
||||
|
||||
def test_ssl_eq():
|
||||
@ -35,24 +37,3 @@ def test_timeout_eq():
|
||||
def test_limits_eq():
|
||||
limits = httpcore.PoolLimits(hard_limit=100)
|
||||
assert limits == httpcore.PoolLimits(hard_limit=100)
|
||||
|
||||
|
||||
def test_ssl_hash():
|
||||
cache = {}
|
||||
ssl = httpcore.SSLConfig(verify=False)
|
||||
cache[ssl] = "example"
|
||||
assert cache[httpcore.SSLConfig(verify=False)] == "example"
|
||||
|
||||
|
||||
def test_timeout_hash():
|
||||
cache = {}
|
||||
timeout = httpcore.TimeoutConfig(timeout=5.0)
|
||||
cache[timeout] = "example"
|
||||
assert cache[httpcore.TimeoutConfig(timeout=5.0)] == "example"
|
||||
|
||||
|
||||
def test_limits_hash():
|
||||
cache = {}
|
||||
limits = httpcore.PoolLimits(hard_limit=100)
|
||||
cache[limits] = "example"
|
||||
assert cache[httpcore.PoolLimits(hard_limit=100)] == "example"
|
||||
|
||||
@ -10,12 +10,12 @@ async def test_keepalive_connections(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -25,12 +25,12 @@ async def test_differing_connection_keys(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 2
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -42,12 +42,12 @@ async def test_soft_limit(server):
|
||||
|
||||
async with httpcore.ConnectionPool(limits=limits) as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert http.num_active_connections == 1
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
await response.read()
|
||||
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert http.num_active_connections == 1
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert http.num_active_connections == 2
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 2
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
await response_b.read()
|
||||
assert http.num_active_connections == 1
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
await response_a.read()
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 2
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -97,8 +97,8 @@ async def test_close_connections(server):
|
||||
headers = [(b"connection", b"close")]
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -110,8 +110,8 @@ async def test_standard_response_close(server):
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
await response.read()
|
||||
await response.close()
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -122,5 +122,5 @@ async def test_premature_response_close(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
await response.close()
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
@ -5,7 +5,7 @@ import httpcore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(server):
|
||||
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
|
||||
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"Hello, world!"
|
||||
@ -13,7 +13,7 @@ async def test_get(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(server):
|
||||
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
|
||||
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
|
||||
response = await http.request(
|
||||
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
|
||||
)
|
||||
|
||||
116
tests/test_http2.py
Normal file
116
tests/test_http2.py
Normal file
@ -0,0 +1,116 @@
|
||||
import json
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import pytest
|
||||
|
||||
import httpcore
|
||||
|
||||
|
||||
class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
|
||||
"""
|
||||
This class exposes Reader and Writer style interfaces
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
config = h2.config.H2Configuration(client_side=False)
|
||||
self.conn = h2.connection.H2Connection(config=config)
|
||||
self.buffer = b""
|
||||
self.requests = {}
|
||||
|
||||
# BaseReader interface
|
||||
|
||||
async def read(self, n, timeout) -> bytes:
|
||||
send, self.buffer = self.buffer[:n], self.buffer[n:]
|
||||
return send
|
||||
|
||||
# BaseWriter interface
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
events = self.conn.receive_data(data)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
self.request_received(event.headers, event.stream_id)
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
self.receive_data(event.data, event.stream_id)
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
self.stream_complete(event.stream_id)
|
||||
|
||||
async def write(self, data: bytes, timeout) -> None:
|
||||
self.write_no_block(data)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
# Server implementation
|
||||
|
||||
def request_received(self, headers, stream_id):
|
||||
if stream_id not in self.requests:
|
||||
self.requests[stream_id] = []
|
||||
self.requests[stream_id].append({"headers": headers, "data": b""})
|
||||
|
||||
def receive_data(self, data, stream_id):
|
||||
self.requests[stream_id][-1]["data"] += data
|
||||
|
||||
def stream_complete(self, stream_id):
|
||||
request = self.requests[stream_id].pop(0)
|
||||
if not self.requests[stream_id]:
|
||||
del self.requests[stream_id]
|
||||
|
||||
request_headers = dict(request["headers"])
|
||||
request_data = request["data"]
|
||||
|
||||
response_body = json.dumps(
|
||||
{
|
||||
"method": request_headers[b":method"].decode(),
|
||||
"path": request_headers[b":path"].decode(),
|
||||
"body": request_data.decode(),
|
||||
}
|
||||
).encode()
|
||||
|
||||
response_headers = ((b":status", b"200"),)
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.conn.send_data(stream_id, response_body, end_stream=True)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_get_request():
|
||||
server = MockServer()
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
|
||||
response = await conn.request("GET", "http://example.org")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_post_request():
|
||||
server = MockServer()
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
|
||||
response = await conn.request("POST", "http://example.org", body=b"<data>")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.body) == {
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"body": "<data>",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_multiple_requests():
|
||||
server = MockServer()
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
|
||||
response_1 = await conn.request("GET", "http://example.org/1")
|
||||
response_2 = await conn.request("GET", "http://example.org/2")
|
||||
response_3 = await conn.request("GET", "http://example.org/3")
|
||||
|
||||
assert response_1.status_code == 200
|
||||
assert json.loads(response_1.body) == {"method": "GET", "path": "/1", "body": ""}
|
||||
|
||||
assert response_2.status_code == 200
|
||||
assert json.loads(response_2.body) == {"method": "GET", "path": "/2", "body": ""}
|
||||
|
||||
assert response_3.status_code == 200
|
||||
assert json.loads(response_3.body) == {"method": "GET", "path": "/3", "body": ""}
|
||||
@ -5,19 +5,22 @@ import httpcore
|
||||
|
||||
def test_host_header():
|
||||
request = httpcore.Request("GET", "http://example.org")
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
|
||||
)
|
||||
|
||||
|
||||
def test_content_length_header():
|
||||
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"content-length", b"8"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
(b"content-length", b"8"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_transfer_encoding_header():
|
||||
@ -27,31 +30,34 @@ def test_transfer_encoding_header():
|
||||
body = streaming_body(b"test 123")
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", body=body)
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"transfer-encoding", b"chunked"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
(b"transfer-encoding", b"chunked"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_override_host_header():
|
||||
headers = [(b"host", b"1.2.3.4:80")]
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
assert request.headers == [
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"host", b"1.2.3.4:80"),
|
||||
]
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
|
||||
)
|
||||
|
||||
|
||||
def test_override_accept_encoding_header():
|
||||
headers = [(b"accept-encoding", b"identity")]
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"identity"),
|
||||
]
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
|
||||
)
|
||||
|
||||
|
||||
def test_override_content_length_header():
|
||||
@ -62,23 +68,26 @@ def test_override_content_length_header():
|
||||
headers = [(b"content-length", b"8")]
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"content-length", b"8"),
|
||||
]
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"content-length", b"8"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_url():
|
||||
request = httpcore.Request("GET", "http://example.org")
|
||||
assert request.url.scheme == "http"
|
||||
assert request.url.port == 80
|
||||
assert request.url.target == "/"
|
||||
assert request.url.full_path == "/"
|
||||
|
||||
request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
|
||||
assert request.url.scheme == "https"
|
||||
assert request.url.port == 443
|
||||
assert request.url.target == "/abc?foo=bar"
|
||||
assert request.url.full_path == "/abc?foo=bar"
|
||||
|
||||
|
||||
def test_invalid_urls():
|
||||
|
||||
@ -24,10 +24,9 @@ async def test_connect_timeout(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_timeout(server):
|
||||
timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
|
||||
limits = httpcore.PoolLimits(hard_limit=1)
|
||||
limits = httpcore.PoolLimits(hard_limit=1, pool_timeout=0.0001)
|
||||
|
||||
async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
|
||||
async with httpcore.ConnectionPool(limits=limits) as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
|
||||
with pytest.raises(httpcore.PoolTimeout):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user