Merge pull request #59 from encode/use-stdlib-status-codes

Use stdlib status codes
This commit is contained in:
Tom Christie 2019-05-13 14:45:01 +01:00 committed by GitHub
commit afc836939c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 44 additions and 104 deletions

View File

@ -1,6 +1,5 @@
from .client import AsyncClient, Client
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .constants import Protocol, codes
from .dispatch.connection import HTTPConnection
from .dispatch.connection_pool import ConnectionPool
from .dispatch.http2 import HTTP2Connection
@ -20,7 +19,8 @@ from .exceptions import (
Timeout,
TooManyRedirects,
)
from .interfaces import BaseReader, BaseWriter, Dispatcher
from .interfaces import BaseReader, BaseWriter, Dispatcher, Protocol
from .models import URL, Headers, Origin, QueryParams, Request, Response
from .status_codes import codes
__version__ = "0.2.1"

View File

@ -11,7 +11,6 @@ from .config import (
SSLConfig,
TimeoutConfig,
)
from .constants import codes
from .dispatch.connection_pool import ConnectionPool
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from .interfaces import Dispatcher
@ -26,6 +25,7 @@ from .models import (
SyncResponse,
URLTypes,
)
from .status_codes import codes
class AsyncClient:
@ -300,17 +300,17 @@ class AsyncClient:
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.see_other and method != "HEAD":
if response.status_code == codes.SEE_OTHER and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.found and method != "HEAD":
if response.status_code == codes.FOUND and method != "HEAD":
method = "GET"
# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.moved_permanently and method == "POST":
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
method = "GET"
return method

View File

@ -13,9 +13,8 @@ import ssl
import typing
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .constants import Protocol
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter
from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter, Protocol
OptionalTimeout = typing.Optional[TimeoutConfig]

View File

@ -1,67 +0,0 @@
import enum
class Protocol(str, enum.Enum):
HTTP_11 = "HTTP/1.1"
HTTP_2 = "HTTP/2"
codes = enum.IntEnum(
"StatusCode",
[
("continue", 100),
("switching_protocols", 101),
("processing", 102),
("ok", 200),
("created", 201),
("accepted", 202),
("non_authoritative_information", 203),
("no_content", 204),
("reset_content", 205),
("partial_content", 206),
("multi_status", 207),
("already_reported", 208),
("im_used", 226),
("multiple_choices", 300),
("moved_permanently", 301),
("found", 302),
("see_other", 303),
("not_modified", 304),
("use_proxy", 305),
("temporary_redirect", 307),
("permanent_redirect", 308),
("bad_request", 400),
("unauthorized", 401),
("payment_required", 402),
("forbidden", 403),
("not_found", 404),
("method_not_allowed", 405),
("not_acceptable", 406),
("proxy_authentication_required", 407),
("request_timeout", 408),
("conflict", 409),
("gone", 410),
("length_required", 411),
("precondition_failed", 412),
("request_entity_too_large", 413),
("request_uri_too_long", 414),
("unsupported_media_type", 415),
("requested_range_not_satisfiable", 416),
("expectation_failed", 417),
("unprocessable_entity", 422),
("locked", 423),
("failed_dependency", 424),
("precondition_required", 428),
("too_many_requests", 429),
("request_header_fields_too_large", 431),
("unavailable_for_legal_reasons", 451),
("internal_server_error", 500),
("not_implemented", 501),
("bad_gateway", 502),
("service_unavailable", 503),
("gateway_timeout", 504),
("http_version_not_supported", 505),
("insufficient_storage", 507),
("network_authentication_required", 511),
],
)

View File

@ -11,9 +11,8 @@ from ..config import (
SSLConfig,
TimeoutConfig,
)
from ..constants import Protocol
from ..exceptions import ConnectTimeout
from ..interfaces import Dispatcher
from ..interfaces import Dispatcher, Protocol
from ..models import Origin, Request, Response
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection

View File

