Swap auth/redirects ordering (#1267)
* Internal refactoring to swap auth/redirects ordering * Test for auth with cross domain redirect
This commit is contained in:
parent
016e4ee210
commit
4d950e5780
@ -725,8 +725,12 @@ class Client(BaseClient):
|
||||
|
||||
auth = self._build_request_auth(request, auth)
|
||||
|
||||
response = self._send_handling_redirects(
|
||||
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
|
||||
response = self._send_handling_auth(
|
||||
request,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
history=[],
|
||||
)
|
||||
|
||||
if not stream:
|
||||
@ -740,23 +744,17 @@ class Client(BaseClient):
|
||||
def _send_handling_redirects(
|
||||
self,
|
||||
request: Request,
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
allow_redirects: bool = True,
|
||||
history: typing.List[Response] = None,
|
||||
allow_redirects: bool,
|
||||
history: typing.List[Response],
|
||||
) -> Response:
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
while True:
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects(
|
||||
"Exceeded maximum allowed redirects.", request=request
|
||||
)
|
||||
|
||||
response = self._send_handling_auth(
|
||||
request, auth=auth, timeout=timeout, history=history
|
||||
)
|
||||
response = self._send_single_request(request, timeout)
|
||||
response.history = list(history)
|
||||
|
||||
if not response.is_redirect:
|
||||
@ -771,7 +769,6 @@ class Client(BaseClient):
|
||||
response.call_next = functools.partial(
|
||||
self._send_handling_redirects,
|
||||
request=request,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=False,
|
||||
history=history,
|
||||
@ -781,16 +778,21 @@ class Client(BaseClient):
|
||||
def _send_handling_auth(
|
||||
self,
|
||||
request: Request,
|
||||
history: typing.List[Response],
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
allow_redirects: bool,
|
||||
history: typing.List[Response],
|
||||
) -> Response:
|
||||
auth_flow = auth.sync_auth_flow(request)
|
||||
request = next(auth_flow)
|
||||
|
||||
while True:
|
||||
response = self._send_single_request(request, timeout)
|
||||
|
||||
response = self._send_handling_redirects(
|
||||
request,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
history=history,
|
||||
)
|
||||
try:
|
||||
next_request = auth_flow.send(response)
|
||||
except StopIteration:
|
||||
@ -1346,8 +1348,12 @@ class AsyncClient(BaseClient):
|
||||
|
||||
auth = self._build_request_auth(request, auth)
|
||||
|
||||
response = await self._send_handling_redirects(
|
||||
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
|
||||
response = await self._send_handling_auth(
|
||||
request,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
history=[],
|
||||
)
|
||||
|
||||
if not stream:
|
||||
@ -1361,23 +1367,17 @@ class AsyncClient(BaseClient):
|
||||
async def _send_handling_redirects(
|
||||
self,
|
||||
request: Request,
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
allow_redirects: bool = True,
|
||||
history: typing.List[Response] = None,
|
||||
allow_redirects: bool,
|
||||
history: typing.List[Response],
|
||||
) -> Response:
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
while True:
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects(
|
||||
"Exceeded maximum allowed redirects.", request=request
|
||||
)
|
||||
|
||||
response = await self._send_handling_auth(
|
||||
request, auth=auth, timeout=timeout, history=history
|
||||
)
|
||||
response = await self._send_single_request(request, timeout)
|
||||
response.history = list(history)
|
||||
|
||||
if not response.is_redirect:
|
||||
@ -1392,7 +1392,6 @@ class AsyncClient(BaseClient):
|
||||
response.call_next = functools.partial(
|
||||
self._send_handling_redirects,
|
||||
request=request,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=False,
|
||||
history=history,
|
||||
@ -1402,16 +1401,21 @@ class AsyncClient(BaseClient):
|
||||
async def _send_handling_auth(
|
||||
self,
|
||||
request: Request,
|
||||
history: typing.List[Response],
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
allow_redirects: bool,
|
||||
history: typing.List[Response],
|
||||
) -> Response:
|
||||
auth_flow = auth.async_auth_flow(request)
|
||||
request = await auth_flow.__anext__()
|
||||
|
||||
while True:
|
||||
response = await self._send_single_request(request, timeout)
|
||||
|
||||
response = await self._send_handling_redirects(
|
||||
request,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
history=history,
|
||||
)
|
||||
try:
|
||||
next_request = await auth_flow.asend(response)
|
||||
except StopAsyncIteration:
|
||||
|
||||
@ -323,7 +323,7 @@ def test_redirect_loop():
|
||||
client.get("https://example.org/redirect_loop")
|
||||
|
||||
|
||||
def test_cross_domain_redirect():
|
||||
def test_cross_domain_redirect_with_auth_header():
|
||||
client = httpx.Client(transport=SyncMockTransport())
|
||||
url = "https://example.com/cross_domain"
|
||||
headers = {"Authorization": "abc"}
|
||||
@ -332,6 +332,14 @@ def test_cross_domain_redirect():
|
||||
assert "authorization" not in response.json()["headers"]
|
||||
|
||||
|
||||
def test_cross_domain_redirect_with_auth():
|
||||
client = httpx.Client(transport=SyncMockTransport())
|
||||
url = "https://example.com/cross_domain"
|
||||
response = client.get(url, auth=("user", "pass"))
|
||||
assert response.url == "https://example.org/cross_domain_target"
|
||||
assert "authorization" not in response.json()["headers"]
|
||||
|
||||
|
||||
def test_same_domain_redirect():
|
||||
client = httpx.Client(transport=SyncMockTransport())
|
||||
url = "https://example.org/cross_domain"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user