Swap auth/redirects ordering (#1267)

* Internal refactoring to swap auth/redirects ordering

* Test for auth with cross domain redirect
This commit is contained in:
Tom Christie 2020-09-10 09:12:05 +01:00 committed by GitHub
parent 016e4ee210
commit 4d950e5780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 31 deletions

View File

@ -725,8 +725,12 @@ class Client(BaseClient):
auth = self._build_request_auth(request, auth)
response = self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
response = self._send_handling_auth(
request,
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
history=[],
)
if not stream:
@ -740,23 +744,17 @@ class Client(BaseClient):
def _send_handling_redirects(
self,
request: Request,
auth: Auth,
timeout: Timeout,
allow_redirects: bool = True,
history: typing.List[Response] = None,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
if history is None:
history = []
while True:
if len(history) > self.max_redirects:
raise TooManyRedirects(
"Exceeded maximum allowed redirects.", request=request
)
response = self._send_handling_auth(
request, auth=auth, timeout=timeout, history=history
)
response = self._send_single_request(request, timeout)
response.history = list(history)
if not response.is_redirect:
@ -771,7 +769,6 @@ class Client(BaseClient):
response.call_next = functools.partial(
self._send_handling_redirects,
request=request,
auth=auth,
timeout=timeout,
allow_redirects=False,
history=history,
@ -781,16 +778,21 @@ class Client(BaseClient):
def _send_handling_auth(
self,
request: Request,
history: typing.List[Response],
auth: Auth,
timeout: Timeout,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
auth_flow = auth.sync_auth_flow(request)
request = next(auth_flow)
while True:
response = self._send_single_request(request, timeout)
response = self._send_handling_redirects(
request,
timeout=timeout,
allow_redirects=allow_redirects,
history=history,
)
try:
next_request = auth_flow.send(response)
except StopIteration:
@ -1346,8 +1348,12 @@ class AsyncClient(BaseClient):
auth = self._build_request_auth(request, auth)
response = await self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
response = await self._send_handling_auth(
request,
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
history=[],
)
if not stream:
@ -1361,23 +1367,17 @@ class AsyncClient(BaseClient):
async def _send_handling_redirects(
self,
request: Request,
auth: Auth,
timeout: Timeout,
allow_redirects: bool = True,
history: typing.List[Response] = None,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
if history is None:
history = []
while True:
if len(history) > self.max_redirects:
raise TooManyRedirects(
"Exceeded maximum allowed redirects.", request=request
)
response = await self._send_handling_auth(
request, auth=auth, timeout=timeout, history=history
)
response = await self._send_single_request(request, timeout)
response.history = list(history)
if not response.is_redirect:
@ -1392,7 +1392,6 @@ class AsyncClient(BaseClient):
response.call_next = functools.partial(
self._send_handling_redirects,
request=request,
auth=auth,
timeout=timeout,
allow_redirects=False,
history=history,
@ -1402,16 +1401,21 @@ class AsyncClient(BaseClient):
async def _send_handling_auth(
self,
request: Request,
history: typing.List[Response],
auth: Auth,
timeout: Timeout,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
auth_flow = auth.async_auth_flow(request)
request = await auth_flow.__anext__()
while True:
response = await self._send_single_request(request, timeout)
response = await self._send_handling_redirects(
request,
timeout=timeout,
allow_redirects=allow_redirects,
history=history,
)
try:
next_request = await auth_flow.asend(response)
except StopAsyncIteration:

View File

@ -323,7 +323,7 @@ def test_redirect_loop():
client.get("https://example.org/redirect_loop")
def test_cross_domain_redirect():
def test_cross_domain_redirect_with_auth_header():
client = httpx.Client(transport=SyncMockTransport())
url = "https://example.com/cross_domain"
headers = {"Authorization": "abc"}
@ -332,6 +332,14 @@ def test_cross_domain_redirect():
assert "authorization" not in response.json()["headers"]
def test_cross_domain_redirect_with_auth():
client = httpx.Client(transport=SyncMockTransport())
url = "https://example.com/cross_domain"
response = client.get(url, auth=("user", "pass"))
assert response.url == "https://example.org/cross_domain_target"
assert "authorization" not in response.json()["headers"]
def test_same_domain_redirect():
client = httpx.Client(transport=SyncMockTransport())
url = "https://example.org/cross_domain"