Merge pull request #23 from encode/adapters

Adapter refactoring
This commit is contained in:
Tom Christie 2019-04-29 16:55:42 +01:00 committed by GitHub
commit 696c054968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 2214 additions and 724 deletions

90
API.md Normal file
View File

@ -0,0 +1,90 @@
Client(...)
.request(method, url, ...)
.get(url, ...)
.options(url, ...)
.head(url, ...)
.post(url, ...)
.put(url, ...)
.patch(url, ...)
.delete(url, ...)
.prepare_request(request)
.send(request, ...)
.close()
Adapter()
.prepare_request(request)
.send(request)
.close()
+ EnvironmentAdapter
+ RedirectAdapter
+ CookieAdapter
+ AuthAdapter
+ ConnectionPool
+ HTTPConnection
+ HTTP11Connection
+ HTTP2Connection
Response(...)
.status_code - int
.reason_phrase - str
.protocol - "HTTP/2" or "HTTP/1.1"
.url - URL
.headers - Headers
.content - bytes
.text - str
.encoding - str
.json() - Any
.read() - bytes
.stream() - bytes iterator
.raw() - bytes iterator
.close() - None
.is_redirect - bool
.request - Request
.cookies - Cookies
.history - List[Response]
.raise_for_status()
.next()
Request(...)
.method
.url
.headers
...
Headers
URL
Origin
Cookies
# Sync
SyncClient
SyncResponse
SyncRequest
SyncAdapter
SSE
HTTP/2 server push support
Concurrency

View File

@ -1,16 +1,26 @@
from .adapters.redirects import RedirectAdapter
from .client import Client
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .connectionpool import ConnectionPool
from .datastructures import URL, Origin, Request, Response
from .dispatch.connection import HTTPConnection
from .dispatch.connection_pool import ConnectionPool
from .dispatch.http2 import HTTP2Connection
from .dispatch.http11 import HTTP11Connection
from .exceptions import (
ConnectTimeout,
PoolTimeout,
ProtocolError,
ReadTimeout,
RedirectBodyUnavailable,
RedirectLoop,
ResponseClosed,
StreamConsumed,
Timeout,
TooManyRedirects,
)
from .http11 import HTTP11Connection
from .interfaces import Adapter
from .models import URL, Headers, Origin, Request, Response
from .status_codes import codes
from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
from .sync import SyncClient, SyncConnectionPool
__version__ = "0.2.1"

View File

@ -0,0 +1,4 @@
"""
Adapter classes layer additional behavior over the raw dispatching of the
HTTP request/response.
"""

View File

@ -0,0 +1,18 @@
import typing
from ..interfaces import Adapter
from ..models import Request, Response
class AuthenticationAdapter(Adapter):
def __init__(self, dispatch: Adapter):
self.dispatch = dispatch
def prepare_request(self, request: Request) -> None:
self.dispatch.prepare_request(request)
async def send(self, request: Request, **options: typing.Any) -> Response:
return await self.dispatch.send(request, **options)
async def close(self) -> None:
await self.dispatch.close()

View File

@ -0,0 +1,18 @@
import typing
from ..interfaces import Adapter
from ..models import Request, Response
class CookieAdapter(Adapter):
def __init__(self, dispatch: Adapter):
self.dispatch = dispatch
def prepare_request(self, request: Request) -> None:
self.dispatch.prepare_request(request)
async def send(self, request: Request, **options: typing.Any) -> Response:
return await self.dispatch.send(request, **options)
async def close(self) -> None:
await self.dispatch.close()

View File

@ -0,0 +1,27 @@
import typing
from ..interfaces import Adapter
from ..models import Request, Response
class EnvironmentAdapter(Adapter):
def __init__(self, dispatch: Adapter, trust_env: bool = True):
self.dispatch = dispatch
self.trust_env = trust_env
def prepare_request(self, request: Request) -> None:
self.dispatch.prepare_request(request)
async def send(self, request: Request, **options: typing.Any) -> Response:
if self.trust_env:
self.merge_environment_options(options)
return await self.dispatch.send(request, **options)
async def close(self) -> None:
await self.dispatch.close()
def merge_environment_options(self, options: dict) -> None:
"""
Add environment options.
"""
#  TODO

View File

