Support cookie persistence

This commit is contained in:
Tom Christie 2019-05-17 12:41:36 +01:00
parent 24b6923669
commit f63148aa73
3 changed files with 67 additions and 23 deletions

View File

@ -1,6 +1,5 @@
import asyncio
import typing
from http.cookiejar import CookieJar
from types import TracebackType
from .auth import HTTPBasicAuth
@ -19,6 +18,8 @@ from .interfaces import ConcurrencyBackend, Dispatcher
from .models import (
URL,
AuthTypes,
Cookies,
CookieTypes,
Headers,
HeaderTypes,
QueryParamTypes,
@ -35,6 +36,7 @@ class AsyncClient:
def __init__(
self,
auth: AuthTypes = None,
cookies: CookieTypes = None,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
@ -48,6 +50,7 @@ class AsyncClient:
)
self.auth = auth
self.cookies = Cookies(cookies)
self.max_redirects = max_redirects
self.dispatch = dispatch
@ -57,7 +60,7 @@ class AsyncClient:
*,
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -83,7 +86,7 @@ class AsyncClient:
*,
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -109,7 +112,7 @@ class AsyncClient:
*,
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
@ -136,7 +139,7 @@ class AsyncClient:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -164,7 +167,7 @@ class AsyncClient:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -192,7 +195,7 @@ class AsyncClient:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -220,7 +223,7 @@ class AsyncClient:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -249,7 +252,7 @@ class AsyncClient:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -262,7 +265,7 @@ class AsyncClient:
data=data,
query_params=query_params,
headers=headers,
cookies=cookies,
cookies=self.merge_cookies(cookies),
)
self.prepare_request(request)
response = await self.send(
@ -278,6 +281,13 @@ class AsyncClient:
def prepare_request(self, request: Request) -> None:
request.prepare()
def merge_cookies(self, cookies: CookieTypes = None) -> typing.Optional[CookieTypes]:
if cookies or self.cookies:
merged_cookies = Cookies(self.cookies)
merged_cookies.update(cookies)
return merged_cookies
return cookies
async def send(
self,
request: Request,
@ -334,6 +344,7 @@ class AsyncClient:
request, stream=stream, ssl=ssl, timeout=timeout
)
response.history = list(history)
self.cookies.extract_cookies(response)
history = [response] + history
if not response.is_redirect:
break
@ -365,7 +376,8 @@ class AsyncClient:
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url)
content = self.redirect_content(request, method)
return Request(method=method, url=url, headers=headers, data=content)
cookies = self.merge_cookies(request.cookies)
return Request(method=method, url=url, headers=headers, data=content, cookies=cookies)
def redirect_method(self, request: Request, response: Response) -> str:
"""
@ -466,6 +478,10 @@ class Client:
)
self._loop = asyncio.new_event_loop()
@property
def cookies(self) -> Cookies:
return self._client.cookies
def request(
self,
method: str,
@ -474,7 +490,7 @@ class Client:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -487,7 +503,7 @@ class Client:
data=data,
query_params=query_params,
headers=headers,
cookies=cookies,
cookies=self._client.merge_cookies(cookies),
)
self.prepare_request(request)
response = self.send(
@ -506,7 +522,7 @@ class Client:
*,
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -531,7 +547,7 @@ class Client:
*,
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -556,7 +572,7 @@ class Client:
*,
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
@ -582,7 +598,7 @@ class Client:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -609,7 +625,7 @@ class Client:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -636,7 +652,7 @@ class Client:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
@ -663,7 +679,7 @@ class Client:
data: RequestData = b"",
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieJar = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,

View File

@ -488,8 +488,8 @@ class Request:
self.url = URL(url, query_params=query_params)
self.headers = Headers(headers)
if cookies:
cookies = Cookies(cookies)
cookies.set_cookie_header(self)
self._cookies = Cookies(cookies)
self._cookies.set_cookie_header(self)
if isinstance(data, bytes):
self.is_streaming = False
@ -547,6 +547,12 @@ class Request:
for item in reversed(auto_headers):
self.headers.raw.insert(0, item)
@property
def cookies(self) -> "Cookies":
if not hasattr(self, "_cookies"):
self._cookies = Cookies()
return self._cookies
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
@ -874,7 +880,9 @@ class Cookies(MutableMapping):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, Cookies):
self.jar = cookies.jar
self.jar = CookieJar()
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
self.jar = cookies

View File

@ -104,3 +104,23 @@ def test_get_cookie():
assert response.status_code == 200
assert response.cookies["example-name"] == "example-value"
assert client.cookies["example-name"] == "example-value"
def test_cookie_persistence():
"""
Ensure that Client instances persist cookies between requests.
"""
with Client(dispatch=MockDispatch()) as client:
response = client.get("http://example.org/echo_cookies")
assert response.status_code == 200
assert json.loads(response.text) == {"cookies": None}
response = client.get("http://example.org/set_cookie")
assert response.status_code == 200
assert response.cookies["example-name"] == "example-value"
assert client.cookies["example-name"] == "example-value"
response = client.get("http://example.org/echo_cookies")
assert response.status_code == 200
assert json.loads(response.text) == {"cookies": "example-name=example-value"}