diff --git a/httpcore/__init__.py b/httpcore/__init__.py index d4f6465f..32ff66d6 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,5 +1,6 @@ from .config import PoolLimits, SSLConfig, TimeoutConfig -from .datastructures import URL, Request, Response +from .connections import Connection +from .datastructures import URL, Origin, Request, Response from .exceptions import ( BadResponse, ConnectTimeout, diff --git a/httpcore/config.py b/httpcore/config.py index e2a18b4e..8cc784b3 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -1,3 +1,6 @@ +import asyncio +import os +import ssl import typing import certifi @@ -32,6 +35,58 @@ class SSLConfig: class_name = self.__class__.__name__ return f"{class_name}(cert={self.cert}, verify={self.verify})" + async def load_ssl_context(self) -> ssl.SSLContext: + if not hasattr(self, "ssl_context"): + if not self.verify: + self.ssl_context = self.load_ssl_context_no_verify() + else: + # Run the SSL loading in a threadpool, since it makes disk accesses. + loop = asyncio.get_event_loop() + self.ssl_context = await loop.run_in_executor( + None, self.load_ssl_context_verify + ) + + return self.ssl_context + + def load_ssl_context_no_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + context.set_default_verify_paths() + return context + + def load_ssl_context_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for verified connections. + """ + if isinstance(self.verify, bool): + ca_bundle_path = DEFAULT_CA_BUNDLE_PATH + elif os.path.exists(self.verify): + ca_bundle_path = self.verify + else: + raise IOError( + "Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(self.verify) + ) + + context = ssl.create_default_context() + if os.path.isfile(ca_bundle_path): + context.load_verify_locations(cafile=ca_bundle_path) + elif os.path.isdir(ca_bundle_path): + context.load_verify_locations(capath=ca_bundle_path) + + if self.cert is not None: + if isinstance(self.cert, str): + context.load_cert_chain(certfile=self.cert) + else: + context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) + + return context + class TimeoutConfig: """ diff --git a/httpcore/connections.py b/httpcore/connections.py index 16f91408..6ddea6da 100644 --- a/httpcore/connections.py +++ b/httpcore/connections.py @@ -1,11 +1,10 @@ import asyncio -import ssl import typing import h11 -from .config import TimeoutConfig -from .datastructures import Request, Response +from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig +from .datastructures import Client, Origin, Request, Response from .exceptions import ConnectTimeout, ReadTimeout H11Event = typing.Union[ @@ -18,35 +17,48 @@ H11Event = typing.Union[ ] -class Connection: - def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None): - self.reader = None - self.writer = None - self.state = h11.Connection(our_role=h11.CLIENT) +class Connection(Client): + def __init__( + self, + origin: typing.Union[str, Origin], + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + on_release: typing.Callable = None, + ): + self.origin = Origin(origin) if isinstance(origin, str) else origin + self.ssl = ssl self.timeout = timeout self.on_release = on_release + self._reader = None + self._writer = None + self._h11_state = h11.Connection(our_role=h11.CLIENT) @property def is_closed(self) -> bool: - return self.state.our_state in (h11.CLOSED, h11.ERROR) + return self._h11_state.our_state in (h11.CLOSED, h11.ERROR) - async def open( - self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None - ) -> None: - try: - self.reader, self.writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl), - self.timeout.connect_timeout, - ) - except asyncio.TimeoutError: - raise ConnectTimeout() + async def send( + self, + request: Request, + *, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + assert request.url.origin == self.origin - async def send(self, request: Request) -> Response: + if ssl is None: + ssl = self.ssl + if timeout is None: + timeout = self.timeout + + # Make the connection + if self._reader is None: + await self._connect(ssl, timeout) + + #  Start sending the request. method = request.method.encode() target = request.url.target headers = request.headers - - #  Start sending the request. event = h11.Request(method=method, target=target, headers=headers) await self._send_event(event) @@ -64,69 +76,88 @@ class Connection: await self._send_event(event) # Start getting the response. - event = await self._receive_event() + event = await self._receive_event(timeout) if isinstance(event, h11.InformationalResponse): - event = await self._receive_event() + event = await self._receive_event(timeout) assert isinstance(event, h11.Response) - reason = event.reason.decode('latin1') + reason = event.reason.decode("latin1") status_code = event.status_code headers = event.headers - body = self._body_iter() + body = self._body_iter(timeout) return Response( - status_code=status_code, reason=reason, headers=headers, body=body, on_close=self._release + status_code=status_code, + reason=reason, + headers=headers, + body=body, + on_close=self._release, ) - async def _body_iter(self) -> typing.AsyncIterator[bytes]: - event = await self._receive_event() + async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None: + ssl_context = await ssl.load_ssl_context() if self.origin.is_secure else None + + try: + self._reader, self._writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection( + self.origin.hostname, self.origin.port, ssl=ssl_context + ), + timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]: + event = await self._receive_event(timeout) while isinstance(event, h11.Data): yield event.data - event = await self._receive_event() + event = await self._receive_event(timeout) assert isinstance(event, h11.EndOfMessage) async def _send_event(self, event: H11Event) -> None: - assert self.writer is not None + assert self._writer is not None - data = self.state.send(event) - self.writer.write(data) + data = self._h11_state.send(event) + self._writer.write(data) - async def _receive_event(self) -> H11Event: - assert self.reader is not None + async def _receive_event(self, timeout: TimeoutConfig) -> H11Event: + assert self._reader is not None - event = self.state.next_event() + event = self._h11_state.next_event() while event is h11.NEED_DATA: try: data = await asyncio.wait_for( - self.reader.read(2048), self.timeout.read_timeout + self._reader.read(2048), timeout.read_timeout ) except asyncio.TimeoutError: raise ReadTimeout() - self.state.receive_data(data) - event = self.state.next_event() + self._h11_state.receive_data(data) + event = self._h11_state.next_event() return event async def _release(self) -> None: - assert self.writer is not None + assert self._writer is not None - if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE: - self.state.start_next_cycle() + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._h11_state.start_next_cycle() else: - self.close() + await self.close() if self.on_release is not None: await self.on_release(self) - def close(self) -> None: - assert self.writer is not None - + async def close(self) -> None: event = h11.ConnectionClosed() try: # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. - self.state.send(event) + self._h11_state.send(event) except h11.ProtocolError: # If we're in some other state then it's a premature close, # and we'll end up in h11.ERROR. pass - self.writer.close() + if self._writer is not None: + self._writer.close() diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py index 49f85bdc..7389d451 100644 --- a/httpcore/datastructures.py +++ b/httpcore/datastructures.py @@ -1,7 +1,9 @@ import http import typing +from types import TracebackType from urllib.parse import urlsplit +from .config import SSLConfig, TimeoutConfig from .decoders import ( ACCEPT_ENCODING, SUPPORTED_DECODERS, @@ -61,6 +63,34 @@ class URL: def is_secure(self) -> bool: return self.components.scheme == "https" + @property + def origin(self) -> "Origin": + return Origin(self) + + +class Origin: + def __init__(self, url: typing.Union[str, URL]) -> None: + if isinstance(url, str): + url = URL(url) + self.scheme = url.scheme + self.hostname = url.hostname + self.port = url.port + + @property + def is_secure(self) -> bool: + return self.scheme == "https" + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.scheme == other.scheme + and self.hostname == other.hostname + and self.port == other.port + ) + + def __hash__(self) -> int: + return hash((self.scheme, self.hostname, self.port)) + class Request: def __init__( @@ -207,3 +237,48 @@ class Response: self.is_closed = True if self.on_close is not None: await self.on_close() + + +class Client: + async def request( + self, + method: str, + url: typing.Union[str, URL], + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + stream: bool = False, + ) -> Response: + request = Request(method, url, headers=headers, body=body) + response = await self.send(request, ssl=ssl, timeout=timeout) + if not stream: + try: + await response.read() + finally: + await response.close() + return response + + async def send( + self, + request: Request, + *, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + raise NotImplementedError() # pragma: nocover + + async def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def __aenter__(self) -> "Client": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() diff --git a/httpcore/pool.py b/httpcore/pool.py index 74c19490..a1365718 100644 --- a/httpcore/pool.py +++ b/httpcore/pool.py @@ -1,9 +1,5 @@ import asyncio -import functools -import os -import ssl import typing -from types import TracebackType from .config import ( DEFAULT_CA_BUNDLE_PATH, @@ -15,10 +11,97 @@ from .config import ( TimeoutConfig, ) from .connections import Connection -from .datastructures import URL, Request, Response +from .datastructures import Client, Origin, Request, Response from .exceptions import PoolTimeout -ConnectionKey = typing.Tuple[str, str, int, SSLConfig, TimeoutConfig] + +class ConnectionPool(Client): + def __init__( + self, + *, + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + limits: PoolLimits = DEFAULT_POOL_LIMITS, + ): + self.ssl = ssl + self.timeout = timeout + self.limits = limits + self.is_closed = False + self.num_active_connections = 0 + self.num_keepalive_connections = 0 + self._keepalive_connections = ( + {} + ) # type: typing.Dict[Origin, typing.List[Connection]] + self._max_connections = ConnectionSemaphore( + max_connections=self.limits.hard_limit + ) + + async def send( + self, + request: Request, + *, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + connection = await self.acquire_connection(request.url.origin, timeout=timeout) + response = await connection.send(request, ssl=ssl, timeout=timeout) + return response + + @property + def num_connections(self) -> int: + return self.num_active_connections + self.num_keepalive_connections + + async def acquire_connection( + self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None + ) -> Connection: + try: + connection = self._keepalive_connections[origin].pop() + if not self._keepalive_connections[origin]: + del self._keepalive_connections[origin] + self.num_keepalive_connections -= 1 + self.num_active_connections += 1 + + except (KeyError, IndexError): + if timeout is None: + pool_timeout = self.timeout.pool_timeout + else: + pool_timeout = timeout.pool_timeout + + try: + await asyncio.wait_for(self._max_connections.acquire(), pool_timeout) + except asyncio.TimeoutError: + raise PoolTimeout() + connection = Connection( + origin, + ssl=self.ssl, + timeout=self.timeout, + on_release=self.release_connection, + ) + self.num_active_connections += 1 + + return connection + + async def release_connection(self, connection: Connection) -> None: + if connection.is_closed: + self._max_connections.release() + self.num_active_connections -= 1 + elif ( + self.limits.soft_limit is not None + and self.num_connections > self.limits.soft_limit + ): + self._max_connections.release() + self.num_active_connections -= 1 + await connection.close() + else: + self.num_active_connections -= 1 + self.num_keepalive_connections += 1 + try: + self._keepalive_connections[connection.origin].append(connection) + except KeyError: + self._keepalive_connections[connection.origin] = [connection] + + async def close(self) -> None: + self.is_closed = True class ConnectionSemaphore: @@ -33,174 +116,3 @@ class ConnectionSemaphore: def release(self) -> None: if hasattr(self, "semaphore"): self.semaphore.release() - - -class ConnectionPool: - def __init__( - self, - *, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - limits: PoolLimits = DEFAULT_POOL_LIMITS, - ): - self.ssl_config = ssl - self.timeout = timeout - self.limits = limits - self.is_closed = False - self.num_active_connections = 0 - self.num_keepalive_connections = 0 - self._keepalive_connections = ( - {} - ) # type: typing.Dict[ConnectionKey, typing.List[Connection]] - self._max_connections = ConnectionSemaphore( - max_connections=self.limits.hard_limit - ) - - async def request( - self, - method: str, - url: str, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - stream: bool = False, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: - if ssl is None: - ssl = self.ssl_config - if timeout is None: - timeout = self.timeout - - parsed_url = URL(url) - request = Request(method, parsed_url, headers=headers, body=body) - connection = await self.acquire_connection(parsed_url, ssl=ssl, timeout=timeout) - response = await connection.send(request) - if not stream: - try: - await response.read() - finally: - await response.close() - return response - - @property - def num_connections(self) -> int: - return self.num_active_connections + self.num_keepalive_connections - - async def acquire_connection( - self, url: URL, ssl: SSLConfig, timeout: TimeoutConfig - ) -> Connection: - key = (url.scheme, url.hostname, url.port, ssl, timeout) - try: - connection = self._keepalive_connections[key].pop() - if not self._keepalive_connections[key]: - del self._keepalive_connections[key] - self.num_keepalive_connections -= 1 - self.num_active_connections += 1 - - except (KeyError, IndexError): - ssl_context = await self.get_ssl_context(url, ssl) - try: - await asyncio.wait_for( - self._max_connections.acquire(), timeout.pool_timeout - ) - except asyncio.TimeoutError: - raise PoolTimeout() - release = functools.partial(self.release_connection, key=key) - connection = Connection(timeout=timeout, on_release=release) - self.num_active_connections += 1 - await connection.open(url.hostname, url.port, ssl=ssl_context) - - return connection - - async def release_connection( - self, connection: Connection, key: ConnectionKey - ) -> None: - if connection.is_closed: - self._max_connections.release() - self.num_active_connections -= 1 - elif ( - self.limits.soft_limit is not None - and self.num_connections > self.limits.soft_limit - ): - self._max_connections.release() - self.num_active_connections -= 1 - connection.close() - else: - self.num_active_connections -= 1 - self.num_keepalive_connections += 1 - try: - self._keepalive_connections[key].append(connection) - except KeyError: - self._keepalive_connections[key] = [connection] - - async def get_ssl_context( - self, url: URL, config: SSLConfig - ) -> typing.Optional[ssl.SSLContext]: - if not url.is_secure: - return None - - if not hasattr(self, "ssl_context"): - if not config.verify: - self.ssl_context = self.get_ssl_context_no_verify() - else: - # Run the SSL loading in a threadpool, since it makes disk accesses. - loop = asyncio.get_event_loop() - self.ssl_context = await loop.run_in_executor( - None, self.get_ssl_context_verify - ) - - return self.ssl_context - - def get_ssl_context_no_verify(self) -> ssl.SSLContext: - """ - Return an SSL context for unverified connections. - """ - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_COMPRESSION - context.set_default_verify_paths() - return context - - def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext: - """ - Return an SSL context for verified connections. - """ - if isinstance(config.verify, bool): - ca_bundle_path = DEFAULT_CA_BUNDLE_PATH - elif os.path.exists(config.verify): - ca_bundle_path = config.verify - else: - raise IOError( - "Could not find a suitable TLS CA certificate bundle, " - "invalid path: {}".format(config.verify) - ) - - context = ssl.create_default_context() - if os.path.isfile(ca_bundle_path): - context.load_verify_locations(cafile=ca_bundle_path) - elif os.path.isdir(ca_bundle_path): - context.load_verify_locations(capath=ca_bundle_path) - - if config.cert is not None: - if isinstance(config.cert, str): - context.load_cert_chain(certfile=config.cert) - else: - context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1]) - - return context - - async def close(self) -> None: - self.is_closed = True - - async def __aenter__(self) -> "ConnectionPool": - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.close() diff --git a/httpcore/sync.py b/httpcore/sync.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_connections.py b/tests/test_connections.py new file mode 100644 index 00000000..5cfca611 --- /dev/null +++ b/tests/test_connections.py @@ -0,0 +1,20 @@ +import pytest + +import httpcore + + +@pytest.mark.asyncio +async def test_get(server): + http = httpcore.Connection(origin="http://127.0.0.1:8000/") + response = await http.request("GET", "http://127.0.0.1:8000/") + assert response.status_code == 200 + assert response.body == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_post(server): + http = httpcore.Connection(origin="http://127.0.0.1:8000/") + response = await http.request( + "POST", "http://127.0.0.1:8000/", body=b"Hello, world!" + ) + assert response.status_code == 200 diff --git a/tests/test_responses.py b/tests/test_responses.py index 3efef890..bb930bdb 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -3,26 +3,13 @@ import pytest import httpcore -class MockHTTP(httpcore.ConnectionPool): - async def request( - self, method, url, *, headers=(), body=b"", stream=False - ) -> httpcore.Response: - if stream: - - async def streaming_body(): - yield b"Hello, " - yield b"world!" - - return httpcore.Response(200, body=streaming_body()) - return httpcore.Response(200, body=b"Hello, world!") +async def streaming_body(): + yield b"Hello, " + yield b"world!" -http = MockHTTP() - - -@pytest.mark.asyncio -async def test_request(): - response = await http.request("GET", "http://example.com") +def test_response(): + response = httpcore.Response(200, body=b"Hello, world!") assert response.status_code == 200 assert response.reason == "OK" assert response.body == b"Hello, world!" @@ -31,7 +18,7 @@ async def test_request(): @pytest.mark.asyncio async def test_read_response(): - response = await http.request("GET", "http://example.com") + response = httpcore.Response(200, body=b"Hello, world!") assert response.status_code == 200 assert response.body == b"Hello, world!" @@ -45,25 +32,8 @@ async def test_read_response(): @pytest.mark.asyncio -async def test_stream_response(): - response = await http.request("GET", "http://example.com") - - assert response.status_code == 200 - assert response.body == b"Hello, world!" - assert response.is_closed - - body = b"" - async for part in response.stream(): - body += part - - assert body == b"Hello, world!" - assert response.body == b"Hello, world!" - assert response.is_closed - - -@pytest.mark.asyncio -async def test_read_streaming_response(): - response = await http.request("GET", "http://example.com", stream=True) +async def test_streaming_response(): + response = httpcore.Response(200, body=streaming_body()) assert response.status_code == 200 assert not hasattr(response, "body") @@ -76,26 +46,9 @@ async def test_read_streaming_response(): assert response.is_closed -@pytest.mark.asyncio -async def test_stream_streaming_response(): - response = await http.request("GET", "http://example.com", stream=True) - - assert response.status_code == 200 - assert not hasattr(response, "body") - assert not response.is_closed - - body = b"" - async for part in response.stream(): - body += part - - assert body == b"Hello, world!" - assert not hasattr(response, "body") - assert response.is_closed - - @pytest.mark.asyncio async def test_cannot_read_after_stream_consumed(): - response = await http.request("GET", "http://example.com", stream=True) + response = httpcore.Response(200, body=streaming_body()) body = b"" async for part in response.stream(): @@ -107,7 +60,7 @@ async def test_cannot_read_after_stream_consumed(): @pytest.mark.asyncio async def test_cannot_read_after_response_closed(): - response = await http.request("GET", "http://example.com", stream=True) + response = httpcore.Response(200, body=streaming_body()) await response.close()