@ -0,0 +1,125 @@
import functools
import typing
from urllib.parse import urljoin, urlparse
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..interfaces import Adapter
from ..models import URL, Headers, Request, Response
from ..status_codes import codes
from ..utils import requote_uri
class RedirectAdapter(Adapter):
def __init__(self, dispatch: Adapter, max_redirects: int = DEFAULT_MAX_REDIRECTS):
self.dispatch = dispatch
self.max_redirects = max_redirects
def prepare_request(self, request: Request) -> None:
self.dispatch.prepare_request(request)
async def send(self, request: Request, **options: typing.Any) -> Response:
allow_redirects = options.pop("allow_redirects", True)
history = options.pop("history", []) # type: typing.List[Response]
seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL]
seen_urls.add(request.url)
while True:
response = await self.dispatch.send(request, **options)
response.history = list(history)
if not response.is_redirect:
break
history.append(response)
request = self.build_redirect_request(request, response)
if not allow_redirects:
next_options = dict(options)
next_options["seen_urls"] = seen_urls
next_options["history"] = history
response.next = functools.partial(self.send, request=request, **next_options)
break
if len(history) > self.max_redirects:
raise TooManyRedirects()
if request.url in seen_urls:
raise RedirectLoop()
seen_urls.add(request.url)
return response
async def close(self) -> None:
await self.dispatch.close()
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url)
body = self.redirect_body(request, method)
return Request(method=method, url=url, headers=headers, body=body)
def redirect_method(self, request: Request, response: Response) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.see_other and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.found and method != "HEAD":
method = "GET"
# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.moved_permanently and method == "POST":
method = "GET"
return method
def redirect_url(self, request: Request, response: Response) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]
# Handle redirection without scheme (see: RFC 1808 Section 4)
if location.startswith("//"):
location = f"{request.url.scheme}:{location}"
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
parsed = urlparse(location)
if parsed.fragment == "" and request.url.fragment:
parsed = parsed._replace(fragment=request.url.fragment)
url = parsed.geturl()
# Facilitate relative 'location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
# Compliant with RFC3986, we percent encode the url.
if not parsed.netloc:
url = urljoin(str(request.url), requote_uri(url))
else:
url = requote_uri(url)
return URL(url)
def redirect_headers(self, request: Request, url: URL) -> Headers:
"""
Strip Authorization headers when responses are redirected away from
the origin.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
del headers["Authorization"]
return headers
def redirect_body(self, request: Request, method: str) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:
raise RedirectBodyUnavailable()
return request.body

124
httpcore/client.py Normal file
View File

@ -0,0 +1,124 @@
import typing
from types import TracebackType
from .adapters.authentication import AuthenticationAdapter
from .adapters.cookies import CookieAdapter
from .adapters.environment import EnvironmentAdapter
from .adapters.redirects import RedirectAdapter
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
PoolLimits,
SSLConfig,
TimeoutConfig,
)
from .dispatch.connection_pool import ConnectionPool
from .models import URL, Request, Response
class Client:
def __init__(
self,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
):
connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits)
cookie_adapter = CookieAdapter(dispatch=connection_pool)
auth_adapter = AuthenticationAdapter(dispatch=cookie_adapter)
redirect_adapter = RedirectAdapter(
dispatch=auth_adapter, max_redirects=max_redirects
)
self.adapter = EnvironmentAdapter(dispatch=redirect_adapter)
async def request(
self,
method: str,
url: typing.Union[str, URL],
*,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
allow_redirects: bool = True,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
request = Request(method, url, headers=headers, body=body)
self.prepare_request(request)
response = await self.send(
request,
stream=stream,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
)
return response
async def get(
self,
url: typing.Union[str, URL],
*,
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
return await self.request(
"GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout
)
async def post(
self,
url: typing.Union[str, URL],
*,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
return await self.request(
"POST",
url,
body=body,
headers=headers,
stream=stream,
ssl=ssl,
timeout=timeout,
)
def prepare_request(self, request: Request) -> None:
self.adapter.prepare_request(request)
async def send(
self,
request: Request,
*,
stream: bool = False,
allow_redirects: bool = True,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
options = {"stream": stream} # type: typing.Dict[str, typing.Any]
if ssl is not None:
options["ssl"] = ssl
if timeout is not None:
options["timeout"] = timeout
return await self.adapter.send(request, **options)
async def close(self) -> None:
await self.adapter.close()
async def __aenter__(self) -> "Client":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()

View File

@ -27,10 +27,6 @@ class SSLConfig:
and self.verify == other.verify
)
def __hash__(self) -> int:
as_tuple = (self.cert, self.verify)
return hash(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(cert={self.cert}, verify={self.verify})"
@ -73,7 +69,23 @@ class SSLConfig:
"invalid path: {}".format(self.verify)
)
context = ssl.create_default_context()
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
# RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST
# support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the
# blacklist defined in this section allows only the AES GCM and ChaCha20
# cipher suites with ephemeral key negotiation.
context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20")
if ssl.HAS_ALPN:
context.set_alpn_protocols(["h2", "http/1.1"])
if ssl.HAS_NPN:
context.set_npn_protocols(["h2", "http/1.1"])
if os.path.isfile(ca_bundle_path):
context.load_verify_locations(cafile=ca_bundle_path)
elif os.path.isdir(ca_bundle_path):
@ -99,39 +111,35 @@ class TimeoutConfig:
*,
connect_timeout: float = None,
read_timeout: float = None,
pool_timeout: float = None,
write_timeout: float = None,
):
if timeout is not None:
# Specified as a single timeout value
assert connect_timeout is None
assert read_timeout is None
assert pool_timeout is None
assert write_timeout is None
connect_timeout = timeout
read_timeout = timeout
pool_timeout = timeout
write_timeout = timeout
self.timeout = timeout
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.pool_timeout = pool_timeout
self.write_timeout = write_timeout
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.connect_timeout == other.connect_timeout
and self.read_timeout == other.read_timeout
and self.pool_timeout == other.pool_timeout
and self.write_timeout == other.write_timeout
)
def __hash__(self) -> int:
as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout)
return hash(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
if self.timeout is not None:
return f"{class_name}(timeout={self.timeout})"
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})"
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
class PoolLimits:
@ -142,31 +150,29 @@ class PoolLimits:
def __init__(
self,
*,
soft_limit: typing.Optional[int] = None,
hard_limit: typing.Optional[int] = None,
soft_limit: int = None,
hard_limit: int = None,
pool_timeout: float = None,
):
self.soft_limit = soft_limit
self.hard_limit = hard_limit
self.pool_timeout = pool_timeout
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.soft_limit == other.soft_limit
and self.hard_limit == other.hard_limit
and self.pool_timeout == other.pool_timeout
)
def __hash__(self) -> int:
as_tuple = (self.soft_limit, self.hard_limit)
return hash(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
return (
f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
)
return f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})"
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0)
DEFAULT_CA_BUNDLE_PATH = certifi.where()
DEFAULT_MAX_REDIRECTS = 20

View File

@ -1,132 +0,0 @@
import asyncio
import typing
from .config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
PoolLimits,
SSLConfig,
TimeoutConfig,
)
from .datastructures import Client, Origin, Request, Response
from .exceptions import PoolTimeout
from .http11 import HTTP11Connection
class ConnectionPool(Client):
def __init__(
self,
*,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
):
self.ssl = ssl
self.timeout = timeout
self.limits = limits
self.is_closed = False
self.num_active_connections = 0
self.num_keepalive_connections = 0
self._keepalive_connections = (
{}
) # type: typing.Dict[Origin, typing.List[HTTP11Connection]]
self._max_connections = ConnectionSemaphore(
max_connections=self.limits.hard_limit
)
async def send(
self,
request: Request,
*,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
connection = await self.acquire_connection(request.url.origin, timeout=timeout)
response = await connection.send(request, ssl=ssl, timeout=timeout)
return response
@property
def num_connections(self) -> int:
return self.num_active_connections + self.num_keepalive_connections
async def acquire_connection(
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
) -> HTTP11Connection:
try:
connection = self._keepalive_connections[origin].pop()
if not self._keepalive_connections[origin]:
del self._keepalive_connections[origin]
self.num_keepalive_connections -= 1
self.num_active_connections += 1
except (KeyError, IndexError):
if timeout is None:
pool_timeout = self.timeout.pool_timeout
else:
pool_timeout = timeout.pool_timeout
try:
await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
connection = HTTP11Connection(
origin,
ssl=self.ssl,
timeout=self.timeout,
on_release=self.release_connection,
)
self.num_active_connections += 1
return connection
async def release_connection(self, connection: HTTP11Connection) -> None:
if connection.is_closed:
self._max_connections.release()
self.num_active_connections -= 1
elif (
self.limits.soft_limit is not None
and self.num_connections > self.limits.soft_limit
):
self._max_connections.release()
self.num_active_connections -= 1
await connection.close()
else:
self.num_active_connections -= 1
self.num_keepalive_connections += 1
try:
self._keepalive_connections[connection.origin].append(connection)
except KeyError:
self._keepalive_connections[connection.origin] = [connection]
async def close(self) -> None:
self.is_closed = True
all_connections = []
for connections in self._keepalive_connections.values():
all_connections.extend(list(connections))
self._keepalive_connections.clear()
for connection in all_connections:
await connection.close()
class ConnectionSemaphore:
def __init__(self, max_connections: int = None):
self.max_connections = max_connections
@property
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
if not hasattr(self, "_semaphore"):
if self.max_connections is None:
self._semaphore = None
else:
self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
return self._semaphore
async def acquire(self) -> None:
if self.semaphore is not None:
await self.semaphore.acquire()
def release(self) -> None:
if self.semaphore is not None:
self.semaphore.release()

View File

@ -1,280 +0,0 @@
import http
import typing
from types import TracebackType
from urllib.parse import urlsplit
from .config import SSLConfig, TimeoutConfig
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
Decoder,
IdentityDecoder,
MultiDecoder,
)
from .exceptions import ResponseClosed, StreamConsumed
class URL:
def __init__(self, url: str = "") -> None:
self.components = urlsplit(url)
if not self.components.scheme:
raise ValueError("No scheme included in URL.")
if self.components.scheme not in ("http", "https"):
raise ValueError('URL scheme must be "http" or "https".')
if not self.components.hostname:
raise ValueError("No hostname included in URL.")
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def hostname(self) -> str:
return self.components.hostname
@property
def port(self) -> int:
port = self.components.port
if port is None:
return {"https": 443, "http": 80}[self.scheme]
return port
@property
def target(self) -> str:
path = self.path or "/"
query = self.query
if query:
return path + "?" + query
return path
@property
def is_secure(self) -> bool:
return self.components.scheme == "https"
@property
def origin(self) -> "Origin":
return Origin(self)
class Origin:
def __init__(self, url: typing.Union[str, URL]) -> None:
if isinstance(url, str):
url = URL(url)
self.is_ssl = url.scheme == "https"
self.hostname = url.hostname.lower()
self.port = url.port
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.is_ssl == other.is_ssl
and self.hostname == other.hostname
and self.port == other.port
)
def __hash__(self) -> int:
return hash((self.is_ssl, self.hostname, self.port))
class Request:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method.upper()
self.url = URL(url) if isinstance(url, str) else url
self.headers = list(headers)
if isinstance(body, bytes):
self.is_streaming = False
self.body = body
else:
self.is_streaming = True
self.body_aiter = body
self.headers = self._auto_headers() + self.headers
def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
has_host = False
has_content_length = False
has_accept_encoding = False
for header, value in self.headers:
header = header.strip().lower()
if header == b"host":
has_host = True
elif header in (b"content-length", b"transfer-encoding"):
has_content_length = True
elif header == b"accept-encoding":
has_accept_encoding = True
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
if not has_host:
headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
if self.is_streaming:
headers.append((b"transfer-encoding", b"chunked"))
elif self.body:
content_length = str(len(self.body)).encode()
headers.append((b"content-length", content_length))
if not has_accept_encoding:
headers.append((b"accept-encoding", ACCEPT_ENCODING))
return headers
async def stream(self) -> typing.AsyncIterator[bytes]:
assert self.is_streaming
async for part in self.body_aiter:
yield part
class Response:
def __init__(
self,
status_code: int,
*,
reason: typing.Optional[str] = None,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
):
self.status_code = status_code
if not reason:
try:
self.reason = http.HTTPStatus(status_code).phrase
except ValueError as exc:
self.reason = ""
else:
self.reason = reason
self.headers = list(headers)
self.on_close = on_close
self.is_closed = False
self.is_streamed = False
decoders = [] # type: typing.List[Decoder]
for header, value in self.headers:
if header.strip().lower() == b"content-encoding":
for part in value.split(b","):
part = part.strip().lower()
decoder_cls = SUPPORTED_DECODERS[part]
decoders.append(decoder_cls())
if len(decoders) == 0:
self.decoder = IdentityDecoder() # type: Decoder
elif len(decoders) == 1:
self.decoder = decoders[0]
else:
self.decoder = MultiDecoder(decoders)
if isinstance(body, bytes):
self.is_closed = True
self.body = self.decoder.decode(body) + self.decoder.flush()
else:
self.body_aiter = body
async def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "body"):
body = b""
async for part in self.stream():
body += part
self.body = body
return self.body
async def stream(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "body"):
yield self.body
else:
async for chunk in self.raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
async def raw(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_streamed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_streamed = True
async for part in self.body_aiter:
yield part
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:
await self.on_close()
class Client:
async def request(
self,
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
stream: bool = False,
) -> Response:
request = Request(method, url, headers=headers, body=body)
response = await self.send(request, ssl=ssl, timeout=timeout)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
async def send(
self,
request: Request,
*,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
raise NotImplementedError() # pragma: nocover
async def close(self) -> None:
raise NotImplementedError() # pragma: nocover
async def __aenter__(self) -> "Client":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()

View File

@ -110,17 +110,17 @@ class MultiDecoder(Decoder):
SUPPORTED_DECODERS = {
b"identity": IdentityDecoder,
b"deflate": DeflateDecoder,
b"gzip": GZipDecoder,
b"br": BrotliDecoder,
"identity": IdentityDecoder,
"deflate": DeflateDecoder,
"gzip": GZipDecoder,
"br": BrotliDecoder,
}
if brotli is None:
SUPPORTED_DECODERS.pop(b"br") # pragma: nocover
SUPPORTED_DECODERS.pop("br") # pragma: nocover
ACCEPT_ENCODING = b", ".join(
[key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
ACCEPT_ENCODING = ", ".join(
[key for key in SUPPORTED_DECODERS.keys() if key != "identity"]
)

View File

@ -0,0 +1,4 @@
"""
Dispatch classes handle the raw network connections and the implementation
details of making the HTTP request and receiving the response.
"""

View File

@ -0,0 +1,93 @@
import functools
import typing
import h2.connection
import h11
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
SSLConfig,
TimeoutConfig,
)
from ..exceptions import ConnectTimeout
from ..interfaces import Adapter
from ..models import Origin, Request, Response
from ..streams import Protocol, connect
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
# Callback signature: async def callback(conn: HTTPConnection) -> None
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
class HTTPConnection(Adapter):
def __init__(
self,
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
release_func: typing.Optional[ReleaseCallback] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
self.release_func = release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
def prepare_request(self, request: Request) -> None:
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
if self.h11_connection is None and self.h2_connection is None:
await self.connect(**options)
if self.h2_connection is not None:
response = await self.h2_connection.send(request, **options)
else:
assert self.h11_connection is not None
response = await self.h11_connection.send(request, **options)
return response
async def connect(self, **options: typing.Any) -> None:
ssl = options.get("ssl", self.ssl)
timeout = options.get("timeout", self.timeout)
assert isinstance(ssl, SSLConfig)
assert isinstance(timeout, TimeoutConfig)
hostname = self.origin.hostname
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
if self.release_func is None:
on_release = None
else:
on_release = functools.partial(self.release_func, self)
reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
else:
self.h11_connection = HTTP11Connection(
reader, writer, on_release=on_release
)
async def close(self) -> None:
if self.h2_connection is not None:
await self.h2_connection.close()
elif self.h11_connection is not None:
await self.h11_connection.close()
@property
def is_http2(self) -> bool:
return self.h2_connection is not None
@property
def is_closed(self) -> bool:
if self.h2_connection is not None:
return self.h2_connection.is_closed
else:
assert self.h11_connection is not None
return self.h11_connection.is_closed

View File

@ -0,0 +1,158 @@
import collections.abc
import typing
from ..config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
PoolLimits,
SSLConfig,
TimeoutConfig,
)
from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
from ..interfaces import Adapter
from ..models import Origin, Request, Response
from ..streams import PoolSemaphore
from .connection import HTTPConnection
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
class ConnectionStore(collections.abc.Sequence):
"""
We need to maintain collections of connections in a way that allows us to:
* Lookup connections by origin.
* Iterate over connections by insertion time.
* Return the total number of connections.
"""
def __init__(self) -> None:
self.all = {} # type: typing.Dict[HTTPConnection, float]
self.by_origin = (
{}
) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
def pop_by_origin(
self, origin: Origin, http2_only: bool = False
) -> typing.Optional[HTTPConnection]:
try:
connections = self.by_origin[origin]
except KeyError:
return None
connection = next(reversed(list(connections.keys())))
if http2_only and not connection.is_http2:
return None
del connections[connection]
if not connections:
del self.by_origin[origin]
del self.all[connection]
return connection
def add(self, connection: HTTPConnection) -> None:
self.all[connection] = 0.0
try:
self.by_origin[connection.origin][connection] = 0.0
except KeyError:
self.by_origin[connection.origin] = {connection: 0.0}
def remove(self, connection: HTTPConnection) -> None:
del self.all[connection]
del self.by_origin[connection.origin][connection]
if not self.by_origin[connection.origin]:
del self.by_origin[connection.origin]
def clear(self) -> None:
self.all.clear()
self.by_origin.clear()
def __iter__(self) -> typing.Iterator[HTTPConnection]:
return iter(self.all.keys())
def __getitem__(self, key: typing.Any) -> typing.Any:
if key in self.all:
return key
return None
def __len__(self) -> int:
return len(self.all)
class ConnectionPool(Adapter):
def __init__(
self,
*,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
):
self.ssl = ssl
self.timeout = timeout
self.limits = limits
self.is_closed = False
self.max_connections = PoolSemaphore(limits)
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
@property
def num_connections(self) -> int:
return len(self.keepalive_connections) + len(self.active_connections)
def prepare_request(self, request: Request) -> None:
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
connection = await self.acquire_connection(request.url.origin)
try:
response = await connection.send(request, **options)
except BaseException as exc:
self.active_connections.remove(connection)
self.max_connections.release()
raise exc
return response
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin)
if connection is None:
await self.max_connections.acquire()
connection = HTTPConnection(
origin,
ssl=self.ssl,
timeout=self.timeout,
release_func=self.release_connection,
)
self.active_connections.add(connection)
return connection
async def release_connection(self, connection: HTTPConnection) -> None:
if connection.is_closed:
self.active_connections.remove(connection)
self.max_connections.release()
elif (
self.limits.soft_limit is not None
and self.num_connections > self.limits.soft_limit
):
self.active_connections.remove(connection)
self.max_connections.release()
await connection.close()
else:
self.active_connections.remove(connection)
self.keepalive_connections.add(connection)
async def close(self) -> None:
self.is_closed = True
connections = list(self.keepalive_connections)
self.keepalive_connections.clear()
for connection in connections:
await connection.close()

148
httpcore/dispatch/http11.py Normal file
View File

@ -0,0 +1,148 @@
import typing
import h11
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
SSLConfig,
TimeoutConfig,
)
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import Adapter
from ..models import Request, Response
from ..streams import BaseReader, BaseWriter
H11Event = typing.Union[
h11.Request,
h11.Response,
h11.InformationalResponse,
h11.Data,
h11.EndOfMessage,
h11.ConnectionClosed,
]
OptionalTimeout = typing.Optional[TimeoutConfig]
# Callback signature: async def callback() -> None
# In practice the callback will be a functools partial, which binds
# the `ConnectionPool.release_connection(conn: HTTPConnection)` method.
OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
class HTTP11Connection(Adapter):
READ_NUM_BYTES = 4096
def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.reader = reader
self.writer = writer
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
def prepare_request(self, request: Request) -> None:
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
timeout = options.get("timeout")
stream = options.get("stream", False)
assert timeout is None or isinstance(timeout, TimeoutConfig)
#  Start sending the request.
method = request.method.encode()
target = request.url.full_path
headers = request.headers.raw
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event, timeout)
# Send the request body.
async for data in request.stream():
event = h11.Data(data=data)
await self._send_event(event, timeout)
# Finalize sending the request.
event = h11.EndOfMessage()
await self._send_event(event, timeout)
# Start getting the response.
event = await self._receive_event(timeout)
if isinstance(event, h11.InformationalResponse):
event = await self._receive_event(timeout)
assert isinstance(event, h11.Response)
reason = event.reason.decode("latin1")
status_code = event.status_code
headers = event.headers
body = self._body_iter(timeout)
response = Response(
status_code=status_code,
reason=reason,
protocol="HTTP/1.1",
headers=headers,
body=body,
on_close=self.response_closed,
request=request,
)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
async def close(self) -> None:
event = h11.ConnectionClosed()
try:
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
self.h11_state.send(event)
except h11.ProtocolError:
# If we're in some other state then it's a premature close,
# and we'll end up in h11.ERROR.
pass
await self.writer.close()
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
event = await self._receive_event(timeout)
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event(timeout)
assert isinstance(event, h11.EndOfMessage)
async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None:
data = self.h11_state.send(event)
await self.writer.write(data, timeout)
async def _receive_event(self, timeout: OptionalTimeout) -> H11Event:
event = self.h11_state.next_event()
while event is h11.NEED_DATA:
data = await self.reader.read(self.READ_NUM_BYTES, timeout)
self.h11_state.receive_data(data)
event = self.h11_state.next_event()
return event
async def response_closed(self) -> None:
if (
self.h11_state.our_state is h11.DONE
and self.h11_state.their_state is h11.DONE
):
self.h11_state.start_next_cycle()
else:
await self.close()
if self.on_release is not None:
await self.on_release()
@property
def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)

156
httpcore/dispatch/http2.py Normal file
View File

@ -0,0 +1,156 @@
import functools
import typing
import h2.connection
import h2.events
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
SSLConfig,
TimeoutConfig,
)
from ..exceptions import ConnectTimeout, ReadTimeout
from ..interfaces import Adapter
from ..models import Request, Response
from ..streams import BaseReader, BaseWriter
OptionalTimeout = typing.Optional[TimeoutConfig]
class HTTP2Connection(Adapter):
READ_NUM_BYTES = 4096
def __init__(
self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None
):
self.reader = reader
self.writer = writer
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
self.initialized = False
def prepare_request(self, request: Request) -> None:
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
timeout = options.get("timeout")
stream = options.get("stream", False)
assert timeout is None or isinstance(timeout, TimeoutConfig)
#  Start sending the request.
if not self.initialized:
self.initiate_connection()
stream_id = await self.send_headers(request, timeout)
self.events[stream_id] = []
# Send the request body.
async for data in request.stream():
await self.send_data(stream_id, data, timeout)
# Finalize sending the request.
await self.end_stream(stream_id, timeout)
# Start getting the response.
while True:
event = await self.receive_event(stream_id, timeout)
if isinstance(event, h2.events.ResponseReceived):
break
status_code = 200
headers = []
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode())
elif not k.startswith(b":"):
headers.append((k, v))
body = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)
response = Response(
status_code=status_code,
protocol="HTTP/2",
headers=headers,
body=body,
on_close=on_close,
request=request,
)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
async def close(self) -> None:
await self.writer.close()
def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.writer.write_no_block(data_to_send)
self.initialized = True
async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int:
stream_id = self.h2_state.get_next_available_stream_id()
headers = [
(b":method", request.method.encode()),
(b":authority", request.url.hostname.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
] + request.headers.raw
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
return stream_id
async def send_data(
self, stream_id: int, data: bytes, timeout: OptionalTimeout
) -> None:
self.h2_state.send_data(stream_id, data)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
async def body_iter(
self, stream_id: int, timeout: OptionalTimeout
) -> typing.AsyncIterator[bytes]:
while True:
event = await self.receive_event(stream_id, timeout)
if isinstance(event, h2.events.DataReceived):
yield event.data
elif isinstance(event, h2.events.StreamEnded):
break
async def receive_event(
self, stream_id: int, timeout: OptionalTimeout
) -> h2.events.Event:
while not self.events[stream_id]:
data = await self.reader.read(self.READ_NUM_BYTES, timeout)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
self.events[event.stream_id].append(event)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
return self.events[stream_id].pop(0)
async def response_closed(self, stream_id: int) -> None:
del self.events[stream_id]
if not self.events and self.on_release is not None:
await self.on_release()
@property
def is_closed(self) -> bool:
return False

View File

@ -16,12 +16,43 @@ class ReadTimeout(Timeout):
"""
class WriteTimeout(Timeout):
"""
Timeout while writing request data.
"""
class PoolTimeout(Timeout):
"""
Timeout while waiting to acquire a connection from the pool.
"""
class RedirectError(Exception):
"""
Base class for HTTP redirect errors.
"""
class TooManyRedirects(RedirectError):
"""
Too many redirects.
"""
class RedirectBodyUnavailable(RedirectError):
"""
Got a redirect response, but the request body was streaming, and is
no longer available.
"""
class RedirectLoop(RedirectError):
"""
Infinite redirect loop.
"""
class ProtocolError(Exception):
"""
Malformed HTTP.
@ -40,3 +71,8 @@ class ResponseClosed(Exception):
Attempted to read or stream response content, but the request has been
closed without loading the body.
"""
class InvalidURL(Exception):
"""
"""