@ -1,3 +1,4 @@
import enum
import typing
from types import TracebackType
@ -16,6 +17,11 @@ from .models import (
OptionalTimeout = typing.Optional[TimeoutConfig]
class Protocol(str, enum.Enum):
HTTP_11 = "HTTP/1.1"
HTTP_2 = "HTTP/2"
class Dispatcher:
"""
The base class for all adapter or dispatcher classes.

View File

@ -8,7 +8,6 @@ import idna
import rfc3986
from .config import SSLConfig, TimeoutConfig
from .constants import codes
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
@ -23,6 +22,7 @@ from .exceptions import (
ResponseNotRead,
StreamConsumed,
)
from .status_codes import codes
from .utils import (
get_reason_phrase,
is_known_encoding,
@ -715,11 +715,11 @@ class Response:
return (
self.status_code
in (
codes.moved_permanently,
codes.found,
codes.see_other,
codes.temporary_redirect,
codes.permanent_redirect,
codes.MOVED_PERMANENTLY,
codes.FOUND,
codes.SEE_OTHER,
codes.TEMPORARY_REDIRECT,
codes.PERMANENT_REDIRECT,
)
and "location" in self.headers
)

3
httpcore/status_codes.py Normal file
View File

@ -0,0 +1,3 @@
from http import HTTPStatus
codes = HTTPStatus

View File

@ -27,33 +27,33 @@ class MockDispatch(Dispatcher):
timeout: TimeoutConfig = None,
) -> Response:
if request.url.path == "/redirect_301":
status_code = codes.moved_permanently
status_code = codes.MOVED_PERMANENTLY
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_302":
status_code = codes.found
status_code = codes.FOUND
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_303":
status_code = codes.see_other
status_code = codes.SEE_OTHER
headers = {"location": "https://example.org/"}
return Response(status_code, headers=headers, request=request)
elif request.url.path == "/relative_redirect":
headers = {"location": "/"}
return Response(codes.see_other, headers=headers, request=request)
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/no_scheme_redirect":
headers = {"location": "//example.org/"}
return Response(codes.see_other, headers=headers, request=request)
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
count = int(params.get("count", "0")[0])
redirect_count = count - 1
code = codes.see_other if count else codes.ok
code = codes.SEE_OTHER if count else codes.OK
location = "/multiple_redirects"
if redirect_count:
location += "?count=" + str(redirect_count)
@ -62,35 +62,35 @@ class MockDispatch(Dispatcher):
if request.url.path == "/redirect_loop":
headers = {"location": "/redirect_loop"}
return Response(codes.see_other, headers=headers, request=request)
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/cross_domain":
headers = {"location": "https://example.org/cross_domain_target"}
return Response(codes.see_other, headers=headers, request=request)
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/cross_domain_target":
headers = dict(request.headers.items())
content = json.dumps({"headers": headers}).encode()
return Response(codes.ok, content=content, request=request)
return Response(codes.OK, content=content, request=request)
elif request.url.path == "/redirect_body":
await request.read()
headers = {"location": "/redirect_body_target"}
return Response(codes.permanent_redirect, headers=headers, request=request)
return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
elif request.url.path == "/redirect_body_target":
content = await request.read()
body = json.dumps({"body": content.decode()}).encode()
return Response(codes.ok, content=body, request=request)
return Response(codes.OK, content=body, request=request)
return Response(codes.ok, content=b"Hello, world!", request=request)
return Response(codes.OK, content=b"Hello, world!", request=request)
@pytest.mark.asyncio
async def test_redirect_301():
client = AsyncClient(dispatch=MockDispatch())
response = await client.request("POST", "https://example.org/redirect_301")
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@ -99,7 +99,7 @@ async def test_redirect_301():
async def test_redirect_302():
client = AsyncClient(dispatch=MockDispatch())
response = await client.request("POST", "https://example.org/redirect_302")
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@ -108,7 +108,7 @@ async def test_redirect_302():
async def test_redirect_303():
client = AsyncClient(dispatch=MockDispatch())
response = await client.request("GET", "https://example.org/redirect_303")
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@ -119,12 +119,12 @@ async def test_disallow_redirects():
response = await client.request(
"POST", "https://example.org/redirect_303", allow_redirects=False
)
assert response.status_code == codes.see_other
assert response.status_code == codes.SEE_OTHER
assert response.url == URL("https://example.org/redirect_303")
assert len(response.history) == 0
response = await response.next()
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@ -133,7 +133,7 @@ async def test_disallow_redirects():
async def test_relative_redirect():
client = AsyncClient(dispatch=MockDispatch())
response = await client.request("GET", "https://example.org/relative_redirect")
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@ -142,7 +142,7 @@ async def test_relative_redirect():
async def test_no_scheme_redirect():
client = AsyncClient(dispatch=MockDispatch())
response = await client.request("GET", "https://example.org/no_scheme_redirect")
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@ -152,7 +152,7 @@ async def test_fragment_redirect():
client = AsyncClient(dispatch=MockDispatch())
url = "https://example.org/relative_redirect#fragment"
response = await client.request("GET", url)
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/#fragment")
assert len(response.history) == 1
@ -162,7 +162,7 @@ async def test_multiple_redirects():
client = AsyncClient(dispatch=MockDispatch())
url = "https://example.org/multiple_redirects?count=20"
response = await client.request("GET", url)
assert response.status_code == codes.ok
assert response.status_code == codes.OK
assert response.url == URL("https://example.org/multiple_redirects")
assert len(response.history) == 20