Adapters
This commit is contained in:
parent
ddf1cec65a
commit
fab6fcd397
@ -1,3 +1,5 @@
|
||||
from .adapters import Adapter
|
||||
from .client import Client
|
||||
from .config import PoolLimits, SSLConfig, TimeoutConfig
|
||||
from .connection import HTTPConnection
|
||||
from .connection_pool import ConnectionPool
|
||||
|
||||
40
httpcore/adapters.py
Normal file
40
httpcore/adapters.py
Normal file
@ -0,0 +1,40 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .models import URL, Request, Response
|
||||
|
||||
|
||||
class Adapter:
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
**options: typing.Any,
|
||||
) -> Response:
|
||||
request = Request(method, url, headers=headers, body=body)
|
||||
self.prepare_request(request)
|
||||
response = await self.send(request, **options)
|
||||
return response
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def __aenter__(self) -> "Adapter":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
18
httpcore/auth.py
Normal file
18
httpcore/auth.py
Normal file
@ -0,0 +1,18 @@
|
||||
import typing
|
||||
|
||||
from .adapters import Adapter
|
||||
from .models import Request, Response
|
||||
|
||||
|
||||
class AuthAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter):
|
||||
self.dispatch = dispatch
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.dispatch.close()
|
||||
124
httpcore/client.py
Normal file
124
httpcore/client.py
Normal file
@ -0,0 +1,124 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .auth import AuthAdapter
|
||||
from .config import (
|
||||
DEFAULT_MAX_REDIRECTS,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
DEFAULT_SSL_CONFIG,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from .connection_pool import ConnectionPool
|
||||
from .cookies import CookieAdapter
|
||||
from .environment import EnvironmentAdapter
|
||||
from .models import URL, Request, Response
|
||||
from .redirects import RedirectAdapter
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
):
|
||||
connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits)
|
||||
cookie_adapter = CookieAdapter(dispatch=connection_pool)
|
||||
auth_adapter = AuthAdapter(dispatch=cookie_adapter)
|
||||
redirect_adapter = RedirectAdapter(
|
||||
dispatch=auth_adapter, max_redirects=max_redirects
|
||||
)
|
||||
self.adapter = EnvironmentAdapter(dispatch=redirect_adapter)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
stream: bool = False,
|
||||
allow_redirects: bool = True,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
request = Request(method, url, headers=headers, body=body)
|
||||
self.prepare_request(request)
|
||||
response = await self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
stream: bool = False,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
return await self.request(
|
||||
"GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
stream: bool = False,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
return await self.request(
|
||||
"POST",
|
||||
url,
|
||||
body=body,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.adapter.prepare_request(request)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
allow_redirects: bool = True,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
options = {"stream": stream} # type: typing.Dict[str, typing.Any]
|
||||
if ssl is not None:
|
||||
options["ssl"] = ssl
|
||||
if timeout is not None:
|
||||
options["timeout"] = timeout
|
||||
return await self.adapter.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.adapter.close()
|
||||
|
||||
async def __aenter__(self) -> "Client":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
@ -112,24 +112,20 @@ class TimeoutConfig:
|
||||
connect_timeout: float = None,
|
||||
read_timeout: float = None,
|
||||
write_timeout: float = None,
|
||||
pool_timeout: float = None,
|
||||
):
|
||||
if timeout is not None:
|
||||
# Specified as a single timeout value
|
||||
assert connect_timeout is None
|
||||
assert read_timeout is None
|
||||
assert write_timeout is None
|
||||
assert pool_timeout is None
|
||||
connect_timeout = timeout
|
||||
read_timeout = timeout
|
||||
write_timeout = timeout
|
||||
pool_timeout = timeout
|
||||
|
||||
self.timeout = timeout
|
||||
self.connect_timeout = connect_timeout
|
||||
self.read_timeout = read_timeout
|
||||
self.write_timeout = write_timeout
|
||||
self.pool_timeout = pool_timeout
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
@ -137,14 +133,13 @@ class TimeoutConfig:
|
||||
and self.connect_timeout == other.connect_timeout
|
||||
and self.read_timeout == other.read_timeout
|
||||
and self.write_timeout == other.write_timeout
|
||||
and self.pool_timeout == other.pool_timeout
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
if self.timeout is not None:
|
||||
return f"{class_name}(timeout={self.timeout})"
|
||||
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, pool_timeout={self.pool_timeout})"
|
||||
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
|
||||
|
||||
|
||||
class PoolLimits:
|
||||
@ -155,27 +150,29 @@ class PoolLimits:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
soft_limit: typing.Optional[int] = None,
|
||||
hard_limit: typing.Optional[int] = None,
|
||||
soft_limit: int = None,
|
||||
hard_limit: int = None,
|
||||
pool_timeout: float = None,
|
||||
):
|
||||
self.soft_limit = soft_limit
|
||||
self.hard_limit = hard_limit
|
||||
self.pool_timeout = pool_timeout
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.soft_limit == other.soft_limit
|
||||
and self.hard_limit == other.hard_limit
|
||||
and self.pool_timeout == other.pool_timeout
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return (
|
||||
f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
|
||||
)
|
||||
return f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})"
|
||||
|
||||
|
||||
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
|
||||
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
|
||||
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
|
||||
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0)
|
||||
DEFAULT_CA_BUNDLE_PATH = certifi.where()
|
||||
DEFAULT_MAX_REDIRECTS = 30
|
||||
|
||||
@ -4,18 +4,19 @@ import typing
|
||||
import h2.connection
|
||||
import h11
|
||||
|
||||
from .adapters import Adapter
|
||||
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
from .models import Client, Origin, Request, Response
|
||||
from .models import Origin, Request, Response
|
||||
from .streams import Protocol, connect
|
||||
|
||||
# Callback signature: async def callback(conn: HTTPConnection) -> None
|
||||
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
|
||||
|
||||
|
||||
class HTTPConnection(Client):
|
||||
class HTTPConnection(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
@ -30,33 +31,26 @@ class HTTPConnection(Client):
|
||||
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
|
||||
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
if self.h11_connection is None and self.h2_connection is None:
|
||||
await self.connect(ssl, timeout)
|
||||
await self.connect(**options)
|
||||
|
||||
if self.h2_connection is not None:
|
||||
response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout)
|
||||
response = await self.h2_connection.send(request, **options)
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout)
|
||||
response = await self.h11_connection.send(request, **options)
|
||||
|
||||
return response
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> None:
|
||||
if ssl is None:
|
||||
ssl = self.ssl
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
async def connect(self, **options: typing.Any) -> None:
|
||||
ssl = options.get("ssl", self.ssl)
|
||||
timeout = options.get("timeout", self.timeout)
|
||||
assert isinstance(ssl, SSLConfig)
|
||||
assert isinstance(timeout, TimeoutConfig)
|
||||
|
||||
hostname = self.origin.hostname
|
||||
port = self.origin.port
|
||||
@ -69,20 +63,10 @@ class HTTPConnection(Client):
|
||||
|
||||
reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
|
||||
if protocol == Protocol.HTTP_2:
|
||||
self.h2_connection = HTTP2Connection(
|
||||
reader,
|
||||
writer,
|
||||
origin=self.origin,
|
||||
timeout=self.timeout,
|
||||
on_release=on_release,
|
||||
)
|
||||
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
|
||||
else:
|
||||
self.h11_connection = HTTP11Connection(
|
||||
reader,
|
||||
writer,
|
||||
origin=self.origin,
|
||||
timeout=self.timeout,
|
||||
on_release=on_release,
|
||||
reader, writer, on_release=on_release
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
from .adapters import Adapter
|
||||
from .config import (
|
||||
DEFAULT_CA_BUNDLE_PATH,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
@ -12,7 +13,7 @@ from .config import (
|
||||
)
|
||||
from .connection import HTTPConnection
|
||||
from .exceptions import PoolTimeout
|
||||
from .models import Client, Origin, Request, Response
|
||||
from .models import Origin, Request, Response
|
||||
from .streams import PoolSemaphore
|
||||
|
||||
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
@ -81,7 +82,7 @@ class ConnectionStore(collections.abc.Sequence):
|
||||
return len(self.all)
|
||||
|
||||
|
||||
class ConnectionPool(Client):
|
||||
class ConnectionPool(Adapter):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -94,7 +95,7 @@ class ConnectionPool(Client):
|
||||
self.limits = limits
|
||||
self.is_closed = False
|
||||
|
||||
self.max_connections = PoolSemaphore(limits, timeout)
|
||||
self.max_connections = PoolSemaphore(limits)
|
||||
self.keepalive_connections = ConnectionStore()
|
||||
self.active_connections = ConnectionStore()
|
||||
|
||||
@ -102,31 +103,26 @@ class ConnectionPool(Client):
|
||||
def num_connections(self) -> int:
|
||||
return len(self.keepalive_connections) + len(self.active_connections)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
connection = await self.acquire_connection(request.url.origin, timeout=timeout)
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
connection = await self.acquire_connection(request.url.origin)
|
||||
try:
|
||||
response = await connection.send(request, ssl=ssl, timeout=timeout)
|
||||
response = await connection.send(request, **options)
|
||||
except BaseException as exc:
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
raise exc
|
||||
return response
|
||||
|
||||
async def acquire_connection(
|
||||
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> HTTPConnection:
|
||||
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
|
||||
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
|
||||
if connection is None:
|
||||
connection = self.keepalive_connections.pop_by_origin(origin)
|
||||
|
||||
if connection is None:
|
||||
await self.max_connections.acquire(timeout)
|
||||
await self.max_connections.acquire()
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
|
||||
18
httpcore/cookies.py
Normal file
18
httpcore/cookies.py
Normal file
@ -0,0 +1,18 @@
|
||||
import typing
|
||||
|
||||
from .adapters import Adapter
|
||||
from .models import Request, Response
|
||||
|
||||
|
||||
class CookieAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter):
|
||||
self.dispatch = dispatch
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.dispatch.close()
|
||||
27
httpcore/environment.py
Normal file
27
httpcore/environment.py
Normal file
@ -0,0 +1,27 @@
|
||||
import typing
|
||||
|
||||
from .adapters import Adapter
|
||||
from .models import Request, Response
|
||||
|
||||
|
||||
class EnvironmentAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter, trust_env: bool = True):
|
||||
self.dispatch = dispatch
|
||||
self.trust_env = trust_env
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
if self.trust_env:
|
||||
self.merge_environment_options(options)
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
|
||||
def merge_environment_options(self, options: dict) -> None:
|
||||
"""
|
||||
Add environment options.
|
||||
"""
|
||||
# TODO
|
||||
@ -28,6 +28,12 @@ class PoolTimeout(Timeout):
|
||||
"""
|
||||
|
||||
|
||||
class TooManyRedirects(Exception):
|
||||
"""
|
||||
Too many redirects.
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
"""
|
||||
Malformed HTTP.
|
||||
|
||||
@ -2,9 +2,10 @@ import typing
|
||||
|
||||
import h11
|
||||
|
||||
from .adapters import Adapter
|
||||
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout, ReadTimeout
|
||||
from .models import Client, Origin, Request, Response
|
||||
from .models import Request, Response
|
||||
from .streams import BaseReader, BaseWriter
|
||||
|
||||
H11Event = typing.Union[
|
||||
@ -25,35 +26,28 @@ OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
|
||||
|
||||
|
||||
class HTTP11Connection(Client):
|
||||
class HTTP11Connection(Adapter):
|
||||
READ_NUM_BYTES = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: BaseReader,
|
||||
writer: BaseWriter,
|
||||
origin: Origin,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Optional[OnReleaseCallback] = None,
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.origin = origin
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
timeout = options.get("timeout")
|
||||
stream = options.get("stream", False)
|
||||
assert timeout is None or isinstance(timeout, TimeoutConfig)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> Response:
|
||||
# Start sending the request.
|
||||
method = request.method.encode()
|
||||
target = request.url.full_path
|
||||
@ -81,7 +75,7 @@ class HTTP11Connection(Client):
|
||||
headers = event.headers
|
||||
body = self._body_iter(timeout)
|
||||
|
||||
return Response(
|
||||
response = Response(
|
||||
status_code=status_code,
|
||||
reason=reason,
|
||||
protocol="HTTP/1.1",
|
||||
@ -90,6 +84,26 @@ class HTTP11Connection(Client):
|
||||
on_close=self.response_closed,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
|
||||
self.h11_state.send(event)
|
||||
except h11.ProtocolError:
|
||||
# If we're in some other state then it's a premature close,
|
||||
# and we'll end up in h11.ERROR.
|
||||
pass
|
||||
|
||||
await self.writer.close()
|
||||
|
||||
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
|
||||
event = await self._receive_event(timeout)
|
||||
while isinstance(event, h11.Data):
|
||||
@ -123,14 +137,6 @@ class HTTP11Connection(Client):
|
||||
if self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
|
||||
self.h11_state.send(event)
|
||||
except h11.ProtocolError:
|
||||
# If we're in some other state then it's a premature close,
|
||||
# and we'll end up in h11.ERROR.
|
||||
pass
|
||||
|
||||
await self.writer.close()
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
|
||||
@ -4,52 +4,39 @@ import typing
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from .adapters import Adapter
|
||||
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout, ReadTimeout
|
||||
from .models import Client, Origin, Request, Response
|
||||
from .models import Request, Response
|
||||
from .streams import BaseReader, BaseWriter
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
|
||||
class HTTP2Connection(Client):
|
||||
class HTTP2Connection(Adapter):
|
||||
READ_NUM_BYTES = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: BaseReader,
|
||||
writer: BaseWriter,
|
||||
origin: Origin,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Callable = None,
|
||||
self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.origin = origin
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self.h2_state = h2.connection.H2Connection()
|
||||
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
|
||||
self.initialized = False
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> Response:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
if not self.initialized:
|
||||
self.initiate_connection()
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
timeout = options.get("timeout")
|
||||
stream = options.get("stream", False)
|
||||
assert timeout is None or isinstance(timeout, TimeoutConfig)
|
||||
|
||||
# Start sending the request.
|
||||
if not self.initialized:
|
||||
self.initiate_connection()
|
||||
stream_id = await self.send_headers(request, timeout)
|
||||
self.events[stream_id] = []
|
||||
|
||||
@ -77,7 +64,7 @@ class HTTP2Connection(Client):
|
||||
body = self.body_iter(stream_id, timeout)
|
||||
on_close = functools.partial(self.response_closed, stream_id=stream_id)
|
||||
|
||||
return Response(
|
||||
response = Response(
|
||||
status_code=status_code,
|
||||
protocol="HTTP/2",
|
||||
headers=headers,
|
||||
@ -85,6 +72,17 @@ class HTTP2Connection(Client):
|
||||
on_close=on_close,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.writer.close()
|
||||
|
||||
def initiate_connection(self) -> None:
|
||||
self.h2_state.initiate_connection()
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
@ -147,5 +145,6 @@ class HTTP2Connection(Client):
|
||||
if not self.events and self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.writer.close()
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import http
|
||||
import typing
|
||||
from types import TracebackType
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
@ -237,47 +236,6 @@ class Response:
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
|
||||
|
||||
class Client:
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
stream: bool = False,
|
||||
) -> Response:
|
||||
request = Request(method, url, headers=headers, body=body)
|
||||
response = await self.send(request, ssl=ssl, timeout=timeout)
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
return response
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def __aenter__(self) -> "Client":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
@property
|
||||
def is_redirect(self) -> bool:
|
||||
return self.status_code in (301, 302, 303, 307, 308)
|
||||
|
||||
35
httpcore/redirects.py
Normal file
35
httpcore/redirects.py
Normal file
@ -0,0 +1,35 @@
|
||||
import typing
|
||||
|
||||
from .adapters import Adapter
|
||||
from .exceptions import TooManyRedirects
|
||||
from .models import Request, Response
|
||||
|
||||
|
||||
class RedirectAdapter(Adapter):
|
||||
def __init__(self, dispatch: Adapter, max_redirects: int):
|
||||
self.dispatch = dispatch
|
||||
self.max_redirects = max_redirects
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
self.dispatch.prepare_request(request)
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
allow_redirects = options.pop("allow_redirects", True)
|
||||
history = []
|
||||
|
||||
while True:
|
||||
response = await self.dispatch.send(request, **options)
|
||||
if not allow_redirects or not response.is_redirect:
|
||||
break
|
||||
history.append(response)
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
request = self.build_redirect_request(request, response)
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
self.dispatch.close()
|
||||
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
raise NotImplementedError()
|
||||
@ -41,10 +41,7 @@ class BaseWriter:
|
||||
|
||||
|
||||
class BasePoolSemaphore:
|
||||
def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def acquire(self, timeout: OptionalTimeout = None) -> None:
|
||||
async def acquire(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def release(self) -> None:
|
||||
@ -100,9 +97,8 @@ class Writer(BaseWriter):
|
||||
|
||||
|
||||
class PoolSemaphore(BasePoolSemaphore):
|
||||
def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
|
||||
def __init__(self, limits: PoolLimits):
|
||||
self.limits = limits
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
||||
@ -114,15 +110,13 @@ class PoolSemaphore(BasePoolSemaphore):
|
||||
self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self, timeout: OptionalTimeout = None) -> None:
|
||||
async def acquire(self) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
timeout = self.limits.pool_timeout
|
||||
try:
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout)
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise PoolTimeout()
|
||||
|
||||
|
||||
@ -2,9 +2,10 @@ import asyncio
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .adapters import Adapter
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .connection_pool import ConnectionPool
|
||||
from .models import URL, Client, Response
|
||||
from .models import URL, Response
|
||||
|
||||
|
||||
class SyncResponse:
|
||||
@ -44,8 +45,8 @@ class SyncResponse:
|
||||
|
||||
|
||||
class SyncClient:
|
||||
def __init__(self, client: Client):
|
||||
self._client = client
|
||||
def __init__(self, adapter: Adapter):
|
||||
self._client = adapter
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def request(
|
||||
@ -55,20 +56,10 @@ class SyncClient:
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
stream: bool = False,
|
||||
**options: typing.Any
|
||||
) -> SyncResponse:
|
||||
response = self._loop.run_until_complete(
|
||||
self._client.request(
|
||||
method,
|
||||
url,
|
||||
headers=headers,
|
||||
body=body,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
self._client.request(method, url, headers=headers, body=body, **options)
|
||||
)
|
||||
return SyncResponse(response, self._loop)
|
||||
|
||||
|
||||
@ -5,16 +5,16 @@ import httpcore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request("GET", "http://127.0.0.1:8000/")
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request(
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request(
|
||||
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@ -22,8 +22,8 @@ async def test_post(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert response.status_code == 200
|
||||
assert not hasattr(response, "body")
|
||||
body = await response.read()
|
||||
@ -36,8 +36,8 @@ async def test_stream_request(server):
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request(
|
||||
async with httpcore.Client() as client:
|
||||
response = await client.request(
|
||||
"POST", "http://127.0.0.1:8000/", body=hello_world()
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
@ -13,13 +13,15 @@ def test_timeout_repr():
|
||||
timeout = httpcore.TimeoutConfig(read_timeout=5.0)
|
||||
assert (
|
||||
repr(timeout)
|
||||
== "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None, pool_timeout=None)"
|
||||
== "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)"
|
||||
)
|
||||
|
||||
|
||||
def test_limits_repr():
|
||||
limits = httpcore.PoolLimits(hard_limit=100)
|
||||
assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
|
||||
assert (
|
||||
repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)"
|
||||
)
|
||||
|
||||
|
||||
def test_ssl_eq():
|
||||
|
||||
@ -79,11 +79,8 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_get_request():
|
||||
server = MockServer()
|
||||
origin = httpcore.Origin("http://example.org")
|
||||
async with httpcore.HTTP2Connection(
|
||||
reader=server, writer=server, origin=origin
|
||||
) as client:
|
||||
response = await client.request("GET", "http://example.org")
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
|
||||
response = await conn.request("GET", "http://example.org")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
|
||||
|
||||
@ -91,11 +88,8 @@ async def test_http2_get_request():
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_post_request():
|
||||
server = MockServer()
|
||||
origin = httpcore.Origin("http://example.org")
|
||||
async with httpcore.HTTP2Connection(
|
||||
reader=server, writer=server, origin=origin
|
||||
) as client:
|
||||
response = await client.request("POST", "http://example.org", body=b"<data>")
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
|
||||
response = await conn.request("POST", "http://example.org", body=b"<data>")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.body) == {
|
||||
"method": "POST",
|
||||
@ -107,13 +101,10 @@ async def test_http2_post_request():
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_multiple_requests():
|
||||
server = MockServer()
|
||||
origin = httpcore.Origin("http://example.org")
|
||||
async with httpcore.HTTP2Connection(
|
||||
reader=server, writer=server, origin=origin
|
||||
) as client:
|
||||
response_1 = await client.request("GET", "http://example.org/1")
|
||||
response_2 = await client.request("GET", "http://example.org/2")
|
||||
response_3 = await client.request("GET", "http://example.org/3")
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
|
||||
response_1 = await conn.request("GET", "http://example.org/1")
|
||||
response_2 = await conn.request("GET", "http://example.org/2")
|
||||
response_3 = await conn.request("GET", "http://example.org/3")
|
||||
|
||||
assert response_1.status_code == 200
|
||||
assert json.loads(response_1.body) == {"method": "GET", "path": "/1", "body": ""}
|
||||
|
||||
@ -24,10 +24,9 @@ async def test_connect_timeout(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_timeout(server):
|
||||
timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
|
||||
limits = httpcore.PoolLimits(hard_limit=1)
|
||||
limits = httpcore.PoolLimits(hard_limit=1, pool_timeout=0.0001)
|
||||
|
||||
async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
|
||||
async with httpcore.ConnectionPool(limits=limits) as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
|
||||
with pytest.raises(httpcore.PoolTimeout):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user