View File

@ -1,163 +0,0 @@
import asyncio
import typing
import h11
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
from .datastructures import Client, Origin, Request, Response
from .exceptions import ConnectTimeout, ReadTimeout
H11Event = typing.Union[
h11.Request,
h11.Response,
h11.InformationalResponse,
h11.Data,
h11.EndOfMessage,
h11.ConnectionClosed,
]
class HTTP11Connection(Client):
def __init__(
self,
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
self.on_release = on_release
self._reader = None
self._writer = None
self._h11_state = h11.Connection(our_role=h11.CLIENT)
@property
def is_closed(self) -> bool:
return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
async def send(
self,
request: Request,
*,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
assert request.url.origin == self.origin
if ssl is None:
ssl = self.ssl
if timeout is None:
timeout = self.timeout
# Make the connection
if self._reader is None:
await self._connect(ssl, timeout)
#  Start sending the request.
method = request.method.encode()
target = request.url.target
headers = request.headers
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event)
# Send the request body.
if request.is_streaming:
async for data in request.stream():
event = h11.Data(data=data)
await self._send_event(event)
elif request.body:
event = h11.Data(data=request.body)
await self._send_event(event)
# Finalize sending the request.
event = h11.EndOfMessage()
await self._send_event(event)
# Start getting the response.
event = await self._receive_event(timeout)
if isinstance(event, h11.InformationalResponse):
event = await self._receive_event(timeout)
assert isinstance(event, h11.Response)
reason = event.reason.decode("latin1")
status_code = event.status_code
headers = event.headers
body = self._body_iter(timeout)
return Response(
status_code=status_code,
reason=reason,
headers=headers,
body=body,
on_close=self._release,
)
async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
hostname = self.origin.hostname
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
try:
self._reader, self._writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
event = await self._receive_event(timeout)
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event(timeout)
assert isinstance(event, h11.EndOfMessage)
async def _send_event(self, event: H11Event) -> None:
assert self._writer is not None
data = self._h11_state.send(event)
self._writer.write(data)
async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
assert self._reader is not None
event = self._h11_state.next_event()
while event is h11.NEED_DATA:
try:
data = await asyncio.wait_for(
self._reader.read(2048), timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout()
self._h11_state.receive_data(data)
event = self._h11_state.next_event()
return event
async def _release(self) -> None:
assert self._writer is not None
if (
self._h11_state.our_state is h11.DONE
and self._h11_state.their_state is h11.DONE
):
self._h11_state.start_next_cycle()
else:
await self.close()
if self.on_release is not None:
await self.on_release(self)
async def close(self) -> None:
event = h11.ConnectionClosed()
try:
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
self._h11_state.send(event)
except h11.ProtocolError:
# If we're in some other state then it's a premature close,
# and we'll end up in h11.ERROR.
pass
if self._writer is not None:
self._writer.close()

67
httpcore/interfaces.py Normal file
View File

@ -0,0 +1,67 @@
import typing
from types import TracebackType
from .config import TimeoutConfig
from .models import URL, Request, Response
OptionalTimeout = typing.Optional[TimeoutConfig]
class Adapter:
async def request(
self,
method: str,
url: typing.Union[str, URL],
*,
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
**options: typing.Any,
) -> Response:
request = Request(method, url, headers=headers, body=body)
self.prepare_request(request)
response = await self.send(request, **options)
return response
def prepare_request(self, request: Request) -> None:
raise NotImplementedError() # pragma: nocover
async def send(self, request: Request, **options: typing.Any) -> Response:
raise NotImplementedError() # pragma: nocover
async def close(self) -> None:
raise NotImplementedError() # pragma: nocover
async def __aenter__(self) -> "Adapter":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()
class BaseReader:
async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
raise NotImplementedError() # pragma: no cover
class BaseWriter:
def write_no_block(self, data: bytes) -> None:
raise NotImplementedError() # pragma: no cover
async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
raise NotImplementedError() # pragma: no cover
async def close(self) -> None:
raise NotImplementedError() # pragma: no cover
class BasePoolSemaphore:
async def acquire(self) -> None:
raise NotImplementedError() # pragma: no cover
def release(self) -> None:
raise NotImplementedError() # pragma: no cover

400
httpcore/models.py Normal file
View File

@ -0,0 +1,400 @@
import http
import typing
from urllib.parse import urlsplit
from .config import SSLConfig, TimeoutConfig
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
Decoder,
IdentityDecoder,
MultiDecoder,
)
from .exceptions import ResponseClosed, StreamConsumed
from .utils import normalize_header_key, normalize_header_value
class URL:
def __init__(self, url: str = "") -> None:
self.components = urlsplit(url)
if not self.components.scheme:
raise ValueError("No scheme included in URL.")
if self.components.scheme not in ("http", "https"):
raise ValueError('URL scheme must be "http" or "https".')
if not self.components.hostname:
raise ValueError("No hostname included in URL.")
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def fragment(self) -> str:
return self.components.fragment
@property
def hostname(self) -> str:
return self.components.hostname
@property
def port(self) -> int:
port = self.components.port
if port is None:
return {"https": 443, "http": 80}[self.scheme]
return port
@property
def full_path(self) -> str:
path = self.path or "/"
query = self.query
if query:
return path + "?" + query
return path
@property
def is_secure(self) -> bool:
return self.components.scheme == "https"
@property
def origin(self) -> "Origin":
return Origin(self)
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, URL) and str(self) == str(other)
def __str__(self) -> str:
return self.components.geturl()
def __repr__(self) -> str:
class_name = self.__class__.__name__
url_str = str(self)
return f"{class_name}({url_str!r})"
class Origin:
def __init__(self, url: typing.Union[str, URL]) -> None:
if isinstance(url, str):
url = URL(url)
self.is_ssl = url.scheme == "https"
self.hostname = url.hostname.lower()
self.port = url.port
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.is_ssl == other.is_ssl
and self.hostname == other.hostname
and self.port == other.port
)
def __hash__(self) -> int:
return hash((self.is_ssl, self.hostname, self.port))
HeaderTypes = typing.Union[
"Headers",
typing.Dict[typing.AnyStr, typing.AnyStr],
typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
]
class Headers(typing.MutableMapping[str, str]):
"""
A case-insensitive multidict.
"""
def __init__(self, headers: HeaderTypes = None) -> None:
if headers is None:
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
elif isinstance(headers, Headers):
self._list = list(headers.raw)
elif isinstance(headers, dict):
self._list = [
(normalize_header_key(k), normalize_header_value(v))
for k, v in headers.items()
]
else:
self._list = [
(normalize_header_key(k), normalize_header_value(v)) for k, v in headers
]
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self._list
def keys(self) -> typing.List[str]: # type: ignore
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> typing.List[str]: # type: ignore
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
def get(self, key: str, default: typing.Any = None) -> typing.Any:
try:
return self[key]
except KeyError:
return default
def getlist(self, key: str) -> typing.List[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
for item_key, item_value in self._list
if item_key == get_header_key
]
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]
def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
as_dict = dict(self.items())
if len(as_dict) == len(self):
return f"{class_name}({as_dict!r})"
return f"{class_name}(raw={self.raw!r})"
class Request:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
headers: HeaderTypes = None,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method.upper()
self.url = URL(url) if isinstance(url, str) else url
if isinstance(body, bytes):
self.is_streaming = False
self.body = body
else:
self.is_streaming = True
self.body_aiter = body
self.headers = Headers(headers)
async def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "body"):
body = b""
async for part in self.stream():
body += part
self.body = body
return self.body
async def stream(self) -> typing.AsyncIterator[bytes]:
if self.is_streaming:
async for part in self.body_aiter:
yield part
elif self.body:
yield self.body
def prepare(self) -> None:
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
has_host = "host" in self.headers
has_content_length = (
"content-length" in self.headers or "transfer-encoding" in self.headers
)
has_accept_encoding = "accept-encoding" in self.headers
if not has_host:
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
if self.is_streaming:
auto_headers.append((b"transfer-encoding", b"chunked"))
elif self.body:
content_length = str(len(self.body)).encode()
auto_headers.append((b"content-length", content_length))
if not has_accept_encoding:
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
for item in reversed(auto_headers):
self.headers.raw.insert(0, item)
class Response:
def __init__(
self,
status_code: int,
*,
reason: typing.Optional[str] = None,
protocol: typing.Optional[str] = None,
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
request: Request = None,
history: typing.List["Response"] = None,
):
self.status_code = status_code
if not reason:
try:
self.reason = http.HTTPStatus(status_code).phrase
except ValueError as exc:
self.reason = ""
else:
self.reason = reason
self.protocol = protocol
self.headers = Headers(headers)
self.on_close = on_close
self.is_closed = False
self.is_streamed = False
decoders = [] # type: typing.List[Decoder]
value = self.headers.get("content-encoding", "identity")
for part in value.split(","):
part = part.strip().lower()
decoder_cls = SUPPORTED_DECODERS[part]
decoders.append(decoder_cls())
if len(decoders) == 0:
self.decoder = IdentityDecoder() # type: Decoder
elif len(decoders) == 1:
self.decoder = decoders[0]
else:
self.decoder = MultiDecoder(decoders)
if isinstance(body, bytes):
self.is_closed = True
self.body = self.decoder.decode(body) + self.decoder.flush()
else:
self.body_aiter = body
self.request = request
self.history = [] if history is None else list(history)
@property
def url(self) -> typing.Optional[URL]:
return None if self.request is None else self.request.url
async def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "body"):
body = b""
async for part in self.stream():
body += part
self.body = body
return self.body
async def stream(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "body"):
yield self.body
else:
async for chunk in self.raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
async def raw(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_streamed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_streamed = True
async for part in self.body_aiter:
yield part
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:
await self.on_close()
@property
def is_redirect(self) -> bool:
return (
self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
)

61
httpcore/status_codes.py Normal file
View File

@ -0,0 +1,61 @@
import enum
codes = enum.IntEnum(
"StatusCode",
[
("continue", 100),
("switching_protocols", 101),
("processing", 102),
("ok", 200),
("created", 201),
("accepted", 202),
("non_authoritative_information", 203),
("no_content", 204),
("reset_content", 205),
("partial_content", 206),
("multi_status", 207),
("already_reported", 208),
("im_used", 226),
("multiple_choices", 300),
("moved_permanently", 301),
("found", 302),
("see_other", 303),
("not_modified", 304),
("use_proxy", 305),
("temporary_redirect", 307),
("permanent_redirect", 308),
("bad_request", 400),
("unauthorized", 401),
("payment_required", 402),
("forbidden", 403),
("not_found", 404),
("method_not_allowed", 405),
("not_acceptable", 406),
("proxy_authentication_required", 407),
("request_timeout", 408),
("conflict", 409),
("gone", 410),
("length_required", 411),
("precondition_failed", 412),
("request_entity_too_large", 413),
("request_uri_too_long", 414),
("unsupported_media_type", 415),
("requested_range_not_satisfiable", 416),
("expectation_failed", 417),
("unprocessable_entity", 422),
("locked", 423),
("failed_dependency", 424),
("precondition_required", 428),
("too_many_requests", 429),
("request_header_fields_too_large", 431),
("unavailable_for_legal_reasons", 451),
("internal_server_error", 500),
("not_implemented", 501),
("bad_gateway", 502),
("service_unavailable", 503),
("gateway_timeout", 504),
("http_version_not_supported", 505),
("insufficient_storage", 507),
("network_authentication_required", 511),
],
)

133
httpcore/streams.py Normal file
View File

@ -0,0 +1,133 @@
"""
The `Reader` and `Writer` classes here provide a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
These classes help encapsulate the timeout logic, make it easier to unit-test
protocols, and help keep the rest of the package more `async`/`await`
based, and less strictly `asyncio`-specific.
"""
import asyncio
import enum
import ssl
import typing
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter
OptionalTimeout = typing.Optional[TimeoutConfig]
class Protocol(enum.Enum):
HTTP_11 = 1
HTTP_2 = 2
class Reader(BaseReader):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
) -> None:
self.stream_reader = stream_reader
self.timeout = timeout
async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
if timeout is None:
timeout = self.timeout
try:
data = await asyncio.wait_for(
self.stream_reader.read(n), timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout()
return data
class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
self.stream_writer = stream_writer
self.timeout = timeout
def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data)
async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
if not data:
return
if timeout is None:
timeout = self.timeout
self.stream_writer.write(data)
try:
data = await asyncio.wait_for( # type: ignore
self.stream_writer.drain(), timeout.write_timeout
)
except asyncio.TimeoutError:
raise WriteTimeout()
async def close(self) -> None:
self.stream_writer.close()
class PoolSemaphore(BasePoolSemaphore):
def __init__(self, limits: PoolLimits):
self.limits = limits
@property
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
if not hasattr(self, "_semaphore"):
max_connections = self.limits.hard_limit
if max_connections is None:
self._semaphore = None
else:
self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
return self._semaphore
async def acquire(self) -> None:
if self.semaphore is None:
return
timeout = self.limits.pool_timeout
try:
await asyncio.wait_for(self.semaphore.acquire(), timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
def release(self) -> None:
if self.semaphore is None:
return
self.semaphore.release()
async def connect(
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext] = None,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
) -> typing.Tuple[Reader, Writer, Protocol]:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
ssl_object = stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
ident = "http/1.1"
else:
ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()
reader = Reader(stream_reader=stream_reader, timeout=timeout)
writer = Writer(stream_writer=stream_writer, timeout=timeout)
protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
return (reader, writer, protocol)

