diff --git a/README.md b/README.md index c6b8d257..63eb03dd 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,8 @@ of it, and exposes only plain datastructures that reflect the network response. ```python import httpcore -response = await httpcore.request('GET', 'http://example.com') +http = httpcore.ConnectionPool() +response = await http.request('GET', 'http://example.com') assert response.status_code == 200 assert response.body == b'Hello, world' ``` @@ -71,20 +72,22 @@ assert response.body == b'Hello, world' Top-level API... ```python -response = await httpcore.request(method, url, [headers], [body], [stream]) +http = httpcore.ConnectionPool([ssl], [timeout], [limits]) +response = await http.request(method, url, [headers], [body], [stream]) ``` -Explicit PoolManager... +ConnectionPool as a context-manager... ```python -async with httpcore.PoolManager([ssl], [timeout], [limits]) as pool: - response = await pool.request(method, url, [headers], [body], [stream]) +async with httpcore.ConnectionPool([ssl], [timeout], [limits]) as http: + response = await http.request(method, url, [headers], [body], [stream]) ``` Streaming... ```python -response = await httpcore.request(method, url, stream=True) +http = httpcore.ConnectionPool() +response = await http.request(method, url, stream=True) async for part in response.stream(): ... ``` @@ -100,7 +103,7 @@ import httpcore class GatewayServer: def __init__(self, base_url): self.base_url = base_url - self.pool = httpcore.PoolManager() + self.http = httpcore.ConnectionPool() async def __call__(self, scope, receive, send): assert scope['type'] == 'http' @@ -122,7 +125,7 @@ class GatewayServer: if not message.get('more_body', False): break - response = await self.pool.request( + response = await self.http.request( method, url, headers=headers, body=body, stream=True ) diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 69894f36..24a6bbfb 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,5 +1,6 @@ -from .api import PoolManager, Response, request from .config import PoolLimits, SSLConfig, TimeoutConfig +from .datastructures import URL, Request, Response from .exceptions import ResponseClosed, StreamConsumed +from .pool import ConnectionPool __version__ = "0.0.2" diff --git a/httpcore/api.py b/httpcore/api.py deleted file mode 100644 index c270d569..00000000 --- a/httpcore/api.py +++ /dev/null @@ -1,132 +0,0 @@ -import typing -from types import TracebackType - -from .config import ( - DEFAULT_POOL_LIMITS, - DEFAULT_SSL_CONFIG, - DEFAULT_TIMEOUT_CONFIG, - PoolLimits, - SSLConfig, - TimeoutConfig, -) -from .exceptions import ResponseClosed, StreamConsumed - - -async def request( - method: str, - url: str, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - stream: bool = False, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, -) -> "Response": - async with PoolManager(ssl=ssl, timeout=timeout) as pool: - return await pool.request( - method=method, url=url, headers=headers, body=body, stream=stream - ) - - -class PoolManager: - def __init__( - self, - *, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - limits: PoolLimits = DEFAULT_POOL_LIMITS, - ): - self.ssl = ssl - self.timeout = timeout - self.limits = limits - self.is_closed = False - - async def request( - self, - method: str, - url: str, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - stream: bool = False, - ) -> "Response": - if stream: - async def streaming_body(): - yield b"Hello, " - yield b"world!" - return Response(200, body=streaming_body) - return Response(200, body=b"Hello, world!") - - async def close(self) -> None: - self.is_closed = True - - async def __aenter__(self) -> "PoolManager": - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.close() - - -class Response: - def __init__( - self, - status_code: int, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - on_close: typing.Callable = None, - ): - self.status_code = status_code - self.headers = list(headers) - self.on_close = on_close - self.is_closed = False - self.is_streamed = False - if isinstance(body, bytes): - self.is_closed = True - self.body = body - else: - self.body_aiter = body - - async def read(self) -> bytes: - if not hasattr(self, "body"): - body = b"" - async for part in self.stream(): - body += part - self.body = body - return self.body - - async def stream(self) -> typing.AsyncIterator[bytes]: - if hasattr(self, "body"): - yield self.body - else: - if self.is_streamed: - raise StreamConsumed() - if self.is_closed: - raise ResponseClosed() - self.is_streamed = True - async for part in self.body_aiter(): - yield part - await self.close() - - async def close(self) -> None: - if not self.is_closed: - self.is_closed = True - if self.on_close is not None: - await self.on_close() - - async def __aenter__(self) -> "Response": - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - if not self.is_closed: - await self.close() diff --git a/httpcore/config.py b/httpcore/config.py index aa0c0718..d169e0af 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -1,5 +1,7 @@ import typing +import certifi + class SSLConfig: """ @@ -52,3 +54,4 @@ class PoolLimits: DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False) +DEFAULT_CA_BUNDLE_PATH = certifi.where() diff --git a/httpcore/connections.py b/httpcore/connections.py new file mode 100644 index 00000000..312fc46c --- /dev/null +++ b/httpcore/connections.py @@ -0,0 +1,129 @@ +import asyncio +import ssl +import typing + +import h11 + +from .config import TimeoutConfig +from .datastructures import Request, Response +from .exceptions import ConnectTimeout, ReadTimeout + +H11Event = typing.Union[ + h11.Request, + h11.Response, + h11.InformationalResponse, + h11.Data, + h11.EndOfMessage, + h11.ConnectionClosed, +] + + +class Connection: + def __init__(self, timeout: TimeoutConfig): + self.reader = None + self.writer = None + self.state = h11.Connection(our_role=h11.CLIENT) + self.timeout = timeout + + async def open( + self, + hostname: str, + port: int, + *, + ssl: typing.Union[bool, ssl.SSLContext] = False + ) -> None: + try: + self.reader, self.writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection(hostname, port, ssl=ssl), + self.timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + async def send(self, request: Request, stream: bool = False) -> Response: + method = request.method.encode() + target = request.url.target + host_header = (b"host", request.url.netloc.encode("ascii")) + if request.is_streaming: + content_length = (b"transfer-encoding", b"chunked") + else: + content_length = (b"content-length", str(len(request.body)).encode()) + + headers = [host_header, content_length] + request.headers + + #  Start sending the request. + event = h11.Request(method=method, target=target, headers=headers) + await self._send_event(event) + + # Send the request body. + if request.is_streaming: + async for data in request.stream(): + event = h11.Data(data=data) + await self._send_event(event) + elif request.body: + event = h11.Data(data=request.body) + await self._send_event(event) + + # Finalize sending the request. + event = h11.EndOfMessage() + await self._send_event(event) + + # Start getting the response. + event = await self._receive_event() + if isinstance(event, h11.InformationalResponse): + event = await self._receive_event() + assert isinstance(event, h11.Response) + status_code = event.status_code + headers = event.headers + + if stream: + body_iter = self.body_iter() + return Response(status_code=status_code, headers=headers, body=body_iter) + + #  Get the response body. + body = b"" + event = await self._receive_event() + while isinstance(event, h11.Data): + body += event.data + event = await self._receive_event() + assert isinstance(event, h11.EndOfMessage) + await self.close() + + return Response(status_code=status_code, headers=headers, body=body) + + async def body_iter(self) -> typing.AsyncIterator[bytes]: + event = await self._receive_event() + while isinstance(event, h11.Data): + yield event.data + event = await self._receive_event() + assert isinstance(event, h11.EndOfMessage) + await self.close() + + async def _send_event(self, event: H11Event) -> None: + assert self.writer is not None + + data = self.state.send(event) + self.writer.write(data) + + async def _receive_event(self) -> H11Event: + assert self.reader is not None + + event = self.state.next_event() + + while event is h11.NEED_DATA: + try: + data = await asyncio.wait_for( + self.reader.read(2048), self.timeout.read_timeout + ) + except asyncio.TimeoutError: + raise ReadTimeout() + self.state.receive_data(data) + event = self.state.next_event() + + return event + + async def close(self) -> None: + if self.writer is not None: + self.writer.close() + if hasattr(self.writer, "wait_closed"): + await self.writer.wait_closed() diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py new file mode 100644 index 00000000..d60e18a5 --- /dev/null +++ b/httpcore/datastructures.py @@ -0,0 +1,145 @@ +import typing +from urllib.parse import urlsplit + +from .decoders import IdentityDecoder +from .exceptions import ResponseClosed, StreamConsumed + + +class URL: + def __init__(self, url: str = "") -> None: + self.components = urlsplit(url) + if not self.components.scheme: + raise ValueError("No scheme included in URL.") + if self.components.scheme not in ("http", "https"): + raise ValueError('URL scheme must be "http" or "https".') + if not self.components.hostname: + raise ValueError("No hostname included in URL.") + + @property + def scheme(self) -> str: + return self.components.scheme + + @property + def netloc(self) -> str: + return self.components.netloc + + @property + def path(self) -> str: + return self.components.path + + @property + def query(self) -> str: + return self.components.query + + @property + def hostname(self) -> str: + return self.components.hostname + + @property + def port(self) -> int: + port = self.components.port + if port is None: + return {"https": 443, "http": 80}[self.scheme] + return port + + @property + def target(self) -> str: + path = self.path or "/" + query = self.query + if query: + return path + "?" + query + return path + + @property + def is_secure(self) -> bool: + return self.components.scheme == "https" + + +class Request: + def __init__( + self, + method: str, + url: URL, + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + ): + self.method = method + self.url = url + self.headers = list(headers) + if isinstance(body, bytes): + self.is_streaming = False + self.body = body + else: + self.is_streaming = True + self.body_aiter = body + + async def stream(self) -> typing.AsyncIterator[bytes]: + assert self.is_streaming + + async for part in self.body_aiter: + yield part + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + on_close: typing.Callable = None, + ): + self.status_code = status_code + self.headers = list(headers) + self.on_close = on_close + self.is_closed = False + self.is_streamed = False + self.decoder = IdentityDecoder() + if isinstance(body, bytes): + self.is_closed = True + self.body = body + else: + self.body_aiter = body + + async def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "body"): + body = b"" + async for part in self.stream(): + body += part + self.body = body + return self.body + + async def stream(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This will allow us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "body"): + yield self.body + else: + async for chunk in self.raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() + + async def raw(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_streamed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() + self.is_streamed = True + async for part in self.body_aiter: + yield part + await self.close() + + async def close(self) -> None: + if not self.is_closed: + self.is_closed = True + if self.on_close is not None: + await self.on_close() diff --git a/httpcore/decoders.py b/httpcore/decoders.py new file mode 100644 index 00000000..2d35a44f --- /dev/null +++ b/httpcore/decoders.py @@ -0,0 +1,41 @@ +""" +Handlers for Content-Encoding. +""" + + +class IdentityDecoder: + def decode(self, chunk: bytes) -> bytes: + return chunk + + def flush(self) -> bytes: + return b"" + + +# class DeflateDecoder: +# pass +# +# +# class GZipDecoder: +# pass +# +# +# class BrotliDecoder: +# pass +# +# +# class MultiDecoder: +# def __init__(self, children): +# self.children = children +# +# def decode(self, chunk: bytes) -> bytes: +# data = chunk +# for child in children: +# data = child.decode(data) +# return data +# +# def flush(self) -> bytes: +# data = b'' +# for child in children: +# data = child.decode(data) +# data = child.flush() +# return data diff --git a/httpcore/pool.py b/httpcore/pool.py new file mode 100644 index 00000000..75948477 --- /dev/null +++ b/httpcore/pool.py @@ -0,0 +1,126 @@ +import asyncio +import os +import ssl +import typing +from types import TracebackType + +from .config import ( + DEFAULT_CA_BUNDLE_PATH, + DEFAULT_POOL_LIMITS, + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + PoolLimits, + SSLConfig, + TimeoutConfig, +) +from .connections import Connection +from .datastructures import URL, Request, Response + + +class ConnectionPool: + def __init__( + self, + *, + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + limits: PoolLimits = DEFAULT_POOL_LIMITS, + ): + self.ssl_config = ssl + self.timeout = timeout + self.limits = limits + self.is_closed = False + + async def request( + self, + method: str, + url: str, + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + stream: bool = False, + ) -> Response: + parsed_url = URL(url) + request = Request(method, parsed_url, headers=headers, body=body) + ssl_context = await self.get_ssl_context(parsed_url) + connection = await self.acquire_connection(parsed_url, ssl=ssl_context) + response = await connection.send(request, stream=stream) + return response + + async def acquire_connection( + self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False + ) -> Connection: + connection = Connection(timeout=self.timeout) + await connection.open(url.hostname, url.port, ssl=ssl) + return connection + + async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]: + if not url.is_secure: + return False + + if not hasattr(self, "ssl_context"): + if not self.ssl_config.verify: + self.ssl_context = self.get_ssl_context_no_verify() + else: + # Run the SSL loading in a threadpool, since it makes disk accesses. + loop = asyncio.get_event_loop() + self.ssl_context = await loop.run_in_executor( + None, self.get_ssl_context_verify + ) + + return self.ssl_context + + def get_ssl_context_no_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + context.set_default_verify_paths() + return context + + def get_ssl_context_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for verified connections. + """ + cert = self.ssl_config.cert + verify = self.ssl_config.verify + + if isinstance(verify, bool): + ca_bundle_path = DEFAULT_CA_BUNDLE_PATH + elif os.path.exists(verify): + ca_bundle_path = verify + else: + raise IOError( + "Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(verify) + ) + + context = ssl.create_default_context() + if os.path.isfile(ca_bundle_path): + context.load_verify_locations(cafile=ca_bundle_path) + elif os.path.isdir(ca_bundle_path): + context.load_verify_locations(capath=ca_bundle_path) + + if cert is not None: + if isinstance(cert, str): + context.load_cert_chain(certfile=cert) + else: + context.load_cert_chain(certfile=cert[0], keyfile=cert[1]) + + return context + + async def close(self) -> None: + self.is_closed = True + + async def __aenter__(self) -> "ConnectionPool": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() diff --git a/requirements.txt b/requirements.txt index 16c77063..1baef341 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +certifi h11 # Testing @@ -9,3 +10,4 @@ mypy pytest pytest-asyncio pytest-cov +uvicorn diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..234cf43b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +import asyncio + +import pytest +from uvicorn.config import Config +from uvicorn.main import Server + + +async def app(scope, receive, send): + assert scope["type"] == "http" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + +@pytest.fixture +async def server(): + config = Config(app=app, lifespan="off") + server = Server(config=config) + task = asyncio.ensure_future(server.serve()) + try: + while not server.started: + await asyncio.sleep(0.0001) + yield server + finally: + task.cancel() diff --git a/tests/test_api.py b/tests/test_api.py index 2d5df0f3..6b80587d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,93 +4,40 @@ import httpcore @pytest.mark.asyncio -async def test_request(): - response = await httpcore.request("GET", "http://example.com") +async def test_get(server): + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") assert response.status_code == 200 assert response.body == b"Hello, world!" - assert response.is_closed @pytest.mark.asyncio -async def test_read_response(): - response = await httpcore.request("GET", "http://example.com") - +async def test_post(server): + async with httpcore.ConnectionPool() as http: + response = await http.request( + "POST", "http://127.0.0.1:8000/", body=b"Hello, world!" + ) assert response.status_code == 200 - assert response.body == b"Hello, world!" - assert response.is_closed + +@pytest.mark.asyncio +async def test_stream_response(server): + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert response.status_code == 200 + assert not hasattr(response, "body") body = await response.read() - assert body == b"Hello, world!" - assert response.body == b"Hello, world!" - assert response.is_closed @pytest.mark.asyncio -async def test_stream_response(): - response = await httpcore.request("GET", "http://example.com") +async def test_stream_request(server): + async def hello_world(): + yield b"Hello, " + yield b"world!" + async with httpcore.ConnectionPool() as http: + response = await http.request( + "POST", "http://127.0.0.1:8000/", body=hello_world() + ) assert response.status_code == 200 - assert response.body == b"Hello, world!" - assert response.is_closed - - body = b'' - async for part in response.stream(): - body += part - - assert body == b"Hello, world!" - assert response.body == b"Hello, world!" - assert response.is_closed - - -@pytest.mark.asyncio -async def test_read_streaming_response(): - response = await httpcore.request("GET", "http://example.com", stream=True) - - assert response.status_code == 200 - assert not hasattr(response, 'body') - assert not response.is_closed - - body = await response.read() - - assert body == b"Hello, world!" - assert response.body == b"Hello, world!" - assert response.is_closed - - -@pytest.mark.asyncio -async def test_stream_streaming_response(): - response = await httpcore.request("GET", "http://example.com", stream=True) - - assert response.status_code == 200 - assert not hasattr(response, 'body') - assert not response.is_closed - - body = b'' - async for part in response.stream(): - body += part - - assert body == b"Hello, world!" - assert not hasattr(response, 'body') - assert response.is_closed - - -@pytest.mark.asyncio -async def test_cannot_read_after_stream_consumed(): - response = await httpcore.request("GET", "http://example.com", stream=True) - - body = b'' - async for part in response.stream(): - body += part - - with pytest.raises(httpcore.StreamConsumed): - await response.read() - -@pytest.mark.asyncio -async def test_cannot_read_after_response_closed(): - response = await httpcore.request("GET", "http://example.com", stream=True) - - await response.close() - - with pytest.raises(httpcore.ResponseClosed): - await response.read() diff --git a/tests/test_responses.py b/tests/test_responses.py new file mode 100644 index 00000000..ae754b40 --- /dev/null +++ b/tests/test_responses.py @@ -0,0 +1,114 @@ +import pytest + +import httpcore + + +class MockHTTP(httpcore.ConnectionPool): + async def request( + self, method, url, *, headers=(), body=b"", stream=False + ) -> httpcore.Response: + if stream: + + async def streaming_body(): + yield b"Hello, " + yield b"world!" + + return httpcore.Response(200, body=streaming_body()) + return httpcore.Response(200, body=b"Hello, world!") + + +http = MockHTTP() + + +@pytest.mark.asyncio +async def test_request(): + response = await http.request("GET", "http://example.com") + assert response.status_code == 200 + assert response.body == b"Hello, world!" + assert response.is_closed + + +@pytest.mark.asyncio +async def test_read_response(): + response = await http.request("GET", "http://example.com") + + assert response.status_code == 200 + assert response.body == b"Hello, world!" + assert response.is_closed + + body = await response.read() + + assert body == b"Hello, world!" + assert response.body == b"Hello, world!" + assert response.is_closed + + +@pytest.mark.asyncio +async def test_stream_response(): + response = await http.request("GET", "http://example.com") + + assert response.status_code == 200 + assert response.body == b"Hello, world!" + assert response.is_closed + + body = b"" + async for part in response.stream(): + body += part + + assert body == b"Hello, world!" + assert response.body == b"Hello, world!" + assert response.is_closed + + +@pytest.mark.asyncio +async def test_read_streaming_response(): + response = await http.request("GET", "http://example.com", stream=True) + + assert response.status_code == 200 + assert not hasattr(response, "body") + assert not response.is_closed + + body = await response.read() + + assert body == b"Hello, world!" + assert response.body == b"Hello, world!" + assert response.is_closed + + +@pytest.mark.asyncio +async def test_stream_streaming_response(): + response = await http.request("GET", "http://example.com", stream=True) + + assert response.status_code == 200 + assert not hasattr(response, "body") + assert not response.is_closed + + body = b"" + async for part in response.stream(): + body += part + + assert body == b"Hello, world!" + assert not hasattr(response, "body") + assert response.is_closed + + +@pytest.mark.asyncio +async def test_cannot_read_after_stream_consumed(): + response = await http.request("GET", "http://example.com", stream=True) + + body = b"" + async for part in response.stream(): + body += part + + with pytest.raises(httpcore.StreamConsumed): + await response.read() + + +@pytest.mark.asyncio +async def test_cannot_read_after_response_closed(): + response = await http.request("GET", "http://example.com", stream=True) + + await response.close() + + with pytest.raises(httpcore.ResponseClosed): + await response.read()