Add Client.auth setter (#1185)
This commit is contained in:
parent
34ba0e14b0
commit
cb620e67c7
@ -36,13 +36,13 @@
|
||||
|
||||
::: httpx.Client
|
||||
:docstring:
|
||||
:members: headers cookies params request get head options post put patch delete build_request send close
|
||||
:members: headers cookies params auth request get head options post put patch delete build_request send close
|
||||
|
||||
## `AsyncClient`
|
||||
|
||||
::: httpx.AsyncClient
|
||||
:docstring:
|
||||
:members: headers cookies params request get head options post put patch delete build_request send aclose
|
||||
:members: headers cookies params auth request get head options post put patch delete build_request send aclose
|
||||
|
||||
|
||||
## `Response`
|
||||
|
||||
@ -71,7 +71,7 @@ class BaseClient:
|
||||
):
|
||||
self._base_url = self._enforce_trailing_slash(URL(base_url))
|
||||
|
||||
self.auth = auth
|
||||
self._auth = self._build_auth(auth)
|
||||
self._params = QueryParams(params)
|
||||
self._headers = Headers(headers)
|
||||
self._cookies = Cookies(cookies)
|
||||
@ -117,6 +117,21 @@ class BaseClient:
|
||||
def timeout(self, timeout: TimeoutTypes) -> None:
|
||||
self._timeout = Timeout(timeout)
|
||||
|
||||
@property
|
||||
def auth(self) -> typing.Optional[Auth]:
|
||||
"""
|
||||
Authentication class used when none is passed at the request-level.
|
||||
|
||||
See also [Authentication][0].
|
||||
|
||||
[0]: /quickstart/#authentication
|
||||
"""
|
||||
return self._auth
|
||||
|
||||
@auth.setter
|
||||
def auth(self, auth: AuthTypes) -> None:
|
||||
self._auth = self._build_auth(auth)
|
||||
|
||||
@property
|
||||
def base_url(self) -> URL:
|
||||
"""
|
||||
@ -284,19 +299,25 @@ class BaseClient:
|
||||
return merged_queryparams
|
||||
return params
|
||||
|
||||
def _build_auth(
|
||||
def _build_auth(self, auth: AuthTypes) -> typing.Optional[Auth]:
|
||||
if auth is None:
|
||||
return None
|
||||
elif isinstance(auth, tuple):
|
||||
return BasicAuth(username=auth[0], password=auth[1])
|
||||
elif isinstance(auth, Auth):
|
||||
return auth
|
||||
elif callable(auth):
|
||||
return FunctionAuth(func=auth)
|
||||
else:
|
||||
raise TypeError('Invalid "auth" argument.')
|
||||
|
||||
def _build_request_auth(
|
||||
self, request: Request, auth: typing.Union[AuthTypes, UnsetType] = UNSET
|
||||
) -> Auth:
|
||||
auth = self.auth if isinstance(auth, UnsetType) else auth
|
||||
auth = self._auth if isinstance(auth, UnsetType) else self._build_auth(auth)
|
||||
|
||||
if auth is not None:
|
||||
if isinstance(auth, tuple):
|
||||
return BasicAuth(username=auth[0], password=auth[1])
|
||||
elif isinstance(auth, Auth):
|
||||
return auth
|
||||
elif callable(auth):
|
||||
return FunctionAuth(func=auth)
|
||||
raise TypeError('Invalid "auth" argument.')
|
||||
return auth
|
||||
|
||||
username, password = request.url.username, request.url.password
|
||||
if username or password:
|
||||
@ -667,7 +688,7 @@ class Client(BaseClient):
|
||||
"""
|
||||
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
|
||||
|
||||
auth = self._build_auth(request, auth)
|
||||
auth = self._build_request_auth(request, auth)
|
||||
|
||||
response = self._send_handling_redirects(
|
||||
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
|
||||
@ -1269,7 +1290,7 @@ class AsyncClient(BaseClient):
|
||||
"""
|
||||
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
|
||||
|
||||
auth = self._build_auth(request, auth)
|
||||
auth = self._build_request_auth(request, auth)
|
||||
|
||||
response = await self._send_handling_redirects(
|
||||
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
|
||||
|
||||
@ -9,6 +9,7 @@ from httpx import (
|
||||
URL,
|
||||
AsyncClient,
|
||||
Auth,
|
||||
BasicAuth,
|
||||
Client,
|
||||
DigestAuth,
|
||||
ProtocolError,
|
||||
@ -310,14 +311,34 @@ async def test_auth_hidden_header() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_invalid_type() -> None:
|
||||
async def test_auth_property() -> None:
|
||||
client = AsyncClient(transport=AsyncMockTransport())
|
||||
assert client.auth is None
|
||||
|
||||
client.auth = ("tomchristie", "password123") # type: ignore
|
||||
assert isinstance(client.auth, BasicAuth)
|
||||
|
||||
url = "https://example.org/"
|
||||
client = AsyncClient(
|
||||
transport=AsyncMockTransport(),
|
||||
auth="not a tuple, not a callable", # type: ignore
|
||||
)
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_invalid_type() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
await client.get(url)
|
||||
client = AsyncClient(
|
||||
transport=AsyncMockTransport(),
|
||||
auth="not a tuple, not a callable", # type: ignore
|
||||
)
|
||||
|
||||
client = AsyncClient(transport=AsyncMockTransport())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
await client.get(auth="not a tuple, not a callable") # type: ignore
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
client.auth = "not a tuple, not a callable" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
Reference in New Issue
Block a user