View File

@ -3,8 +3,9 @@ import typing
from types import TracebackType
from .config import SSLConfig, TimeoutConfig
from .connectionpool import ConnectionPool
from .datastructures import URL, Client, Response
from .dispatch.connection_pool import ConnectionPool
from .interfaces import Adapter
from .models import URL, Headers, Response
class SyncResponse:
@ -21,7 +22,7 @@ class SyncResponse:
return self._response.reason
@property
def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
def headers(self) -> Headers:
return self._response.headers
@property
@ -44,8 +45,8 @@ class SyncResponse:
class SyncClient:
def __init__(self, client: Client):
self._client = client
def __init__(self, adapter: Adapter):
self._client = adapter
self._loop = asyncio.new_event_loop()
def request(
@ -53,22 +54,12 @@ class SyncClient:
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
stream: bool = False,
**options: typing.Any
) -> SyncResponse:
response = self._loop.run_until_complete(
self._client.request(
method,
url,
headers=headers,
body=body,
ssl=ssl,
timeout=timeout,
stream=stream,
)
self._client.request(method, url, headers=headers, body=body, **options)
)
return SyncResponse(response, self._loop)

71
httpcore/utils.py Normal file
View File

@ -0,0 +1,71 @@
import typing
from urllib.parse import quote
from .exceptions import InvalidURL
# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
)
def unquote_unreserved(uri: str) -> str:
"""
Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
"""
parts = uri.split("%")
for i in range(1, len(parts)):
h = parts[i][0:2]
if len(h) == 2 and h.isalnum():
try:
c = chr(int(h, 16))
except ValueError:
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
if c in UNRESERVED_SET:
parts[i] = c + parts[i][2:]
else:
parts[i] = "%" + parts[i]
else:
parts[i] = "%" + parts[i]
return "".join(parts)
def requote_uri(uri: str) -> str:
"""
Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
ensure that it is fully and consistently quoted.
"""
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
try:
# Unquote only the unreserved characters
# Then quote only illegal characters (do not quote reserved,
# unreserved, or '%')
return quote(unquote_unreserved(uri), safe=safe_with_percent)
except InvalidURL:
# We couldn't unquote the given URI, so let's try quoting it, but
# there may be unquoted '%'s in the URI. We need to make sure they're
# properly quoted so they do not cause issues elsewhere.
return quote(uri, safe=safe_without_percent)
def normalize_header_key(value: typing.AnyStr) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header key.
"""
if isinstance(value, bytes):
return value.lower()
return value.encode("latin-1").lower()
def normalize_header_value(value: typing.AnyStr) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header value.
"""
if isinstance(value, bytes):
return value
return value.encode("latin-1")

