Support cookie persistence
This commit is contained in:
parent
24b6923669
commit
f63148aa73
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from http.cookiejar import CookieJar
|
||||
from types import TracebackType
|
||||
|
||||
from .auth import HTTPBasicAuth
|
||||
@ -19,6 +18,8 @@ from .interfaces import ConcurrencyBackend, Dispatcher
|
||||
from .models import (
|
||||
URL,
|
||||
AuthTypes,
|
||||
Cookies,
|
||||
CookieTypes,
|
||||
Headers,
|
||||
HeaderTypes,
|
||||
QueryParamTypes,
|
||||
@ -35,6 +36,7 @@ class AsyncClient:
|
||||
def __init__(
|
||||
self,
|
||||
auth: AuthTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
@ -48,6 +50,7 @@ class AsyncClient:
|
||||
)
|
||||
|
||||
self.auth = auth
|
||||
self.cookies = Cookies(cookies)
|
||||
self.max_redirects = max_redirects
|
||||
self.dispatch = dispatch
|
||||
|
||||
@ -57,7 +60,7 @@ class AsyncClient:
|
||||
*,
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -83,7 +86,7 @@ class AsyncClient:
|
||||
*,
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -109,7 +112,7 @@ class AsyncClient:
|
||||
*,
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = False, # Note: Differs to usual default.
|
||||
@ -136,7 +139,7 @@ class AsyncClient:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -164,7 +167,7 @@ class AsyncClient:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -192,7 +195,7 @@ class AsyncClient:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -220,7 +223,7 @@ class AsyncClient:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -249,7 +252,7 @@ class AsyncClient:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -262,7 +265,7 @@ class AsyncClient:
|
||||
data=data,
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
cookies=self.merge_cookies(cookies),
|
||||
)
|
||||
self.prepare_request(request)
|
||||
response = await self.send(
|
||||
@ -278,6 +281,13 @@ class AsyncClient:
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
request.prepare()
|
||||
|
||||
def merge_cookies(self, cookies: CookieTypes = None) -> typing.Optional[CookieTypes]:
|
||||
if cookies or self.cookies:
|
||||
merged_cookies = Cookies(self.cookies)
|
||||
merged_cookies.update(cookies)
|
||||
return merged_cookies
|
||||
return cookies
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
@ -334,6 +344,7 @@ class AsyncClient:
|
||||
request, stream=stream, ssl=ssl, timeout=timeout
|
||||
)
|
||||
response.history = list(history)
|
||||
self.cookies.extract_cookies(response)
|
||||
history = [response] + history
|
||||
if not response.is_redirect:
|
||||
break
|
||||
@ -365,7 +376,8 @@ class AsyncClient:
|
||||
url = self.redirect_url(request, response)
|
||||
headers = self.redirect_headers(request, url)
|
||||
content = self.redirect_content(request, method)
|
||||
return Request(method=method, url=url, headers=headers, data=content)
|
||||
cookies = self.merge_cookies(request.cookies)
|
||||
return Request(method=method, url=url, headers=headers, data=content, cookies=cookies)
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
@ -466,6 +478,10 @@ class Client:
|
||||
)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
@property
|
||||
def cookies(self) -> Cookies:
|
||||
return self._client.cookies
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
@ -474,7 +490,7 @@ class Client:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -487,7 +503,7 @@ class Client:
|
||||
data=data,
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
cookies=self._client.merge_cookies(cookies),
|
||||
)
|
||||
self.prepare_request(request)
|
||||
response = self.send(
|
||||
@ -506,7 +522,7 @@ class Client:
|
||||
*,
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -531,7 +547,7 @@ class Client:
|
||||
*,
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -556,7 +572,7 @@ class Client:
|
||||
*,
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = False, # Note: Differs to usual default.
|
||||
@ -582,7 +598,7 @@ class Client:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -609,7 +625,7 @@ class Client:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -636,7 +652,7 @@ class Client:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
@ -663,7 +679,7 @@ class Client:
|
||||
data: RequestData = b"",
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieJar = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
|
||||
@ -488,8 +488,8 @@ class Request:
|
||||
self.url = URL(url, query_params=query_params)
|
||||
self.headers = Headers(headers)
|
||||
if cookies:
|
||||
cookies = Cookies(cookies)
|
||||
cookies.set_cookie_header(self)
|
||||
self._cookies = Cookies(cookies)
|
||||
self._cookies.set_cookie_header(self)
|
||||
|
||||
if isinstance(data, bytes):
|
||||
self.is_streaming = False
|
||||
@ -547,6 +547,12 @@ class Request:
|
||||
for item in reversed(auto_headers):
|
||||
self.headers.raw.insert(0, item)
|
||||
|
||||
@property
|
||||
def cookies(self) -> "Cookies":
|
||||
if not hasattr(self, "_cookies"):
|
||||
self._cookies = Cookies()
|
||||
return self._cookies
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
url = str(self.url)
|
||||
@ -874,7 +880,9 @@ class Cookies(MutableMapping):
|
||||
for key, value in cookies.items():
|
||||
self.set(key, value)
|
||||
elif isinstance(cookies, Cookies):
|
||||
self.jar = cookies.jar
|
||||
self.jar = CookieJar()
|
||||
for cookie in cookies.jar:
|
||||
self.jar.set_cookie(cookie)
|
||||
else:
|
||||
self.jar = cookies
|
||||
|
||||
|
||||
@ -104,3 +104,23 @@ def test_get_cookie():
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies["example-name"] == "example-value"
|
||||
assert client.cookies["example-name"] == "example-value"
|
||||
|
||||
|
||||
def test_cookie_persistence():
|
||||
"""
|
||||
Ensure that Client instances persist cookies between requests.
|
||||
"""
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.get("http://example.org/echo_cookies")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.text) == {"cookies": None}
|
||||
|
||||
response = client.get("http://example.org/set_cookie")
|
||||
assert response.status_code == 200
|
||||
assert response.cookies["example-name"] == "example-value"
|
||||
assert client.cookies["example-name"] == "example-value"
|
||||
|
||||
response = client.get("http://example.org/echo_cookies")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.text) == {"cookies": "example-name=example-value"}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user