Add Client.auth setter (#1185)

This commit is contained in:
Florimond Manca 2020-08-17 14:51:52 +02:00 committed by GitHub
parent 34ba0e14b0
commit cb620e67c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 20 deletions

View File

@ -36,13 +36,13 @@
::: httpx.Client
:docstring:
:members: headers cookies params request get head options post put patch delete build_request send close
:members: headers cookies params auth request get head options post put patch delete build_request send close
## `AsyncClient`
::: httpx.AsyncClient
:docstring:
:members: headers cookies params request get head options post put patch delete build_request send aclose
:members: headers cookies params auth request get head options post put patch delete build_request send aclose
## `Response`

View File

@ -71,7 +71,7 @@ class BaseClient:
):
self._base_url = self._enforce_trailing_slash(URL(base_url))
self.auth = auth
self._auth = self._build_auth(auth)
self._params = QueryParams(params)
self._headers = Headers(headers)
self._cookies = Cookies(cookies)
@ -117,6 +117,21 @@ class BaseClient:
def timeout(self, timeout: TimeoutTypes) -> None:
self._timeout = Timeout(timeout)
@property
def auth(self) -> typing.Optional[Auth]:
"""
Authentication class used when none is passed at the request-level.
See also [Authentication][0].
[0]: /quickstart/#authentication
"""
return self._auth
@auth.setter
def auth(self, auth: AuthTypes) -> None:
self._auth = self._build_auth(auth)
@property
def base_url(self) -> URL:
"""
@ -284,19 +299,25 @@ class BaseClient:
return merged_queryparams
return params
def _build_auth(
def _build_auth(self, auth: AuthTypes) -> typing.Optional[Auth]:
if auth is None:
return None
elif isinstance(auth, tuple):
return BasicAuth(username=auth[0], password=auth[1])
elif isinstance(auth, Auth):
return auth
elif callable(auth):
return FunctionAuth(func=auth)
else:
raise TypeError('Invalid "auth" argument.')
def _build_request_auth(
self, request: Request, auth: typing.Union[AuthTypes, UnsetType] = UNSET
) -> Auth:
auth = self.auth if isinstance(auth, UnsetType) else auth
auth = self._auth if isinstance(auth, UnsetType) else self._build_auth(auth)
if auth is not None:
if isinstance(auth, tuple):
return BasicAuth(username=auth[0], password=auth[1])
elif isinstance(auth, Auth):
return auth
elif callable(auth):
return FunctionAuth(func=auth)
raise TypeError('Invalid "auth" argument.')
return auth
username, password = request.url.username, request.url.password
if username or password:
@ -667,7 +688,7 @@ class Client(BaseClient):
"""
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
auth = self._build_auth(request, auth)
auth = self._build_request_auth(request, auth)
response = self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
@ -1269,7 +1290,7 @@ class AsyncClient(BaseClient):
"""
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
auth = self._build_auth(request, auth)
auth = self._build_request_auth(request, auth)
response = await self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,

View File

@ -9,6 +9,7 @@ from httpx import (
URL,
AsyncClient,
Auth,
BasicAuth,
Client,
DigestAuth,
ProtocolError,
@ -310,14 +311,34 @@ async def test_auth_hidden_header() -> None:
@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
async def test_auth_property() -> None:
client = AsyncClient(transport=AsyncMockTransport())
assert client.auth is None
client.auth = ("tomchristie", "password123") # type: ignore
assert isinstance(client.auth, BasicAuth)
url = "https://example.org/"
client = AsyncClient(
transport=AsyncMockTransport(),
auth="not a tuple, not a callable", # type: ignore
)
response = await client.get(url)
assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
with pytest.raises(TypeError):
await client.get(url)
client = AsyncClient(
transport=AsyncMockTransport(),
auth="not a tuple, not a callable", # type: ignore
)
client = AsyncClient(transport=AsyncMockTransport())
with pytest.raises(TypeError):
await client.get(auth="not a tuple, not a callable") # type: ignore
with pytest.raises(TypeError):
client.auth = "not a tuple, not a callable" # type: ignore
@pytest.mark.asyncio