View File

@ -1,5 +1,6 @@
certifi
h11
h2
# Optional
brotlipy

View File

@ -47,7 +47,7 @@ setup(
author_email="tom@tomchristie.com",
packages=get_packages("httpcore"),
data_files=[("", ["LICENSE.md"])],
install_requires=["h11", "certifi"],
install_requires=["h11", "h2", "certifi"],
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",

View File

@ -0,0 +1,219 @@
import json
from urllib.parse import parse_qs
import pytest
from httpcore import (
URL,
Adapter,
RedirectAdapter,
RedirectBodyUnavailable,
RedirectLoop,
Request,
Response,
TooManyRedirects,
codes,
)
class MockDispatch(Adapter):
def prepare_request(self, request: Request) -> None:
pass
async def send(self, request: Request, **options) -> Response:
if request.url.path == "/redirect_301":
status_code = codes.moved_permanently
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_302":
status_code = codes.found
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_303":
status_code = codes.see_other
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/relative_redirect":
headers = {"location": "/"}
return Response(codes.see_other, headers=headers, request=request)
elif request.url.path == "/no_scheme_redirect":
headers = {"location": "//example.org/"}
return Response(codes.see_other, headers=headers, request=request)
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
count = int(params.get("count", "0")[0])
redirect_count = count - 1
code = codes.see_other if count else codes.ok
location = "/multiple_redirects"
if redirect_count:
location += "?count=" + str(redirect_count)
headers = {"location": location} if count else {}
return Response(code, headers=headers, request=request)
if request.url.path == "/redirect_loop":
headers = {"location": "/redirect_loop"}
return Response(codes.see_other, headers=headers, request=request)
elif request.url.path == "/cross_domain":
headers = {"location": "https://example.org/cross_domain_target"}
return Response(codes.see_other, headers=headers, request=request)
elif request.url.path == "/cross_domain_target":
headers = dict(request.headers.items())
body = json.dumps({"headers": headers}).encode()
return Response(codes.ok, body=body, request=request)
elif request.url.path == "/redirect_body":
body = await request.read()
headers = {"location": "/redirect_body_target"}
return Response(codes.permanent_redirect, headers=headers, request=request)
elif request.url.path == "/redirect_body_target":
body = await request.read()
body = json.dumps({"body": body.decode()}).encode()
return Response(codes.ok, body=body, request=request)
return Response(codes.ok, body=b"Hello, world!", request=request)
@pytest.mark.asyncio
async def test_redirect_301():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_301")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_redirect_302():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_302")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_redirect_303():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/redirect_303")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_disallow_redirects():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_303", allow_redirects=False)
assert response.status_code == codes.see_other
assert response.url == URL("https://example.org/redirect_303")
assert len(response.history) == 0
response = await response.next()
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_relative_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/relative_redirect")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_no_scheme_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/no_scheme_redirect")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_fragment_redirect():
client = RedirectAdapter(MockDispatch())
url = "https://example.org/relative_redirect#fragment"
response = await client.request("GET", url)
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/#fragment")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_multiple_redirects():
client = RedirectAdapter(MockDispatch())
url = "https://example.org/multiple_redirects?count=20"
response = await client.request("GET", url)
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/multiple_redirects")
assert len(response.history) == 20
@pytest.mark.asyncio
async def test_too_many_redirects():
client = RedirectAdapter(MockDispatch())
with pytest.raises(TooManyRedirects):
await client.request("GET", "https://example.org/multiple_redirects?count=21")
@pytest.mark.asyncio
async def test_redirect_loop():
client = RedirectAdapter(MockDispatch())
with pytest.raises(RedirectLoop):
await client.request("GET", "https://example.org/redirect_loop")
@pytest.mark.asyncio
async def test_cross_domain_redirect():
client = RedirectAdapter(MockDispatch())
url = "https://example.com/cross_domain"
headers = {"Authorization": "abc"}
response = await client.request("GET", url, headers=headers)
data = json.loads(response.body.decode())
assert response.url == URL("https://example.org/cross_domain_target")
assert data == {"headers": {}}
@pytest.mark.asyncio
async def test_same_domain_redirect():
client = RedirectAdapter(MockDispatch())
url = "https://example.org/cross_domain"
headers = {"Authorization": "abc"}
response = await client.request("GET", url, headers=headers)
data = json.loads(response.body.decode())
assert response.url == URL("https://example.org/cross_domain_target")
assert data == {"headers": {"authorization": "abc"}}
@pytest.mark.asyncio
async def test_body_redirect():
client = RedirectAdapter(MockDispatch())
url = "https://example.org/redirect_body"
body = b"Example request body"
response = await client.request("POST", url, body=body)
data = json.loads(response.body.decode())
assert response.url == URL("https://example.org/redirect_body_target")
assert data == {"body": "Example request body"}
@pytest.mark.asyncio
async def test_cannot_redirect_streaming_body():
client = RedirectAdapter(MockDispatch())
url = "https://example.org/redirect_body"
async def body():
yield b"Example request body"
with pytest.raises(RedirectBodyUnavailable):
await client.request("POST", url, body=body())

