Requests from transport API (#1293)

* Refactoring to support instantiating requests from transport API

* Minor refactoring
This commit is contained in:
Tom Christie 2020-09-17 11:59:42 +01:00 committed by GitHub
parent 09f94edd93
commit e1f7791e97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 17 deletions

View File

@ -41,6 +41,7 @@ from ._types import (
HeaderTypes,
PrimitiveData,
QueryParamTypes,
RawURL,
RequestContent,
RequestData,
RequestFiles,
@ -60,8 +61,18 @@ from ._utils import (
class URL:
def __init__(self, url: URLTypes = "", params: QueryParamTypes = None) -> None:
if isinstance(url, str):
def __init__(
self, url: typing.Union["URL", str, RawURL] = "", params: QueryParamTypes = None
) -> None:
if isinstance(url, (str, tuple)):
if isinstance(url, tuple):
raw_scheme, raw_host, port, raw_path = url
scheme = raw_scheme.decode("ascii")
host = raw_host.decode("ascii")
port_str = "" if port is None else f":{port}"
path = raw_path.decode("ascii")
url = f"{scheme}://{host}{port_str}{path}"
try:
self._uri_reference = rfc3986.iri_reference(url).encode()
except rfc3986.exceptions.InvalidAuthority as exc:
@ -141,7 +152,7 @@ class URL:
return self._uri_reference.fragment or ""
@property
def raw(self) -> typing.Tuple[bytes, bytes, typing.Optional[int], bytes]:
def raw(self) -> RawURL:
return (
self.scheme.encode("ascii"),
self.host.encode("ascii"),
@ -585,8 +596,8 @@ class Headers(typing.MutableMapping[str, str]):
class Request:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
method: typing.Union[str, bytes],
url: typing.Union["URL", str, RawURL],
*,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
@ -597,7 +608,10 @@ class Request:
json: typing.Any = None,
stream: ContentStream = None,
):
self.method = method.upper()
if isinstance(method, bytes):
self.method = method.decode("ascii").upper()
else:
self.method = method.upper()
self.url = URL(url, params=params)
self.headers = Headers(headers)
if cookies:

View File

@ -27,6 +27,8 @@ if TYPE_CHECKING: # pragma: no cover
PrimitiveData = Optional[Union[str, int, float, bool]]
RawURL = Tuple[bytes, bytes, Optional[int], bytes]
URLTypes = Union["URL", str]
QueryParamTypes = Union[

View File

@ -36,12 +36,6 @@ class MockTransport(httpcore.SyncHTTPTransport):
stream: httpcore.SyncByteStream = None,
timeout: Mapping[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
raw_scheme, raw_host, port, raw_path = url
scheme = raw_scheme.decode("ascii")
host = raw_host.decode("ascii")
port_str = "" if port is None else f":{port}"
path = raw_path.decode("ascii")
request_headers = httpx.Headers(headers)
content = (
(item for item in stream)
@ -54,17 +48,15 @@ class MockTransport(httpcore.SyncHTTPTransport):
)
request = httpx.Request(
method=method.decode("ascii"),
url=f"{scheme}://{host}{port_str}{path}",
method=method,
url=url,
headers=request_headers,
content=content,
)
request.read()
response = self.handler(request)
return (
response.http_version.encode("ascii")
if response.http_version
else b"HTTP/1.1",
(response.http_version or "HTTP/1.1").encode("ascii"),
response.status_code,
response.reason_phrase.encode("ascii"),
response.headers.raw,