Allow lists in query params (#386)

This commit is contained in:
Florimond Manca 2019-10-08 22:12:04 +02:00 committed by Seth Michael Larson
parent 31730e7095
commit 57ae7ea22b
3 changed files with 47 additions and 6 deletions

View File

@ -32,6 +32,7 @@ from .exceptions import (
from .multipart import multipart_encode
from .status_codes import StatusCode
from .utils import (
flatten_queryparams,
guess_json_utf,
is_known_encoding,
normalize_header_key,
@ -51,7 +52,7 @@ URLTypes = typing.Union["URL", str]
QueryParamTypes = typing.Union[
"QueryParams",
typing.Mapping[str, PrimitiveData],
typing.Mapping[str, typing.Union[PrimitiveData, typing.Sequence[PrimitiveData]]],
typing.List[typing.Tuple[str, PrimitiveData]],
str,
]
@ -311,14 +312,15 @@ class QueryParams(typing.Mapping[str, str]):
value = args[0] if args else kwargs
items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
if isinstance(value, str):
items = parse_qsl(value)
elif isinstance(value, QueryParams):
items = value.multi_items()
elif isinstance(value, list):
items = value # type: ignore
items = value
else:
items = value.items() # type: ignore
items = flatten_queryparams(value)
self._list = [(str(k), str_query_param(v)) for k, v in items]
self._dict = {str(k): str_query_param(v) for k, v in items}

View File

@ -1,4 +1,5 @@
import codecs
import collections
import logging
import netrc
import os
@ -11,6 +12,9 @@ from time import perf_counter
from types import TracebackType
from urllib.request import getproxies
if typing.TYPE_CHECKING: # pragma: no cover
from .models import PrimitiveData
def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
"""
@ -30,7 +34,7 @@ def normalize_header_value(value: typing.AnyStr, encoding: str = None) -> bytes:
return value.encode(encoding or "ascii")
def str_query_param(value: typing.Optional[typing.Union[str, int, float, bool]]) -> str:
def str_query_param(value: "PrimitiveData") -> str:
"""
Coerce a primitive data type into a string value for query params.
@ -256,6 +260,31 @@ def unquote(value: str) -> str:
return value[1:-1] if value[0] == value[-1] == '"' else value
def flatten_queryparams(
queryparams: typing.Mapping[
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
]
) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
"""
Convert a mapping of query params into a flat list of two-tuples
representing each item.
Example:
>>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
[("q", "httpx), ("tag", "python"), ("tag", "dev")]
"""
items = []
for k, v in queryparams.items():
if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
for u in v:
items.append((k, u))
else:
items.append((k, typing.cast("PrimitiveData", v)))
return items
class ElapsedTimer:
def __init__(self) -> None:
self.start: float = perf_counter()

View File

@ -1,8 +1,18 @@
import pytest
from httpx import QueryParams
def test_queryparams():
q = QueryParams("a=123&a=456&b=789")
@pytest.mark.parametrize(
"source",
[
"a=123&a=456&b=789",
{"a": ["123", "456"], "b": 789},
{"a": ("123", "456"), "b": 789},
],
)
def test_queryparams(source):
q = QueryParams(source)
assert "a" in q
assert "A" not in q
assert "c" not in q