View File

@ -5,16 +5,16 @@ import httpcore
@pytest.mark.asyncio
async def test_get(server):
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
async with httpcore.Client() as client:
response = await client.request("GET", "http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.body == b"Hello, world!"
@pytest.mark.asyncio
async def test_post(server):
async with httpcore.ConnectionPool() as http:
response = await http.request(
async with httpcore.Client() as client:
response = await client.request(
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
)
assert response.status_code == 200
@ -22,8 +22,8 @@ async def test_post(server):
@pytest.mark.asyncio
async def test_stream_response(server):
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
async with httpcore.Client() as client:
response = await client.request("GET", "http://127.0.0.1:8000/", stream=True)
assert response.status_code == 200
assert not hasattr(response, "body")
body = await response.read()
@ -36,8 +36,8 @@ async def test_stream_request(server):
yield b"Hello, "
yield b"world!"
async with httpcore.ConnectionPool() as http:
response = await http.request(
async with httpcore.Client() as client:
response = await client.request(
"POST", "http://127.0.0.1:8000/", body=hello_world()
)
assert response.status_code == 200

View File

@ -13,13 +13,15 @@ def test_timeout_repr():
timeout = httpcore.TimeoutConfig(read_timeout=5.0)
assert (
repr(timeout)
== "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)"
== "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)"
)
def test_limits_repr():
limits = httpcore.PoolLimits(hard_limit=100)
assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
assert (
repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)"
)
def test_ssl_eq():
@ -35,24 +37,3 @@ def test_timeout_eq():
def test_limits_eq():
limits = httpcore.PoolLimits(hard_limit=100)
assert limits == httpcore.PoolLimits(hard_limit=100)
def test_ssl_hash():
cache = {}
ssl = httpcore.SSLConfig(verify=False)
cache[ssl] = "example"
assert cache[httpcore.SSLConfig(verify=False)] == "example"
def test_timeout_hash():
cache = {}
timeout = httpcore.TimeoutConfig(timeout=5.0)
cache[timeout] = "example"
assert cache[httpcore.TimeoutConfig(timeout=5.0)] == "example"
def test_limits_hash():
cache = {}
limits = httpcore.PoolLimits(hard_limit=100)
cache[limits] = "example"
assert cache[httpcore.PoolLimits(hard_limit=100)] == "example"

View File

@ -10,12 +10,12 @@ async def test_keepalive_connections(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -25,12 +25,12 @@ async def test_differing_connection_keys(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 2
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@pytest.mark.asyncio
@ -42,12 +42,12 @@ async def test_soft_limit(server):
async with httpcore.ConnectionPool(limits=limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
await response.read()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server):
"""
async with httpcore.ConnectionPool() as http:
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 2
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 2
assert len(http.keepalive_connections) == 0
await response_b.read()
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 1
await response_a.read()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 2
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@pytest.mark.asyncio
@ -97,8 +97,8 @@ async def test_close_connections(server):
headers = [(b"connection", b"close")]
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
@pytest.mark.asyncio
@ -110,8 +110,8 @@ async def test_standard_response_close(server):
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.read()
await response.close()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -122,5 +122,5 @@ async def test_premature_response_close(server):
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.close()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0

View File

@ -5,7 +5,7 @@ import httpcore
@pytest.mark.asyncio
async def test_get(server):
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
response = await http.request("GET", "http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.body == b"Hello, world!"
@ -13,7 +13,7 @@ async def test_get(server):
@pytest.mark.asyncio
async def test_post(server):
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
response = await http.request(
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
)

116
tests/test_http2.py Normal file
View File

@ -0,0 +1,116 @@
import json
import h2.config
import h2.connection
import h2.events
import pytest
import httpcore
class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
"""
This class exposes Reader and Writer style interfaces
"""
def __init__(self):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
self.buffer = b""
self.requests = {}
# BaseReader interface
async def read(self, n, timeout) -> bytes:
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send
# BaseWriter interface
def write_no_block(self, data: bytes) -> None:
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events:
if isinstance(event, h2.events.RequestReceived):
self.request_received(event.headers, event.stream_id)
elif isinstance(event, h2.events.DataReceived):
self.receive_data(event.data, event.stream_id)
elif isinstance(event, h2.events.StreamEnded):
self.stream_complete(event.stream_id)
async def write(self, data: bytes, timeout) -> None:
self.write_no_block(data)
async def close(self) -> None:
pass
# Server implementation
def request_received(self, headers, stream_id):
if stream_id not in self.requests:
self.requests[stream_id] = []
self.requests[stream_id].append({"headers": headers, "data": b""})
def receive_data(self, data, stream_id):
self.requests[stream_id][-1]["data"] += data
def stream_complete(self, stream_id):
request = self.requests[stream_id].pop(0)
if not self.requests[stream_id]:
del self.requests[stream_id]
request_headers = dict(request["headers"])
request_data = request["data"]
response_body = json.dumps(
{
"method": request_headers[b":method"].decode(),
"path": request_headers[b":path"].decode(),
"body": request_data.decode(),
}
).encode()
response_headers = ((b":status", b"200"),)
self.conn.send_headers(stream_id, response_headers)
self.conn.send_data(stream_id, response_body, end_stream=True)
self.buffer += self.conn.data_to_send()
@pytest.mark.asyncio
async def test_http2_get_request():
server = MockServer()
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
response = await conn.request("GET", "http://example.org")
assert response.status_code == 200
assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
@pytest.mark.asyncio
async def test_http2_post_request():
server = MockServer()
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
response = await conn.request("POST", "http://example.org", body=b"<data>")
assert response.status_code == 200
assert json.loads(response.body) == {
"method": "POST",
"path": "/",
"body": "<data>",
}
@pytest.mark.asyncio
async def test_http2_multiple_requests():
server = MockServer()
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
response_1 = await conn.request("GET", "http://example.org/1")
response_2 = await conn.request("GET", "http://example.org/2")
response_3 = await conn.request("GET", "http://example.org/3")
assert response_1.status_code == 200
assert json.loads(response_1.body) == {"method": "GET", "path": "/1", "body": ""}
assert response_2.status_code == 200
assert json.loads(response_2.body) == {"method": "GET", "path": "/2", "body": ""}
assert response_3.status_code == 200
assert json.loads(response_3.body) == {"method": "GET", "path": "/3", "body": ""}

View File

@ -5,19 +5,22 @@ import httpcore
def test_host_header():
request = httpcore.Request("GET", "http://example.org")
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
]
request.prepare()
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
)
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
assert request.headers == [
(b"host", b"example.org"),
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
]
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
]
)
def test_transfer_encoding_header():
@ -27,31 +30,34 @@ def test_transfer_encoding_header():
body = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", body=body)
assert request.headers == [
(b"host", b"example.org"),
(b"transfer-encoding", b"chunked"),
(b"accept-encoding", b"deflate, gzip, br"),
]
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
(b"transfer-encoding", b"chunked"),
(b"accept-encoding", b"deflate, gzip, br"),
]
)
def test_override_host_header():
headers = [(b"host", b"1.2.3.4:80")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
assert request.headers == [
(b"accept-encoding", b"deflate, gzip, br"),
(b"host", b"1.2.3.4:80"),
]
request.prepare()
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
)
def test_override_accept_encoding_header():
headers = [(b"accept-encoding", b"identity")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"identity"),
]
request.prepare()
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
)
def test_override_content_length_header():
@ -62,23 +68,26 @@ def test_override_content_length_header():
headers = [(b"content-length", b"8")]
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-length", b"8"),
]
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-length", b"8"),
]
)
def test_url():
request = httpcore.Request("GET", "http://example.org")
assert request.url.scheme == "http"
assert request.url.port == 80
assert request.url.target == "/"
assert request.url.full_path == "/"
request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
assert request.url.scheme == "https"
assert request.url.port == 443
assert request.url.target == "/abc?foo=bar"
assert request.url.full_path == "/abc?foo=bar"
def test_invalid_urls():

View File

@ -24,10 +24,9 @@ async def test_connect_timeout(server):
@pytest.mark.asyncio
async def test_pool_timeout(server):
timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
limits = httpcore.PoolLimits(hard_limit=1)
limits = httpcore.PoolLimits(hard_limit=1, pool_timeout=0.0001)
async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
async with httpcore.ConnectionPool(limits=limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
with pytest.raises(httpcore.PoolTimeout):