diff --git a/API.md b/API.md new file mode 100644 index 00000000..56f34fe5 --- /dev/null +++ b/API.md @@ -0,0 +1,90 @@ +Client(...) + + .request(method, url, ...) + + .get(url, ...) + .options(url, ...) + .head(url, ...) + .post(url, ...) + .put(url, ...) + .patch(url, ...) + .delete(url, ...) + + .prepare_request(request) + .send(request, ...) + .close() + + +Adapter() + + .prepare_request(request) + .send(request) + .close() + + ++ EnvironmentAdapter ++ RedirectAdapter ++ CookieAdapter ++ AuthAdapter ++ ConnectionPool + + HTTPConnection + + HTTP11Connection + + HTTP2Connection + + + +Response(...) + .status_code - int + .reason_phrase - str + .protocol - "HTTP/2" or "HTTP/1.1" + .url - URL + .headers - Headers + + .content - bytes + .text - str + .encoding - str + .json() - Any + + .read() - bytes + .stream() - bytes iterator + .raw() - bytes iterator + .close() - None + + .is_redirect - bool + .request - Request + .cookies - Cookies + .history - List[Response] + + .raise_for_status() + .next() + + +Request(...) + .method + .url + .headers + + ... + + +Headers + +URL + +Origin + +Cookies + + +# Sync + +SyncClient +SyncResponse +SyncRequest +SyncAdapter + + + +SSE +HTTP/2 server push support +Concurrency diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 9824e854..813b2b56 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,16 +1,26 @@ +from .adapters.redirects import RedirectAdapter +from .client import Client from .config import PoolLimits, SSLConfig, TimeoutConfig -from .connectionpool import ConnectionPool -from .datastructures import URL, Origin, Request, Response +from .dispatch.connection import HTTPConnection +from .dispatch.connection_pool import ConnectionPool +from .dispatch.http2 import HTTP2Connection +from .dispatch.http11 import HTTP11Connection from .exceptions import ( ConnectTimeout, PoolTimeout, ProtocolError, ReadTimeout, + RedirectBodyUnavailable, + RedirectLoop, ResponseClosed, StreamConsumed, Timeout, + TooManyRedirects, ) -from .http11 import HTTP11Connection +from .interfaces import Adapter +from .models import URL, Headers, Origin, Request, Response +from .status_codes import codes +from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect from .sync import SyncClient, SyncConnectionPool __version__ = "0.2.1" diff --git a/httpcore/adapters/__init__.py b/httpcore/adapters/__init__.py new file mode 100644 index 00000000..8d4629f7 --- /dev/null +++ b/httpcore/adapters/__init__.py @@ -0,0 +1,4 @@ +""" +Adapter classes layer additional behavior over the raw dispatching of the +HTTP request/response. +""" diff --git a/httpcore/adapters/authentication.py b/httpcore/adapters/authentication.py new file mode 100644 index 00000000..cb5ae99a --- /dev/null +++ b/httpcore/adapters/authentication.py @@ -0,0 +1,18 @@ +import typing + +from ..interfaces import Adapter +from ..models import Request, Response + + +class AuthenticationAdapter(Adapter): + def __init__(self, dispatch: Adapter): + self.dispatch = dispatch + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + return await self.dispatch.send(request, **options) + + async def close(self) -> None: + await self.dispatch.close() diff --git a/httpcore/adapters/cookies.py b/httpcore/adapters/cookies.py new file mode 100644 index 00000000..11051521 --- /dev/null +++ b/httpcore/adapters/cookies.py @@ -0,0 +1,18 @@ +import typing + +from ..interfaces import Adapter +from ..models import Request, Response + + +class CookieAdapter(Adapter): + def __init__(self, dispatch: Adapter): + self.dispatch = dispatch + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + return await self.dispatch.send(request, **options) + + async def close(self) -> None: + await self.dispatch.close() diff --git a/httpcore/adapters/environment.py b/httpcore/adapters/environment.py new file mode 100644 index 00000000..9840a926 --- /dev/null +++ b/httpcore/adapters/environment.py @@ -0,0 +1,27 @@ +import typing + +from ..interfaces import Adapter +from ..models import Request, Response + + +class EnvironmentAdapter(Adapter): + def __init__(self, dispatch: Adapter, trust_env: bool = True): + self.dispatch = dispatch + self.trust_env = trust_env + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + if self.trust_env: + self.merge_environment_options(options) + return await self.dispatch.send(request, **options) + + async def close(self) -> None: + await self.dispatch.close() + + def merge_environment_options(self, options: dict) -> None: + """ + Add environment options. + """ + #  TODO diff --git a/httpcore/adapters/redirects.py b/httpcore/adapters/redirects.py new file mode 100644 index 00000000..ecbb1728 --- /dev/null +++ b/httpcore/adapters/redirects.py @@ -0,0 +1,125 @@ +import functools +import typing +from urllib.parse import urljoin, urlparse + +from ..config import DEFAULT_MAX_REDIRECTS +from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects +from ..interfaces import Adapter +from ..models import URL, Headers, Request, Response +from ..status_codes import codes +from ..utils import requote_uri + + +class RedirectAdapter(Adapter): + def __init__(self, dispatch: Adapter, max_redirects: int = DEFAULT_MAX_REDIRECTS): + self.dispatch = dispatch + self.max_redirects = max_redirects + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + allow_redirects = options.pop("allow_redirects", True) + history = options.pop("history", []) # type: typing.List[Response] + seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL] + seen_urls.add(request.url) + + while True: + response = await self.dispatch.send(request, **options) + response.history = list(history) + if not response.is_redirect: + break + history.append(response) + request = self.build_redirect_request(request, response) + if not allow_redirects: + next_options = dict(options) + next_options["seen_urls"] = seen_urls + next_options["history"] = history + response.next = functools.partial(self.send, request=request, **next_options) + break + if len(history) > self.max_redirects: + raise TooManyRedirects() + if request.url in seen_urls: + raise RedirectLoop() + seen_urls.add(request.url) + + return response + + async def close(self) -> None: + await self.dispatch.close() + + def build_redirect_request(self, request: Request, response: Response) -> Request: + method = self.redirect_method(request, response) + url = self.redirect_url(request, response) + headers = self.redirect_headers(request, url) + body = self.redirect_body(request, method) + return Request(method=method, url=url, headers=headers, body=body) + + def redirect_method(self, request: Request, response: Response) -> str: + """ + When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.see_other and method != "HEAD": + method = "GET" + + # Do what the browsers do, despite standards... + # Turn 302s into GETs. + if response.status_code == codes.found and method != "HEAD": + method = "GET" + + # If a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in 'requests' issue 1704. + if response.status_code == codes.moved_permanently and method == "POST": + method = "GET" + + return method + + def redirect_url(self, request: Request, response: Response) -> URL: + """ + Return the URL for the redirect to follow. + """ + location = response.headers["Location"] + + # Handle redirection without scheme (see: RFC 1808 Section 4) + if location.startswith("//"): + location = f"{request.url.scheme}:{location}" + + # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2) + parsed = urlparse(location) + if parsed.fragment == "" and request.url.fragment: + parsed = parsed._replace(fragment=request.url.fragment) + url = parsed.geturl() + + # Facilitate relative 'location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + # Compliant with RFC3986, we percent encode the url. + if not parsed.netloc: + url = urljoin(str(request.url), requote_uri(url)) + else: + url = requote_uri(url) + + return URL(url) + + def redirect_headers(self, request: Request, url: URL) -> Headers: + """ + Strip Authorization headers when responses are redirected away from + the origin. + """ + headers = Headers(request.headers) + if url.origin != request.url.origin: + del headers["Authorization"] + return headers + + def redirect_body(self, request: Request, method: str) -> bytes: + """ + Return the body that should be used for the redirect request. + """ + if method != request.method and method == "GET": + return b"" + if request.is_streaming: + raise RedirectBodyUnavailable() + return request.body diff --git a/httpcore/client.py b/httpcore/client.py new file mode 100644 index 00000000..139a5686 --- /dev/null +++ b/httpcore/client.py @@ -0,0 +1,124 @@ +import typing +from types import TracebackType + +from .adapters.authentication import AuthenticationAdapter +from .adapters.cookies import CookieAdapter +from .adapters.environment import EnvironmentAdapter +from .adapters.redirects import RedirectAdapter +from .config import ( + DEFAULT_MAX_REDIRECTS, + DEFAULT_POOL_LIMITS, + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + PoolLimits, + SSLConfig, + TimeoutConfig, +) +from .dispatch.connection_pool import ConnectionPool +from .models import URL, Request, Response + + +class Client: + def __init__( + self, + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + limits: PoolLimits = DEFAULT_POOL_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + ): + connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits) + cookie_adapter = CookieAdapter(dispatch=connection_pool) + auth_adapter = AuthenticationAdapter(dispatch=cookie_adapter) + redirect_adapter = RedirectAdapter( + dispatch=auth_adapter, max_redirects=max_redirects + ) + self.adapter = EnvironmentAdapter(dispatch=redirect_adapter) + + async def request( + self, + method: str, + url: typing.Union[str, URL], + *, + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + headers: typing.List[typing.Tuple[bytes, bytes]] = [], + stream: bool = False, + allow_redirects: bool = True, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + request = Request(method, url, headers=headers, body=body) + self.prepare_request(request) + response = await self.send( + request, + stream=stream, + allow_redirects=allow_redirects, + ssl=ssl, + timeout=timeout, + ) + return response + + async def get( + self, + url: typing.Union[str, URL], + *, + headers: typing.List[typing.Tuple[bytes, bytes]] = [], + stream: bool = False, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + return await self.request( + "GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout + ) + + async def post( + self, + url: typing.Union[str, URL], + *, + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + headers: typing.List[typing.Tuple[bytes, bytes]] = [], + stream: bool = False, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + return await self.request( + "POST", + url, + body=body, + headers=headers, + stream=stream, + ssl=ssl, + timeout=timeout, + ) + + def prepare_request(self, request: Request) -> None: + self.adapter.prepare_request(request) + + async def send( + self, + request: Request, + *, + stream: bool = False, + allow_redirects: bool = True, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + options = {"stream": stream} # type: typing.Dict[str, typing.Any] + if ssl is not None: + options["ssl"] = ssl + if timeout is not None: + options["timeout"] = timeout + return await self.adapter.send(request, **options) + + async def close(self) -> None: + await self.adapter.close() + + async def __aenter__(self) -> "Client": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() diff --git a/httpcore/config.py b/httpcore/config.py index 8cc784b3..ef24a8b1 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -27,10 +27,6 @@ class SSLConfig: and self.verify == other.verify ) - def __hash__(self) -> int: - as_tuple = (self.cert, self.verify) - return hash(as_tuple) - def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(cert={self.cert}, verify={self.verify})" @@ -73,7 +69,23 @@ class SSLConfig: "invalid path: {}".format(self.verify) ) - context = ssl.create_default_context() + context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) + + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + + # RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST + # support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the + # blacklist defined in this section allows only the AES GCM and ChaCha20 + # cipher suites with ephemeral key negotiation. + context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20") + + if ssl.HAS_ALPN: + context.set_alpn_protocols(["h2", "http/1.1"]) + if ssl.HAS_NPN: + context.set_npn_protocols(["h2", "http/1.1"]) + if os.path.isfile(ca_bundle_path): context.load_verify_locations(cafile=ca_bundle_path) elif os.path.isdir(ca_bundle_path): @@ -99,39 +111,35 @@ class TimeoutConfig: *, connect_timeout: float = None, read_timeout: float = None, - pool_timeout: float = None, + write_timeout: float = None, ): if timeout is not None: # Specified as a single timeout value assert connect_timeout is None assert read_timeout is None - assert pool_timeout is None + assert write_timeout is None connect_timeout = timeout read_timeout = timeout - pool_timeout = timeout + write_timeout = timeout self.timeout = timeout self.connect_timeout = connect_timeout self.read_timeout = read_timeout - self.pool_timeout = pool_timeout + self.write_timeout = write_timeout def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, self.__class__) and self.connect_timeout == other.connect_timeout and self.read_timeout == other.read_timeout - and self.pool_timeout == other.pool_timeout + and self.write_timeout == other.write_timeout ) - def __hash__(self) -> int: - as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout) - return hash(as_tuple) - def __repr__(self) -> str: class_name = self.__class__.__name__ if self.timeout is not None: return f"{class_name}(timeout={self.timeout})" - return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})" + return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})" class PoolLimits: @@ -142,31 +150,29 @@ class PoolLimits: def __init__( self, *, - soft_limit: typing.Optional[int] = None, - hard_limit: typing.Optional[int] = None, + soft_limit: int = None, + hard_limit: int = None, + pool_timeout: float = None, ): self.soft_limit = soft_limit self.hard_limit = hard_limit + self.pool_timeout = pool_timeout def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, self.__class__) and self.soft_limit == other.soft_limit and self.hard_limit == other.hard_limit + and self.pool_timeout == other.pool_timeout ) - def __hash__(self) -> int: - as_tuple = (self.soft_limit, self.hard_limit) - return hash(as_tuple) - def __repr__(self) -> str: class_name = self.__class__.__name__ - return ( - f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})" - ) + return f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})" DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) -DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100) +DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0) DEFAULT_CA_BUNDLE_PATH = certifi.where() +DEFAULT_MAX_REDIRECTS = 20 diff --git a/httpcore/connectionpool.py b/httpcore/connectionpool.py deleted file mode 100644 index a2040dbc..00000000 --- a/httpcore/connectionpool.py +++ /dev/null @@ -1,132 +0,0 @@ -import asyncio -import typing - -from .config import ( - DEFAULT_CA_BUNDLE_PATH, - DEFAULT_POOL_LIMITS, - DEFAULT_SSL_CONFIG, - DEFAULT_TIMEOUT_CONFIG, - PoolLimits, - SSLConfig, - TimeoutConfig, -) -from .datastructures import Client, Origin, Request, Response -from .exceptions import PoolTimeout -from .http11 import HTTP11Connection - - -class ConnectionPool(Client): - def __init__( - self, - *, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - limits: PoolLimits = DEFAULT_POOL_LIMITS, - ): - self.ssl = ssl - self.timeout = timeout - self.limits = limits - self.is_closed = False - self.num_active_connections = 0 - self.num_keepalive_connections = 0 - self._keepalive_connections = ( - {} - ) # type: typing.Dict[Origin, typing.List[HTTP11Connection]] - self._max_connections = ConnectionSemaphore( - max_connections=self.limits.hard_limit - ) - - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: - connection = await self.acquire_connection(request.url.origin, timeout=timeout) - response = await connection.send(request, ssl=ssl, timeout=timeout) - return response - - @property - def num_connections(self) -> int: - return self.num_active_connections + self.num_keepalive_connections - - async def acquire_connection( - self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None - ) -> HTTP11Connection: - try: - connection = self._keepalive_connections[origin].pop() - if not self._keepalive_connections[origin]: - del self._keepalive_connections[origin] - self.num_keepalive_connections -= 1 - self.num_active_connections += 1 - - except (KeyError, IndexError): - if timeout is None: - pool_timeout = self.timeout.pool_timeout - else: - pool_timeout = timeout.pool_timeout - - try: - await asyncio.wait_for(self._max_connections.acquire(), pool_timeout) - except asyncio.TimeoutError: - raise PoolTimeout() - connection = HTTP11Connection( - origin, - ssl=self.ssl, - timeout=self.timeout, - on_release=self.release_connection, - ) - self.num_active_connections += 1 - - return connection - - async def release_connection(self, connection: HTTP11Connection) -> None: - if connection.is_closed: - self._max_connections.release() - self.num_active_connections -= 1 - elif ( - self.limits.soft_limit is not None - and self.num_connections > self.limits.soft_limit - ): - self._max_connections.release() - self.num_active_connections -= 1 - await connection.close() - else: - self.num_active_connections -= 1 - self.num_keepalive_connections += 1 - try: - self._keepalive_connections[connection.origin].append(connection) - except KeyError: - self._keepalive_connections[connection.origin] = [connection] - - async def close(self) -> None: - self.is_closed = True - all_connections = [] - for connections in self._keepalive_connections.values(): - all_connections.extend(list(connections)) - self._keepalive_connections.clear() - for connection in all_connections: - await connection.close() - - -class ConnectionSemaphore: - def __init__(self, max_connections: int = None): - self.max_connections = max_connections - - @property - def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]: - if not hasattr(self, "_semaphore"): - if self.max_connections is None: - self._semaphore = None - else: - self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections) - return self._semaphore - - async def acquire(self) -> None: - if self.semaphore is not None: - await self.semaphore.acquire() - - def release(self) -> None: - if self.semaphore is not None: - self.semaphore.release() diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py deleted file mode 100644 index de5ee2f5..00000000 --- a/httpcore/datastructures.py +++ /dev/null @@ -1,280 +0,0 @@ -import http -import typing -from types import TracebackType -from urllib.parse import urlsplit - -from .config import SSLConfig, TimeoutConfig -from .decoders import ( - ACCEPT_ENCODING, - SUPPORTED_DECODERS, - Decoder, - IdentityDecoder, - MultiDecoder, -) -from .exceptions import ResponseClosed, StreamConsumed - - -class URL: - def __init__(self, url: str = "") -> None: - self.components = urlsplit(url) - if not self.components.scheme: - raise ValueError("No scheme included in URL.") - if self.components.scheme not in ("http", "https"): - raise ValueError('URL scheme must be "http" or "https".') - if not self.components.hostname: - raise ValueError("No hostname included in URL.") - - @property - def scheme(self) -> str: - return self.components.scheme - - @property - def netloc(self) -> str: - return self.components.netloc - - @property - def path(self) -> str: - return self.components.path - - @property - def query(self) -> str: - return self.components.query - - @property - def hostname(self) -> str: - return self.components.hostname - - @property - def port(self) -> int: - port = self.components.port - if port is None: - return {"https": 443, "http": 80}[self.scheme] - return port - - @property - def target(self) -> str: - path = self.path or "/" - query = self.query - if query: - return path + "?" + query - return path - - @property - def is_secure(self) -> bool: - return self.components.scheme == "https" - - @property - def origin(self) -> "Origin": - return Origin(self) - - -class Origin: - def __init__(self, url: typing.Union[str, URL]) -> None: - if isinstance(url, str): - url = URL(url) - self.is_ssl = url.scheme == "https" - self.hostname = url.hostname.lower() - self.port = url.port - - def __eq__(self, other: typing.Any) -> bool: - return ( - isinstance(other, self.__class__) - and self.is_ssl == other.is_ssl - and self.hostname == other.hostname - and self.port == other.port - ) - - def __hash__(self) -> int: - return hash((self.is_ssl, self.hostname, self.port)) - - -class Request: - def __init__( - self, - method: str, - url: typing.Union[str, URL], - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - ): - self.method = method.upper() - self.url = URL(url) if isinstance(url, str) else url - self.headers = list(headers) - if isinstance(body, bytes): - self.is_streaming = False - self.body = body - else: - self.is_streaming = True - self.body_aiter = body - self.headers = self._auto_headers() + self.headers - - def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]: - has_host = False - has_content_length = False - has_accept_encoding = False - - for header, value in self.headers: - header = header.strip().lower() - if header == b"host": - has_host = True - elif header in (b"content-length", b"transfer-encoding"): - has_content_length = True - elif header == b"accept-encoding": - has_accept_encoding = True - - headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] - - if not has_host: - headers.append((b"host", self.url.netloc.encode("ascii"))) - if not has_content_length: - if self.is_streaming: - headers.append((b"transfer-encoding", b"chunked")) - elif self.body: - content_length = str(len(self.body)).encode() - headers.append((b"content-length", content_length)) - if not has_accept_encoding: - headers.append((b"accept-encoding", ACCEPT_ENCODING)) - - return headers - - async def stream(self) -> typing.AsyncIterator[bytes]: - assert self.is_streaming - - async for part in self.body_aiter: - yield part - - -class Response: - def __init__( - self, - status_code: int, - *, - reason: typing.Optional[str] = None, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - on_close: typing.Callable = None, - ): - self.status_code = status_code - if not reason: - try: - self.reason = http.HTTPStatus(status_code).phrase - except ValueError as exc: - self.reason = "" - else: - self.reason = reason - self.headers = list(headers) - self.on_close = on_close - self.is_closed = False - self.is_streamed = False - - decoders = [] # type: typing.List[Decoder] - for header, value in self.headers: - if header.strip().lower() == b"content-encoding": - for part in value.split(b","): - part = part.strip().lower() - decoder_cls = SUPPORTED_DECODERS[part] - decoders.append(decoder_cls()) - - if len(decoders) == 0: - self.decoder = IdentityDecoder() # type: Decoder - elif len(decoders) == 1: - self.decoder = decoders[0] - else: - self.decoder = MultiDecoder(decoders) - - if isinstance(body, bytes): - self.is_closed = True - self.body = self.decoder.decode(body) + self.decoder.flush() - else: - self.body_aiter = body - - async def read(self) -> bytes: - """ - Read and return the response content. - """ - if not hasattr(self, "body"): - body = b"" - async for part in self.stream(): - body += part - self.body = body - return self.body - - async def stream(self) -> typing.AsyncIterator[bytes]: - """ - A byte-iterator over the decoded response content. - This allows us to handle gzip, deflate, and brotli encoded responses. - """ - if hasattr(self, "body"): - yield self.body - else: - async for chunk in self.raw(): - yield self.decoder.decode(chunk) - yield self.decoder.flush() - - async def raw(self) -> typing.AsyncIterator[bytes]: - """ - A byte-iterator over the raw response content. - """ - if self.is_streamed: - raise StreamConsumed() - if self.is_closed: - raise ResponseClosed() - self.is_streamed = True - async for part in self.body_aiter: - yield part - await self.close() - - async def close(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - if not self.is_closed: - self.is_closed = True - if self.on_close is not None: - await self.on_close() - - -class Client: - async def request( - self, - method: str, - url: typing.Union[str, URL], - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - stream: bool = False, - ) -> Response: - request = Request(method, url, headers=headers, body=body) - response = await self.send(request, ssl=ssl, timeout=timeout) - if not stream: - try: - await response.read() - finally: - await response.close() - return response - - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: - raise NotImplementedError() # pragma: nocover - - async def close(self) -> None: - raise NotImplementedError() # pragma: nocover - - async def __aenter__(self) -> "Client": - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.close() diff --git a/httpcore/decoders.py b/httpcore/decoders.py index b56745c4..e47143d7 100644 --- a/httpcore/decoders.py +++ b/httpcore/decoders.py @@ -110,17 +110,17 @@ class MultiDecoder(Decoder): SUPPORTED_DECODERS = { - b"identity": IdentityDecoder, - b"deflate": DeflateDecoder, - b"gzip": GZipDecoder, - b"br": BrotliDecoder, + "identity": IdentityDecoder, + "deflate": DeflateDecoder, + "gzip": GZipDecoder, + "br": BrotliDecoder, } if brotli is None: - SUPPORTED_DECODERS.pop(b"br") # pragma: nocover + SUPPORTED_DECODERS.pop("br") # pragma: nocover -ACCEPT_ENCODING = b", ".join( - [key for key in SUPPORTED_DECODERS.keys() if key != b"identity"] +ACCEPT_ENCODING = ", ".join( + [key for key in SUPPORTED_DECODERS.keys() if key != "identity"] ) diff --git a/httpcore/dispatch/__init__.py b/httpcore/dispatch/__init__.py new file mode 100644 index 00000000..4057d6ea --- /dev/null +++ b/httpcore/dispatch/__init__.py @@ -0,0 +1,4 @@ +""" +Dispatch classes handle the raw network connections and the implementation +details of making the HTTP request and receiving the response. +""" diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py new file mode 100644 index 00000000..ffe82fcb --- /dev/null +++ b/httpcore/dispatch/connection.py @@ -0,0 +1,93 @@ +import functools +import typing + +import h2.connection +import h11 + +from ..config import ( + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + SSLConfig, + TimeoutConfig, +) +from ..exceptions import ConnectTimeout +from ..interfaces import Adapter +from ..models import Origin, Request, Response +from ..streams import Protocol, connect +from .http2 import HTTP2Connection +from .http11 import HTTP11Connection + +# Callback signature: async def callback(conn: HTTPConnection) -> None +ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]] + + +class HTTPConnection(Adapter): + def __init__( + self, + origin: typing.Union[str, Origin], + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + release_func: typing.Optional[ReleaseCallback] = None, + ): + self.origin = Origin(origin) if isinstance(origin, str) else origin + self.ssl = ssl + self.timeout = timeout + self.release_func = release_func + self.h11_connection = None # type: typing.Optional[HTTP11Connection] + self.h2_connection = None # type: typing.Optional[HTTP2Connection] + + def prepare_request(self, request: Request) -> None: + request.prepare() + + async def send(self, request: Request, **options: typing.Any) -> Response: + if self.h11_connection is None and self.h2_connection is None: + await self.connect(**options) + + if self.h2_connection is not None: + response = await self.h2_connection.send(request, **options) + else: + assert self.h11_connection is not None + response = await self.h11_connection.send(request, **options) + + return response + + async def connect(self, **options: typing.Any) -> None: + ssl = options.get("ssl", self.ssl) + timeout = options.get("timeout", self.timeout) + assert isinstance(ssl, SSLConfig) + assert isinstance(timeout, TimeoutConfig) + + hostname = self.origin.hostname + port = self.origin.port + ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None + + if self.release_func is None: + on_release = None + else: + on_release = functools.partial(self.release_func, self) + + reader, writer, protocol = await connect(hostname, port, ssl_context, timeout) + if protocol == Protocol.HTTP_2: + self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release) + else: + self.h11_connection = HTTP11Connection( + reader, writer, on_release=on_release + ) + + async def close(self) -> None: + if self.h2_connection is not None: + await self.h2_connection.close() + elif self.h11_connection is not None: + await self.h11_connection.close() + + @property + def is_http2(self) -> bool: + return self.h2_connection is not None + + @property + def is_closed(self) -> bool: + if self.h2_connection is not None: + return self.h2_connection.is_closed + else: + assert self.h11_connection is not None + return self.h11_connection.is_closed diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py new file mode 100644 index 00000000..ca7a837f --- /dev/null +++ b/httpcore/dispatch/connection_pool.py @@ -0,0 +1,158 @@ +import collections.abc +import typing + +from ..config import ( + DEFAULT_CA_BUNDLE_PATH, + DEFAULT_POOL_LIMITS, + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + PoolLimits, + SSLConfig, + TimeoutConfig, +) +from ..decoders import ACCEPT_ENCODING +from ..exceptions import PoolTimeout +from ..interfaces import Adapter +from ..models import Origin, Request, Response +from ..streams import PoolSemaphore +from .connection import HTTPConnection + +CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] + + +class ConnectionStore(collections.abc.Sequence): + """ + We need to maintain collections of connections in a way that allows us to: + + * Lookup connections by origin. + * Iterate over connections by insertion time. + * Return the total number of connections. + """ + + def __init__(self) -> None: + self.all = {} # type: typing.Dict[HTTPConnection, float] + self.by_origin = ( + {} + ) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]] + + def pop_by_origin( + self, origin: Origin, http2_only: bool = False + ) -> typing.Optional[HTTPConnection]: + try: + connections = self.by_origin[origin] + except KeyError: + return None + + connection = next(reversed(list(connections.keys()))) + if http2_only and not connection.is_http2: + return None + + del connections[connection] + if not connections: + del self.by_origin[origin] + del self.all[connection] + + return connection + + def add(self, connection: HTTPConnection) -> None: + self.all[connection] = 0.0 + try: + self.by_origin[connection.origin][connection] = 0.0 + except KeyError: + self.by_origin[connection.origin] = {connection: 0.0} + + def remove(self, connection: HTTPConnection) -> None: + del self.all[connection] + del self.by_origin[connection.origin][connection] + if not self.by_origin[connection.origin]: + del self.by_origin[connection.origin] + + def clear(self) -> None: + self.all.clear() + self.by_origin.clear() + + def __iter__(self) -> typing.Iterator[HTTPConnection]: + return iter(self.all.keys()) + + def __getitem__(self, key: typing.Any) -> typing.Any: + if key in self.all: + return key + return None + + def __len__(self) -> int: + return len(self.all) + + +class ConnectionPool(Adapter): + def __init__( + self, + *, + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + limits: PoolLimits = DEFAULT_POOL_LIMITS, + ): + self.ssl = ssl + self.timeout = timeout + self.limits = limits + self.is_closed = False + + self.max_connections = PoolSemaphore(limits) + self.keepalive_connections = ConnectionStore() + self.active_connections = ConnectionStore() + + @property + def num_connections(self) -> int: + return len(self.keepalive_connections) + len(self.active_connections) + + def prepare_request(self, request: Request) -> None: + request.prepare() + + async def send(self, request: Request, **options: typing.Any) -> Response: + connection = await self.acquire_connection(request.url.origin) + try: + response = await connection.send(request, **options) + except BaseException as exc: + self.active_connections.remove(connection) + self.max_connections.release() + raise exc + return response + + async def acquire_connection(self, origin: Origin) -> HTTPConnection: + connection = self.active_connections.pop_by_origin(origin, http2_only=True) + if connection is None: + connection = self.keepalive_connections.pop_by_origin(origin) + + if connection is None: + await self.max_connections.acquire() + connection = HTTPConnection( + origin, + ssl=self.ssl, + timeout=self.timeout, + release_func=self.release_connection, + ) + + self.active_connections.add(connection) + + return connection + + async def release_connection(self, connection: HTTPConnection) -> None: + if connection.is_closed: + self.active_connections.remove(connection) + self.max_connections.release() + elif ( + self.limits.soft_limit is not None + and self.num_connections > self.limits.soft_limit + ): + self.active_connections.remove(connection) + self.max_connections.release() + await connection.close() + else: + self.active_connections.remove(connection) + self.keepalive_connections.add(connection) + + async def close(self) -> None: + self.is_closed = True + connections = list(self.keepalive_connections) + self.keepalive_connections.clear() + for connection in connections: + await connection.close() diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py new file mode 100644 index 00000000..8128dc18 --- /dev/null +++ b/httpcore/dispatch/http11.py @@ -0,0 +1,148 @@ +import typing + +import h11 + +from ..config import ( + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + SSLConfig, + TimeoutConfig, +) +from ..exceptions import ConnectTimeout, ReadTimeout +from ..interfaces import Adapter +from ..models import Request, Response +from ..streams import BaseReader, BaseWriter + +H11Event = typing.Union[ + h11.Request, + h11.Response, + h11.InformationalResponse, + h11.Data, + h11.EndOfMessage, + h11.ConnectionClosed, +] + + +OptionalTimeout = typing.Optional[TimeoutConfig] + +# Callback signature: async def callback() -> None +# In practice the callback will be a functools partial, which binds +# the `ConnectionPool.release_connection(conn: HTTPConnection)` method. +OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]] + + +class HTTP11Connection(Adapter): + READ_NUM_BYTES = 4096 + + def __init__( + self, + reader: BaseReader, + writer: BaseWriter, + on_release: typing.Optional[OnReleaseCallback] = None, + ): + self.reader = reader + self.writer = writer + self.on_release = on_release + self.h11_state = h11.Connection(our_role=h11.CLIENT) + + def prepare_request(self, request: Request) -> None: + request.prepare() + + async def send(self, request: Request, **options: typing.Any) -> Response: + timeout = options.get("timeout") + stream = options.get("stream", False) + assert timeout is None or isinstance(timeout, TimeoutConfig) + + #  Start sending the request. + method = request.method.encode() + target = request.url.full_path + headers = request.headers.raw + event = h11.Request(method=method, target=target, headers=headers) + await self._send_event(event, timeout) + + # Send the request body. + async for data in request.stream(): + event = h11.Data(data=data) + await self._send_event(event, timeout) + + # Finalize sending the request. + event = h11.EndOfMessage() + await self._send_event(event, timeout) + + # Start getting the response. + event = await self._receive_event(timeout) + if isinstance(event, h11.InformationalResponse): + event = await self._receive_event(timeout) + + assert isinstance(event, h11.Response) + reason = event.reason.decode("latin1") + status_code = event.status_code + headers = event.headers + body = self._body_iter(timeout) + + response = Response( + status_code=status_code, + reason=reason, + protocol="HTTP/1.1", + headers=headers, + body=body, + on_close=self.response_closed, + request=request, + ) + + if not stream: + try: + await response.read() + finally: + await response.close() + + return response + + async def close(self) -> None: + event = h11.ConnectionClosed() + try: + # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. + self.h11_state.send(event) + except h11.ProtocolError: + # If we're in some other state then it's a premature close, + # and we'll end up in h11.ERROR. + pass + + await self.writer.close() + + async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]: + event = await self._receive_event(timeout) + while isinstance(event, h11.Data): + yield event.data + event = await self._receive_event(timeout) + assert isinstance(event, h11.EndOfMessage) + + async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None: + data = self.h11_state.send(event) + await self.writer.write(data, timeout) + + async def _receive_event(self, timeout: OptionalTimeout) -> H11Event: + event = self.h11_state.next_event() + + while event is h11.NEED_DATA: + data = await self.reader.read(self.READ_NUM_BYTES, timeout) + self.h11_state.receive_data(data) + event = self.h11_state.next_event() + + return event + + async def response_closed(self) -> None: + if ( + self.h11_state.our_state is h11.DONE + and self.h11_state.their_state is h11.DONE + ): + self.h11_state.start_next_cycle() + else: + await self.close() + + if self.on_release is not None: + await self.on_release() + + @property + def is_closed(self) -> bool: + return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py new file mode 100644 index 00000000..d44c5f60 --- /dev/null +++ b/httpcore/dispatch/http2.py @@ -0,0 +1,156 @@ +import functools +import typing + +import h2.connection +import h2.events + +from ..config import ( + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + SSLConfig, + TimeoutConfig, +) +from ..exceptions import ConnectTimeout, ReadTimeout +from ..interfaces import Adapter +from ..models import Request, Response +from ..streams import BaseReader, BaseWriter + +OptionalTimeout = typing.Optional[TimeoutConfig] + + +class HTTP2Connection(Adapter): + READ_NUM_BYTES = 4096 + + def __init__( + self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None + ): + self.reader = reader + self.writer = writer + self.on_release = on_release + self.h2_state = h2.connection.H2Connection() + self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] + self.initialized = False + + def prepare_request(self, request: Request) -> None: + request.prepare() + + async def send(self, request: Request, **options: typing.Any) -> Response: + timeout = options.get("timeout") + stream = options.get("stream", False) + assert timeout is None or isinstance(timeout, TimeoutConfig) + + #  Start sending the request. + if not self.initialized: + self.initiate_connection() + stream_id = await self.send_headers(request, timeout) + self.events[stream_id] = [] + + # Send the request body. + async for data in request.stream(): + await self.send_data(stream_id, data, timeout) + + # Finalize sending the request. + await self.end_stream(stream_id, timeout) + + # Start getting the response. + while True: + event = await self.receive_event(stream_id, timeout) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode()) + elif not k.startswith(b":"): + headers.append((k, v)) + + body = self.body_iter(stream_id, timeout) + on_close = functools.partial(self.response_closed, stream_id=stream_id) + + response = Response( + status_code=status_code, + protocol="HTTP/2", + headers=headers, + body=body, + on_close=on_close, + request=request, + ) + + if not stream: + try: + await response.read() + finally: + await response.close() + + return response + + async def close(self) -> None: + await self.writer.close() + + def initiate_connection(self) -> None: + self.h2_state.initiate_connection() + data_to_send = self.h2_state.data_to_send() + self.writer.write_no_block(data_to_send) + self.initialized = True + + async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int: + stream_id = self.h2_state.get_next_available_stream_id() + headers = [ + (b":method", request.method.encode()), + (b":authority", request.url.hostname.encode()), + (b":scheme", request.url.scheme.encode()), + (b":path", request.url.full_path.encode()), + ] + request.headers.raw + self.h2_state.send_headers(stream_id, headers) + data_to_send = self.h2_state.data_to_send() + await self.writer.write(data_to_send, timeout) + return stream_id + + async def send_data( + self, stream_id: int, data: bytes, timeout: OptionalTimeout + ) -> None: + self.h2_state.send_data(stream_id, data) + data_to_send = self.h2_state.data_to_send() + await self.writer.write(data_to_send, timeout) + + async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None: + self.h2_state.end_stream(stream_id) + data_to_send = self.h2_state.data_to_send() + await self.writer.write(data_to_send, timeout) + + async def body_iter( + self, stream_id: int, timeout: OptionalTimeout + ) -> typing.AsyncIterator[bytes]: + while True: + event = await self.receive_event(stream_id, timeout) + if isinstance(event, h2.events.DataReceived): + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + async def receive_event( + self, stream_id: int, timeout: OptionalTimeout + ) -> h2.events.Event: + while not self.events[stream_id]: + data = await self.reader.read(self.READ_NUM_BYTES, timeout) + events = self.h2_state.receive_data(data) + for event in events: + if getattr(event, "stream_id", 0): + self.events[event.stream_id].append(event) + + data_to_send = self.h2_state.data_to_send() + await self.writer.write(data_to_send, timeout) + + return self.events[stream_id].pop(0) + + async def response_closed(self, stream_id: int) -> None: + del self.events[stream_id] + + if not self.events and self.on_release is not None: + await self.on_release() + + @property + def is_closed(self) -> bool: + return False diff --git a/httpcore/exceptions.py b/httpcore/exceptions.py index 30814332..cf8d3f18 100644 --- a/httpcore/exceptions.py +++ b/httpcore/exceptions.py @@ -16,12 +16,43 @@ class ReadTimeout(Timeout): """ +class WriteTimeout(Timeout): + """ + Timeout while writing request data. + """ + + class PoolTimeout(Timeout): """ Timeout while waiting to acquire a connection from the pool. """ +class RedirectError(Exception): + """ + Base class for HTTP redirect errors. + """ + + +class TooManyRedirects(RedirectError): + """ + Too many redirects. + """ + + +class RedirectBodyUnavailable(RedirectError): + """ + Got a redirect response, but the request body was streaming, and is + no longer available. + """ + + +class RedirectLoop(RedirectError): + """ + Infinite redirect loop. + """ + + class ProtocolError(Exception): """ Malformed HTTP. @@ -40,3 +71,8 @@ class ResponseClosed(Exception): Attempted to read or stream response content, but the request has been closed without loading the body. """ + + +class InvalidURL(Exception): + """ + """ diff --git a/httpcore/http11.py b/httpcore/http11.py deleted file mode 100644 index 23cc27ce..00000000 --- a/httpcore/http11.py +++ /dev/null @@ -1,163 +0,0 @@ -import asyncio -import typing - -import h11 - -from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig -from .datastructures import Client, Origin, Request, Response -from .exceptions import ConnectTimeout, ReadTimeout - -H11Event = typing.Union[ - h11.Request, - h11.Response, - h11.InformationalResponse, - h11.Data, - h11.EndOfMessage, - h11.ConnectionClosed, -] - - -class HTTP11Connection(Client): - def __init__( - self, - origin: typing.Union[str, Origin], - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - on_release: typing.Callable = None, - ): - self.origin = Origin(origin) if isinstance(origin, str) else origin - self.ssl = ssl - self.timeout = timeout - self.on_release = on_release - self._reader = None - self._writer = None - self._h11_state = h11.Connection(our_role=h11.CLIENT) - - @property - def is_closed(self) -> bool: - return self._h11_state.our_state in (h11.CLOSED, h11.ERROR) - - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: - assert request.url.origin == self.origin - - if ssl is None: - ssl = self.ssl - if timeout is None: - timeout = self.timeout - - # Make the connection - if self._reader is None: - await self._connect(ssl, timeout) - - #  Start sending the request. - method = request.method.encode() - target = request.url.target - headers = request.headers - event = h11.Request(method=method, target=target, headers=headers) - await self._send_event(event) - - # Send the request body. - if request.is_streaming: - async for data in request.stream(): - event = h11.Data(data=data) - await self._send_event(event) - elif request.body: - event = h11.Data(data=request.body) - await self._send_event(event) - - # Finalize sending the request. - event = h11.EndOfMessage() - await self._send_event(event) - - # Start getting the response. - event = await self._receive_event(timeout) - if isinstance(event, h11.InformationalResponse): - event = await self._receive_event(timeout) - assert isinstance(event, h11.Response) - reason = event.reason.decode("latin1") - status_code = event.status_code - headers = event.headers - body = self._body_iter(timeout) - return Response( - status_code=status_code, - reason=reason, - headers=headers, - body=body, - on_close=self._release, - ) - - async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None: - hostname = self.origin.hostname - port = self.origin.port - ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None - - try: - self._reader, self._writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl_context), - timeout.connect_timeout, - ) - except asyncio.TimeoutError: - raise ConnectTimeout() - - async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]: - event = await self._receive_event(timeout) - while isinstance(event, h11.Data): - yield event.data - event = await self._receive_event(timeout) - assert isinstance(event, h11.EndOfMessage) - - async def _send_event(self, event: H11Event) -> None: - assert self._writer is not None - - data = self._h11_state.send(event) - self._writer.write(data) - - async def _receive_event(self, timeout: TimeoutConfig) -> H11Event: - assert self._reader is not None - - event = self._h11_state.next_event() - - while event is h11.NEED_DATA: - try: - data = await asyncio.wait_for( - self._reader.read(2048), timeout.read_timeout - ) - except asyncio.TimeoutError: - raise ReadTimeout() - self._h11_state.receive_data(data) - event = self._h11_state.next_event() - - return event - - async def _release(self) -> None: - assert self._writer is not None - - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._h11_state.start_next_cycle() - else: - await self.close() - - if self.on_release is not None: - await self.on_release(self) - - async def close(self) -> None: - event = h11.ConnectionClosed() - try: - # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. - self._h11_state.send(event) - except h11.ProtocolError: - # If we're in some other state then it's a premature close, - # and we'll end up in h11.ERROR. - pass - - if self._writer is not None: - self._writer.close() diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py new file mode 100644 index 00000000..d3cdca47 --- /dev/null +++ b/httpcore/interfaces.py @@ -0,0 +1,67 @@ +import typing +from types import TracebackType + +from .config import TimeoutConfig +from .models import URL, Request, Response + +OptionalTimeout = typing.Optional[TimeoutConfig] + + +class Adapter: + async def request( + self, + method: str, + url: typing.Union[str, URL], + *, + headers: typing.List[typing.Tuple[bytes, bytes]] = [], + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + **options: typing.Any, + ) -> Response: + request = Request(method, url, headers=headers, body=body) + self.prepare_request(request) + response = await self.send(request, **options) + return response + + def prepare_request(self, request: Request) -> None: + raise NotImplementedError() # pragma: nocover + + async def send(self, request: Request, **options: typing.Any) -> Response: + raise NotImplementedError() # pragma: nocover + + async def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def __aenter__(self) -> "Adapter": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() + + +class BaseReader: + async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes: + raise NotImplementedError() # pragma: no cover + + +class BaseWriter: + def write_no_block(self, data: bytes) -> None: + raise NotImplementedError() # pragma: no cover + + async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None: + raise NotImplementedError() # pragma: no cover + + async def close(self) -> None: + raise NotImplementedError() # pragma: no cover + + +class BasePoolSemaphore: + async def acquire(self) -> None: + raise NotImplementedError() # pragma: no cover + + def release(self) -> None: + raise NotImplementedError() # pragma: no cover diff --git a/httpcore/models.py b/httpcore/models.py new file mode 100644 index 00000000..43f799a6 --- /dev/null +++ b/httpcore/models.py @@ -0,0 +1,400 @@ +import http +import typing +from urllib.parse import urlsplit + +from .config import SSLConfig, TimeoutConfig +from .decoders import ( + ACCEPT_ENCODING, + SUPPORTED_DECODERS, + Decoder, + IdentityDecoder, + MultiDecoder, +) +from .exceptions import ResponseClosed, StreamConsumed +from .utils import normalize_header_key, normalize_header_value + + +class URL: + def __init__(self, url: str = "") -> None: + self.components = urlsplit(url) + if not self.components.scheme: + raise ValueError("No scheme included in URL.") + if self.components.scheme not in ("http", "https"): + raise ValueError('URL scheme must be "http" or "https".') + if not self.components.hostname: + raise ValueError("No hostname included in URL.") + + @property + def scheme(self) -> str: + return self.components.scheme + + @property + def netloc(self) -> str: + return self.components.netloc + + @property + def path(self) -> str: + return self.components.path + + @property + def query(self) -> str: + return self.components.query + + @property + def fragment(self) -> str: + return self.components.fragment + + @property + def hostname(self) -> str: + return self.components.hostname + + @property + def port(self) -> int: + port = self.components.port + if port is None: + return {"https": 443, "http": 80}[self.scheme] + return port + + @property + def full_path(self) -> str: + path = self.path or "/" + query = self.query + if query: + return path + "?" + query + return path + + @property + def is_secure(self) -> bool: + return self.components.scheme == "https" + + @property + def origin(self) -> "Origin": + return Origin(self) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, URL) and str(self) == str(other) + + def __str__(self) -> str: + return self.components.geturl() + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url_str = str(self) + return f"{class_name}({url_str!r})" + + +class Origin: + def __init__(self, url: typing.Union[str, URL]) -> None: + if isinstance(url, str): + url = URL(url) + self.is_ssl = url.scheme == "https" + self.hostname = url.hostname.lower() + self.port = url.port + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.is_ssl == other.is_ssl + and self.hostname == other.hostname + and self.port == other.port + ) + + def __hash__(self) -> int: + return hash((self.is_ssl, self.hostname, self.port)) + + +HeaderTypes = typing.Union[ + "Headers", + typing.Dict[typing.AnyStr, typing.AnyStr], + typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]], +] + + +class Headers(typing.MutableMapping[str, str]): + """ + A case-insensitive multidict. + """ + + def __init__(self, headers: HeaderTypes = None) -> None: + if headers is None: + self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]] + elif isinstance(headers, Headers): + self._list = list(headers.raw) + elif isinstance(headers, dict): + self._list = [ + (normalize_header_key(k), normalize_header_value(v)) + for k, v in headers.items() + ] + else: + self._list = [ + (normalize_header_key(k), normalize_header_value(v)) for k, v in headers + ] + + @property + def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: + return self._list + + def keys(self) -> typing.List[str]: # type: ignore + return [key.decode("latin-1") for key, value in self._list] + + def values(self) -> typing.List[str]: # type: ignore + return [value.decode("latin-1") for key, value in self._list] + + def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore + return [ + (key.decode("latin-1"), value.decode("latin-1")) + for key, value in self._list + ] + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + try: + return self[key] + except KeyError: + return default + + def getlist(self, key: str) -> typing.List[str]: + get_header_key = key.lower().encode("latin-1") + return [ + item_value.decode("latin-1") + for item_key, item_value in self._list + if item_key == get_header_key + ] + + def __getitem__(self, key: str) -> str: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return header_value.decode("latin-1") + raise KeyError(key) + + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + found_indexes = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + found_indexes.append(idx) + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, set_value) + else: + self._list.append((set_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") + + pop_indexes = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == del_key: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __contains__(self, key: typing.Any) -> bool: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, Headers): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + as_dict = dict(self.items()) + if len(as_dict) == len(self): + return f"{class_name}({as_dict!r})" + return f"{class_name}(raw={self.raw!r})" + + +class Request: + def __init__( + self, + method: str, + url: typing.Union[str, URL], + *, + headers: HeaderTypes = None, + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + ): + self.method = method.upper() + self.url = URL(url) if isinstance(url, str) else url + if isinstance(body, bytes): + self.is_streaming = False + self.body = body + else: + self.is_streaming = True + self.body_aiter = body + self.headers = Headers(headers) + + async def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "body"): + body = b"" + async for part in self.stream(): + body += part + self.body = body + return self.body + + async def stream(self) -> typing.AsyncIterator[bytes]: + if self.is_streaming: + async for part in self.body_aiter: + yield part + elif self.body: + yield self.body + + def prepare(self) -> None: + auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + + has_host = "host" in self.headers + has_content_length = ( + "content-length" in self.headers or "transfer-encoding" in self.headers + ) + has_accept_encoding = "accept-encoding" in self.headers + + if not has_host: + auto_headers.append((b"host", self.url.netloc.encode("ascii"))) + if not has_content_length: + if self.is_streaming: + auto_headers.append((b"transfer-encoding", b"chunked")) + elif self.body: + content_length = str(len(self.body)).encode() + auto_headers.append((b"content-length", content_length)) + if not has_accept_encoding: + auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) + + for item in reversed(auto_headers): + self.headers.raw.insert(0, item) + + +class Response: + def __init__( + self, + status_code: int, + *, + reason: typing.Optional[str] = None, + protocol: typing.Optional[str] = None, + headers: typing.List[typing.Tuple[bytes, bytes]] = [], + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + on_close: typing.Callable = None, + request: Request = None, + history: typing.List["Response"] = None, + ): + self.status_code = status_code + if not reason: + try: + self.reason = http.HTTPStatus(status_code).phrase + except ValueError as exc: + self.reason = "" + else: + self.reason = reason + self.protocol = protocol + self.headers = Headers(headers) + self.on_close = on_close + self.is_closed = False + self.is_streamed = False + + decoders = [] # type: typing.List[Decoder] + value = self.headers.get("content-encoding", "identity") + for part in value.split(","): + part = part.strip().lower() + decoder_cls = SUPPORTED_DECODERS[part] + decoders.append(decoder_cls()) + + if len(decoders) == 0: + self.decoder = IdentityDecoder() # type: Decoder + elif len(decoders) == 1: + self.decoder = decoders[0] + else: + self.decoder = MultiDecoder(decoders) + + if isinstance(body, bytes): + self.is_closed = True + self.body = self.decoder.decode(body) + self.decoder.flush() + else: + self.body_aiter = body + + self.request = request + self.history = [] if history is None else list(history) + + @property + def url(self) -> typing.Optional[URL]: + return None if self.request is None else self.request.url + + async def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "body"): + body = b"" + async for part in self.stream(): + body += part + self.body = body + return self.body + + async def stream(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "body"): + yield self.body + else: + async for chunk in self.raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() + + async def raw(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_streamed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() + self.is_streamed = True + async for part in self.body_aiter: + yield part + await self.close() + + async def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not self.is_closed: + self.is_closed = True + if self.on_close is not None: + await self.on_close() + + @property + def is_redirect(self) -> bool: + return ( + self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers + ) diff --git a/httpcore/status_codes.py b/httpcore/status_codes.py new file mode 100644 index 00000000..6a224d04 --- /dev/null +++ b/httpcore/status_codes.py @@ -0,0 +1,61 @@ +import enum + +codes = enum.IntEnum( + "StatusCode", + [ + ("continue", 100), + ("switching_protocols", 101), + ("processing", 102), + ("ok", 200), + ("created", 201), + ("accepted", 202), + ("non_authoritative_information", 203), + ("no_content", 204), + ("reset_content", 205), + ("partial_content", 206), + ("multi_status", 207), + ("already_reported", 208), + ("im_used", 226), + ("multiple_choices", 300), + ("moved_permanently", 301), + ("found", 302), + ("see_other", 303), + ("not_modified", 304), + ("use_proxy", 305), + ("temporary_redirect", 307), + ("permanent_redirect", 308), + ("bad_request", 400), + ("unauthorized", 401), + ("payment_required", 402), + ("forbidden", 403), + ("not_found", 404), + ("method_not_allowed", 405), + ("not_acceptable", 406), + ("proxy_authentication_required", 407), + ("request_timeout", 408), + ("conflict", 409), + ("gone", 410), + ("length_required", 411), + ("precondition_failed", 412), + ("request_entity_too_large", 413), + ("request_uri_too_long", 414), + ("unsupported_media_type", 415), + ("requested_range_not_satisfiable", 416), + ("expectation_failed", 417), + ("unprocessable_entity", 422), + ("locked", 423), + ("failed_dependency", 424), + ("precondition_required", 428), + ("too_many_requests", 429), + ("request_header_fields_too_large", 431), + ("unavailable_for_legal_reasons", 451), + ("internal_server_error", 500), + ("not_implemented", 501), + ("bad_gateway", 502), + ("service_unavailable", 503), + ("gateway_timeout", 504), + ("http_version_not_supported", 505), + ("insufficient_storage", 507), + ("network_authentication_required", 511), + ], +) diff --git a/httpcore/streams.py b/httpcore/streams.py new file mode 100644 index 00000000..a8590db4 --- /dev/null +++ b/httpcore/streams.py @@ -0,0 +1,133 @@ +""" +The `Reader` and `Writer` classes here provide a lightweight layer over +`asyncio.StreamReader` and `asyncio.StreamWriter`. + +Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`. + +These classes help encapsulate the timeout logic, make it easier to unit-test +protocols, and help keep the rest of the package more `async`/`await` +based, and less strictly `asyncio`-specific. +""" +import asyncio +import enum +import ssl +import typing + +from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig +from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout +from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter + +OptionalTimeout = typing.Optional[TimeoutConfig] + + +class Protocol(enum.Enum): + HTTP_11 = 1 + HTTP_2 = 2 + + +class Reader(BaseReader): + def __init__( + self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig + ) -> None: + self.stream_reader = stream_reader + self.timeout = timeout + + async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes: + if timeout is None: + timeout = self.timeout + + try: + data = await asyncio.wait_for( + self.stream_reader.read(n), timeout.read_timeout + ) + except asyncio.TimeoutError: + raise ReadTimeout() + + return data + + +class Writer(BaseWriter): + def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig): + self.stream_writer = stream_writer + self.timeout = timeout + + def write_no_block(self, data: bytes) -> None: + self.stream_writer.write(data) + + async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None: + if not data: + return + + if timeout is None: + timeout = self.timeout + + self.stream_writer.write(data) + try: + data = await asyncio.wait_for( # type: ignore + self.stream_writer.drain(), timeout.write_timeout + ) + except asyncio.TimeoutError: + raise WriteTimeout() + + async def close(self) -> None: + self.stream_writer.close() + + +class PoolSemaphore(BasePoolSemaphore): + def __init__(self, limits: PoolLimits): + self.limits = limits + + @property + def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]: + if not hasattr(self, "_semaphore"): + max_connections = self.limits.hard_limit + if max_connections is None: + self._semaphore = None + else: + self._semaphore = asyncio.BoundedSemaphore(value=max_connections) + return self._semaphore + + async def acquire(self) -> None: + if self.semaphore is None: + return + + timeout = self.limits.pool_timeout + try: + await asyncio.wait_for(self.semaphore.acquire(), timeout) + except asyncio.TimeoutError: + raise PoolTimeout() + + def release(self) -> None: + if self.semaphore is None: + return + + self.semaphore.release() + + +async def connect( + hostname: str, + port: int, + ssl_context: typing.Optional[ssl.SSLContext] = None, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, +) -> typing.Tuple[Reader, Writer, Protocol]: + try: + stream_reader, stream_writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection(hostname, port, ssl=ssl_context), + timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + ssl_object = stream_writer.get_extra_info("ssl_object") + if ssl_object is None: + ident = "http/1.1" + else: + ident = ssl_object.selected_alpn_protocol() + if ident is None: + ident = ssl_object.selected_npn_protocol() + + reader = Reader(stream_reader=stream_reader, timeout=timeout) + writer = Writer(stream_writer=stream_writer, timeout=timeout) + protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11 + + return (reader, writer, protocol) diff --git a/httpcore/sync.py b/httpcore/sync.py index 1d560faf..2d58f9a1 100644 --- a/httpcore/sync.py +++ b/httpcore/sync.py @@ -3,8 +3,9 @@ import typing from types import TracebackType from .config import SSLConfig, TimeoutConfig -from .connectionpool import ConnectionPool -from .datastructures import URL, Client, Response +from .dispatch.connection_pool import ConnectionPool +from .interfaces import Adapter +from .models import URL, Headers, Response class SyncResponse: @@ -21,7 +22,7 @@ class SyncResponse: return self._response.reason @property - def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]: + def headers(self) -> Headers: return self._response.headers @property @@ -44,8 +45,8 @@ class SyncResponse: class SyncClient: - def __init__(self, client: Client): - self._client = client + def __init__(self, adapter: Adapter): + self._client = adapter self._loop = asyncio.new_event_loop() def request( @@ -53,22 +54,12 @@ class SyncClient: method: str, url: typing.Union[str, URL], *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + headers: typing.List[typing.Tuple[bytes, bytes]] = [], body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - stream: bool = False, + **options: typing.Any ) -> SyncResponse: response = self._loop.run_until_complete( - self._client.request( - method, - url, - headers=headers, - body=body, - ssl=ssl, - timeout=timeout, - stream=stream, - ) + self._client.request(method, url, headers=headers, body=body, **options) ) return SyncResponse(response, self._loop) diff --git a/httpcore/utils.py b/httpcore/utils.py new file mode 100644 index 00000000..419e7ec2 --- /dev/null +++ b/httpcore/utils.py @@ -0,0 +1,71 @@ +import typing +from urllib.parse import quote + +from .exceptions import InvalidURL + +# The unreserved URI characters (RFC 3986) +UNRESERVED_SET = frozenset( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~" +) + + +def unquote_unreserved(uri: str) -> str: + """ + Un-escape any percent-escape sequences in a URI that are unreserved + characters. This leaves all reserved, illegal and non-ASCII bytes encoded. + """ + parts = uri.split("%") + for i in range(1, len(parts)): + h = parts[i][0:2] + if len(h) == 2 and h.isalnum(): + try: + c = chr(int(h, 16)) + except ValueError: + raise InvalidURL("Invalid percent-escape sequence: '%s'" % h) + + if c in UNRESERVED_SET: + parts[i] = c + parts[i][2:] + else: + parts[i] = "%" + parts[i] + else: + parts[i] = "%" + parts[i] + return "".join(parts) + + +def requote_uri(uri: str) -> str: + """ + Re-quote the given URI. + + This function passes the given URI through an unquote/quote cycle to + ensure that it is fully and consistently quoted. + """ + safe_with_percent = "!#$%&'()*+,/:;=?@[]~" + safe_without_percent = "!#$&'()*+,/:;=?@[]~" + try: + # Unquote only the unreserved characters + # Then quote only illegal characters (do not quote reserved, + # unreserved, or '%') + return quote(unquote_unreserved(uri), safe=safe_with_percent) + except InvalidURL: + # We couldn't unquote the given URI, so let's try quoting it, but + # there may be unquoted '%'s in the URI. We need to make sure they're + # properly quoted so they do not cause issues elsewhere. + return quote(uri, safe=safe_without_percent) + + +def normalize_header_key(value: typing.AnyStr) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header key. + """ + if isinstance(value, bytes): + return value.lower() + return value.encode("latin-1").lower() + + +def normalize_header_value(value: typing.AnyStr) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header value. + """ + if isinstance(value, bytes): + return value + return value.encode("latin-1") diff --git a/requirements.txt b/requirements.txt index 5108a8d6..18f9c5fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ certifi h11 +h2 # Optional brotlipy diff --git a/setup.py b/setup.py index 6d709be2..93d02ad5 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( author_email="tom@tomchristie.com", packages=get_packages("httpcore"), data_files=[("", ["LICENSE.md"])], - install_requires=["h11", "certifi"], + install_requires=["h11", "h2", "certifi"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Web Environment", diff --git a/tests/adapters/test_redirects.py b/tests/adapters/test_redirects.py new file mode 100644 index 00000000..ce6d4d3f --- /dev/null +++ b/tests/adapters/test_redirects.py @@ -0,0 +1,219 @@ +import json +from urllib.parse import parse_qs + +import pytest + +from httpcore import ( + URL, + Adapter, + RedirectAdapter, + RedirectBodyUnavailable, + RedirectLoop, + Request, + Response, + TooManyRedirects, + codes, +) + + +class MockDispatch(Adapter): + def prepare_request(self, request: Request) -> None: + pass + + async def send(self, request: Request, **options) -> Response: + if request.url.path == "/redirect_301": + status_code = codes.moved_permanently + headers = {"location": "https://example.org/"} + return Response(status_code, headers=headers, request=request) + + elif request.url.path == "/redirect_302": + status_code = codes.found + headers = {"location": "https://example.org/"} + return Response(status_code, headers=headers, request=request) + + elif request.url.path == "/redirect_303": + status_code = codes.see_other + headers = {"location": "https://example.org/"} + return Response(status_code, headers=headers, request=request) + + elif request.url.path == "/relative_redirect": + headers = {"location": "/"} + return Response(codes.see_other, headers=headers, request=request) + + elif request.url.path == "/no_scheme_redirect": + headers = {"location": "//example.org/"} + return Response(codes.see_other, headers=headers, request=request) + + elif request.url.path == "/multiple_redirects": + params = parse_qs(request.url.query) + count = int(params.get("count", "0")[0]) + redirect_count = count - 1 + code = codes.see_other if count else codes.ok + location = "/multiple_redirects" + if redirect_count: + location += "?count=" + str(redirect_count) + headers = {"location": location} if count else {} + return Response(code, headers=headers, request=request) + + if request.url.path == "/redirect_loop": + headers = {"location": "/redirect_loop"} + return Response(codes.see_other, headers=headers, request=request) + + elif request.url.path == "/cross_domain": + headers = {"location": "https://example.org/cross_domain_target"} + return Response(codes.see_other, headers=headers, request=request) + + elif request.url.path == "/cross_domain_target": + headers = dict(request.headers.items()) + body = json.dumps({"headers": headers}).encode() + return Response(codes.ok, body=body, request=request) + + elif request.url.path == "/redirect_body": + body = await request.read() + headers = {"location": "/redirect_body_target"} + return Response(codes.permanent_redirect, headers=headers, request=request) + + elif request.url.path == "/redirect_body_target": + body = await request.read() + body = json.dumps({"body": body.decode()}).encode() + return Response(codes.ok, body=body, request=request) + + return Response(codes.ok, body=b"Hello, world!", request=request) + + +@pytest.mark.asyncio +async def test_redirect_301(): + client = RedirectAdapter(MockDispatch()) + response = await client.request("POST", "https://example.org/redirect_301") + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_redirect_302(): + client = RedirectAdapter(MockDispatch()) + response = await client.request("POST", "https://example.org/redirect_302") + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_redirect_303(): + client = RedirectAdapter(MockDispatch()) + response = await client.request("GET", "https://example.org/redirect_303") + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_disallow_redirects(): + client = RedirectAdapter(MockDispatch()) + response = await client.request("POST", "https://example.org/redirect_303", allow_redirects=False) + assert response.status_code == codes.see_other + assert response.url == URL("https://example.org/redirect_303") + assert len(response.history) == 0 + + response = await response.next() + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_relative_redirect(): + client = RedirectAdapter(MockDispatch()) + response = await client.request("GET", "https://example.org/relative_redirect") + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_no_scheme_redirect(): + client = RedirectAdapter(MockDispatch()) + response = await client.request("GET", "https://example.org/no_scheme_redirect") + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_fragment_redirect(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/relative_redirect#fragment" + response = await client.request("GET", url) + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/#fragment") + assert len(response.history) == 1 + + +@pytest.mark.asyncio +async def test_multiple_redirects(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/multiple_redirects?count=20" + response = await client.request("GET", url) + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/multiple_redirects") + assert len(response.history) == 20 + + +@pytest.mark.asyncio +async def test_too_many_redirects(): + client = RedirectAdapter(MockDispatch()) + with pytest.raises(TooManyRedirects): + await client.request("GET", "https://example.org/multiple_redirects?count=21") + + +@pytest.mark.asyncio +async def test_redirect_loop(): + client = RedirectAdapter(MockDispatch()) + with pytest.raises(RedirectLoop): + await client.request("GET", "https://example.org/redirect_loop") + + +@pytest.mark.asyncio +async def test_cross_domain_redirect(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.com/cross_domain" + headers = {"Authorization": "abc"} + response = await client.request("GET", url, headers=headers) + data = json.loads(response.body.decode()) + assert response.url == URL("https://example.org/cross_domain_target") + assert data == {"headers": {}} + + +@pytest.mark.asyncio +async def test_same_domain_redirect(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/cross_domain" + headers = {"Authorization": "abc"} + response = await client.request("GET", url, headers=headers) + data = json.loads(response.body.decode()) + assert response.url == URL("https://example.org/cross_domain_target") + assert data == {"headers": {"authorization": "abc"}} + + +@pytest.mark.asyncio +async def test_body_redirect(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/redirect_body" + body = b"Example request body" + response = await client.request("POST", url, body=body) + data = json.loads(response.body.decode()) + assert response.url == URL("https://example.org/redirect_body_target") + assert data == {"body": "Example request body"} + + +@pytest.mark.asyncio +async def test_cannot_redirect_streaming_body(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/redirect_body" + + async def body(): + yield b"Example request body" + + with pytest.raises(RedirectBodyUnavailable): + await client.request("POST", url, body=body()) diff --git a/tests/test_api.py b/tests/test_api.py index 6b80587d..4622849b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,16 +5,16 @@ import httpcore @pytest.mark.asyncio async def test_get(server): - async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/") + async with httpcore.Client() as client: + response = await client.request("GET", "http://127.0.0.1:8000/") assert response.status_code == 200 assert response.body == b"Hello, world!" @pytest.mark.asyncio async def test_post(server): - async with httpcore.ConnectionPool() as http: - response = await http.request( + async with httpcore.Client() as client: + response = await client.request( "POST", "http://127.0.0.1:8000/", body=b"Hello, world!" ) assert response.status_code == 200 @@ -22,8 +22,8 @@ async def test_post(server): @pytest.mark.asyncio async def test_stream_response(server): - async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + async with httpcore.Client() as client: + response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 assert not hasattr(response, "body") body = await response.read() @@ -36,8 +36,8 @@ async def test_stream_request(server): yield b"Hello, " yield b"world!" - async with httpcore.ConnectionPool() as http: - response = await http.request( + async with httpcore.Client() as client: + response = await client.request( "POST", "http://127.0.0.1:8000/", body=hello_world() ) assert response.status_code == 200 diff --git a/tests/test_config.py b/tests/test_config.py index daf0e1ec..e4ce64a4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,13 +13,15 @@ def test_timeout_repr(): timeout = httpcore.TimeoutConfig(read_timeout=5.0) assert ( repr(timeout) - == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)" + == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)" ) def test_limits_repr(): limits = httpcore.PoolLimits(hard_limit=100) - assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)" + assert ( + repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)" + ) def test_ssl_eq(): @@ -35,24 +37,3 @@ def test_timeout_eq(): def test_limits_eq(): limits = httpcore.PoolLimits(hard_limit=100) assert limits == httpcore.PoolLimits(hard_limit=100) - - -def test_ssl_hash(): - cache = {} - ssl = httpcore.SSLConfig(verify=False) - cache[ssl] = "example" - assert cache[httpcore.SSLConfig(verify=False)] == "example" - - -def test_timeout_hash(): - cache = {} - timeout = httpcore.TimeoutConfig(timeout=5.0) - cache[timeout] = "example" - assert cache[httpcore.TimeoutConfig(timeout=5.0)] == "example" - - -def test_limits_hash(): - cache = {} - limits = httpcore.PoolLimits(hard_limit=100) - cache[limits] = "example" - assert cache[httpcore.PoolLimits(hard_limit=100)] == "example" diff --git a/tests/test_connection_pools.py b/tests/test_connection_pools.py index 77a22157..7d478c5a 100644 --- a/tests/test_connection_pools.py +++ b/tests/test_connection_pools.py @@ -10,12 +10,12 @@ async def test_keepalive_connections(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -25,12 +25,12 @@ async def test_differing_connection_keys(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 2 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 2 @pytest.mark.asyncio @@ -42,12 +42,12 @@ async def test_soft_limit(server): async with httpcore.ConnectionPool(limits=limits) as http: response = await http.request("GET", "http://127.0.0.1:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) - assert http.num_active_connections == 1 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 1 + assert len(http.keepalive_connections) == 0 await response.read() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server): """ async with httpcore.ConnectionPool() as http: response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True) - assert http.num_active_connections == 1 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 1 + assert len(http.keepalive_connections) == 0 response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True) - assert http.num_active_connections == 2 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 2 + assert len(http.keepalive_connections) == 0 await response_b.read() - assert http.num_active_connections == 1 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 1 + assert len(http.keepalive_connections) == 1 await response_a.read() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 2 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 2 @pytest.mark.asyncio @@ -97,8 +97,8 @@ async def test_close_connections(server): headers = [(b"connection", b"close")] async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 0 @pytest.mark.asyncio @@ -110,8 +110,8 @@ async def test_standard_response_close(server): response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) await response.read() await response.close() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 1 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 @pytest.mark.asyncio @@ -122,5 +122,5 @@ async def test_premature_response_close(server): async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) await response.close() - assert http.num_active_connections == 0 - assert http.num_keepalive_connections == 0 + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 0 diff --git a/tests/test_connections.py b/tests/test_connections.py index f1590140..11031106 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -5,7 +5,7 @@ import httpcore @pytest.mark.asyncio async def test_get(server): - http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/") + http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/") response = await http.request("GET", "http://127.0.0.1:8000/") assert response.status_code == 200 assert response.body == b"Hello, world!" @@ -13,7 +13,7 @@ async def test_get(server): @pytest.mark.asyncio async def test_post(server): - http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/") + http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/") response = await http.request( "POST", "http://127.0.0.1:8000/", body=b"Hello, world!" ) diff --git a/tests/test_http2.py b/tests/test_http2.py new file mode 100644 index 00000000..dc287bce --- /dev/null +++ b/tests/test_http2.py @@ -0,0 +1,116 @@ +import json + +import h2.config +import h2.connection +import h2.events +import pytest + +import httpcore + + +class MockServer(httpcore.BaseReader, httpcore.BaseWriter): + """ + This class exposes Reader and Writer style interfaces + """ + + def __init__(self): + config = h2.config.H2Configuration(client_side=False) + self.conn = h2.connection.H2Connection(config=config) + self.buffer = b"" + self.requests = {} + + # BaseReader interface + + async def read(self, n, timeout) -> bytes: + send, self.buffer = self.buffer[:n], self.buffer[n:] + return send + + # BaseWriter interface + + def write_no_block(self, data: bytes) -> None: + events = self.conn.receive_data(data) + self.buffer += self.conn.data_to_send() + for event in events: + if isinstance(event, h2.events.RequestReceived): + self.request_received(event.headers, event.stream_id) + elif isinstance(event, h2.events.DataReceived): + self.receive_data(event.data, event.stream_id) + elif isinstance(event, h2.events.StreamEnded): + self.stream_complete(event.stream_id) + + async def write(self, data: bytes, timeout) -> None: + self.write_no_block(data) + + async def close(self) -> None: + pass + + # Server implementation + + def request_received(self, headers, stream_id): + if stream_id not in self.requests: + self.requests[stream_id] = [] + self.requests[stream_id].append({"headers": headers, "data": b""}) + + def receive_data(self, data, stream_id): + self.requests[stream_id][-1]["data"] += data + + def stream_complete(self, stream_id): + request = self.requests[stream_id].pop(0) + if not self.requests[stream_id]: + del self.requests[stream_id] + + request_headers = dict(request["headers"]) + request_data = request["data"] + + response_body = json.dumps( + { + "method": request_headers[b":method"].decode(), + "path": request_headers[b":path"].decode(), + "body": request_data.decode(), + } + ).encode() + + response_headers = ((b":status", b"200"),) + self.conn.send_headers(stream_id, response_headers) + self.conn.send_data(stream_id, response_body, end_stream=True) + self.buffer += self.conn.data_to_send() + + +@pytest.mark.asyncio +async def test_http2_get_request(): + server = MockServer() + async with httpcore.HTTP2Connection(reader=server, writer=server) as conn: + response = await conn.request("GET", "http://example.org") + assert response.status_code == 200 + assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""} + + +@pytest.mark.asyncio +async def test_http2_post_request(): + server = MockServer() + async with httpcore.HTTP2Connection(reader=server, writer=server) as conn: + response = await conn.request("POST", "http://example.org", body=b"") + assert response.status_code == 200 + assert json.loads(response.body) == { + "method": "POST", + "path": "/", + "body": "", + } + + +@pytest.mark.asyncio +async def test_http2_multiple_requests(): + server = MockServer() + async with httpcore.HTTP2Connection(reader=server, writer=server) as conn: + response_1 = await conn.request("GET", "http://example.org/1") + response_2 = await conn.request("GET", "http://example.org/2") + response_3 = await conn.request("GET", "http://example.org/3") + + assert response_1.status_code == 200 + assert json.loads(response_1.body) == {"method": "GET", "path": "/1", "body": ""} + + assert response_2.status_code == 200 + assert json.loads(response_2.body) == {"method": "GET", "path": "/2", "body": ""} + + assert response_3.status_code == 200 + assert json.loads(response_3.body) == {"method": "GET", "path": "/3", "body": ""} diff --git a/tests/test_requests.py b/tests/test_requests.py index bdbf2caa..4df3529c 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -5,19 +5,22 @@ import httpcore def test_host_header(): request = httpcore.Request("GET", "http://example.org") - assert request.headers == [ - (b"host", b"example.org"), - (b"accept-encoding", b"deflate, gzip, br"), - ] + request.prepare() + assert request.headers == httpcore.Headers( + [(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")] + ) def test_content_length_header(): request = httpcore.Request("POST", "http://example.org", body=b"test 123") - assert request.headers == [ - (b"host", b"example.org"), - (b"content-length", b"8"), - (b"accept-encoding", b"deflate, gzip, br"), - ] + request.prepare() + assert request.headers == httpcore.Headers( + [ + (b"host", b"example.org"), + (b"content-length", b"8"), + (b"accept-encoding", b"deflate, gzip, br"), + ] + ) def test_transfer_encoding_header(): @@ -27,31 +30,34 @@ def test_transfer_encoding_header(): body = streaming_body(b"test 123") request = httpcore.Request("POST", "http://example.org", body=body) - assert request.headers == [ - (b"host", b"example.org"), - (b"transfer-encoding", b"chunked"), - (b"accept-encoding", b"deflate, gzip, br"), - ] + request.prepare() + assert request.headers == httpcore.Headers( + [ + (b"host", b"example.org"), + (b"transfer-encoding", b"chunked"), + (b"accept-encoding", b"deflate, gzip, br"), + ] + ) def test_override_host_header(): headers = [(b"host", b"1.2.3.4:80")] request = httpcore.Request("GET", "http://example.org", headers=headers) - assert request.headers == [ - (b"accept-encoding", b"deflate, gzip, br"), - (b"host", b"1.2.3.4:80"), - ] + request.prepare() + assert request.headers == httpcore.Headers( + [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")] + ) def test_override_accept_encoding_header(): headers = [(b"accept-encoding", b"identity")] request = httpcore.Request("GET", "http://example.org", headers=headers) - assert request.headers == [ - (b"host", b"example.org"), - (b"accept-encoding", b"identity"), - ] + request.prepare() + assert request.headers == httpcore.Headers( + [(b"host", b"example.org"), (b"accept-encoding", b"identity")] + ) def test_override_content_length_header(): @@ -62,23 +68,26 @@ def test_override_content_length_header(): headers = [(b"content-length", b"8")] request = httpcore.Request("POST", "http://example.org", body=body, headers=headers) - assert request.headers == [ - (b"host", b"example.org"), - (b"accept-encoding", b"deflate, gzip, br"), - (b"content-length", b"8"), - ] + request.prepare() + assert request.headers == httpcore.Headers( + [ + (b"host", b"example.org"), + (b"accept-encoding", b"deflate, gzip, br"), + (b"content-length", b"8"), + ] + ) def test_url(): request = httpcore.Request("GET", "http://example.org") assert request.url.scheme == "http" assert request.url.port == 80 - assert request.url.target == "/" + assert request.url.full_path == "/" request = httpcore.Request("GET", "https://example.org/abc?foo=bar") assert request.url.scheme == "https" assert request.url.port == 443 - assert request.url.target == "/abc?foo=bar" + assert request.url.full_path == "/abc?foo=bar" def test_invalid_urls(): diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py index b1ceef93..d91cb799 100644 --- a/tests/test_timeouts.py +++ b/tests/test_timeouts.py @@ -24,10 +24,9 @@ async def test_connect_timeout(server): @pytest.mark.asyncio async def test_pool_timeout(server): - timeout = httpcore.TimeoutConfig(pool_timeout=0.0001) - limits = httpcore.PoolLimits(hard_limit=1) + limits = httpcore.PoolLimits(hard_limit=1, pool_timeout=0.0001) - async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http: + async with httpcore.ConnectionPool(limits=limits) as http: response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) with pytest.raises(httpcore.PoolTimeout):