PYTHON-2878 Allow passing dict to sort/create_index/hint (#1389)

This commit is contained in:
Noah Stapp 2023-10-16 14:36:27 -07:00 committed by GitHub
parent 2f13aee868
commit 81c759a3a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 28 deletions

View File

@ -798,7 +798,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
if not isinstance(hint, str):
hint = helpers._index_document(hint) # type: ignore[assignment]
hint = helpers._index_document(hint)
update_doc["hint"] = hint
command = SON([("update", self.name), ("ordered", ordered), ("updates", [update_doc])])
if let is not None:
@ -1277,7 +1277,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
if not isinstance(hint, str):
hint = helpers._index_document(hint) # type: ignore[assignment]
hint = helpers._index_document(hint)
delete_doc["hint"] = hint
command = SON([("delete", self.name), ("ordered", ordered), ("deletes", [delete_doc])])
@ -3097,7 +3097,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd["upsert"] = upsert
if hint is not None:
if not isinstance(hint, str):
hint = helpers._index_document(hint) # type: ignore[assignment]
hint = helpers._index_document(hint)
write_concern = self._write_concern_for_cmd(cmd, session)

View File

@ -157,7 +157,9 @@ class _ConnectionManager:
self.conn = None
_Sort = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]]
_Sort = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_Hint = Union[str, _Sort]

View File

@ -110,8 +110,10 @@ def _index_list(
else:
if isinstance(key_or_list, str):
return [(key_or_list, ASCENDING)]
if isinstance(key_or_list, abc.ItemsView):
return list(key_or_list)
elif isinstance(key_or_list, abc.ItemsView):
return list(key_or_list) # type: ignore[arg-type]
elif isinstance(key_or_list, abc.Mapping):
return list(key_or_list.items())
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
values: list[tuple[str, int]] = []
@ -127,33 +129,40 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]:
Takes a list of (key, direction) pairs.
"""
if isinstance(index_list, abc.Mapping):
if not isinstance(index_list, (list, tuple, abc.Mapping)):
raise TypeError(
"passing a dict to sort/create_index/hint is not "
"allowed - use a list of tuples instead. did you "
"mean %r?" % list(index_list.items())
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
)
elif not isinstance(index_list, (list, tuple)):
raise TypeError("must use a list of (key, direction) pairs, not: " + repr(index_list))
if not len(index_list):
raise ValueError("key_or_list must not be the empty list")
raise ValueError("key_or_list must not be empty")
index: SON[str, Any] = SON()
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING)
key, value = item
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
)
index[key] = value
if isinstance(index_list, abc.Mapping):
for key in index_list:
value = index_list[key]
_validate_index_key_pair(key, value)
index[key] = value
else:
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING)
key, value = item
_validate_index_key_pair(key, value)
index[key] = value
return index
def _validate_index_key_pair(key: Any, value: Any) -> None:
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
)
def _check_command_response(
response: _DocumentOut,
max_wire_version: Optional[int],

View File

@ -37,8 +37,10 @@ if TYPE_CHECKING:
from bson.son import SON
from pymongo.bulk import _Bulk
# Hint supports index name, "myIndex", or list of either strings or index pairs: [('x', 1), ('y', -1), 'z'']
_IndexList = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]]
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
_IndexList = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_IndexKeyHint = Union[str, _IndexList]

View File

@ -1836,6 +1836,28 @@ class TestClient(IntegrationTest):
None,
)
def test_dict_hints(self):
c = rs_or_single_client()
try:
c.t.t.find(hint={"x": 1})
except Exception:
self.fail("passing a dictionary hint to find failed!")
def test_dict_hints_sort(self):
c = rs_or_single_client()
try:
result = c.t.t.find()
result.sort({"x": 1})
except Exception:
self.fail("passing a dictionary to sort failed!")
def test_dict_hints_create_index(self):
c = rs_or_single_client()
try:
c.t.t.create_index({"x": pymongo.ASCENDING})
except Exception:
self.fail("passing a dictionary to create_index failed!")
class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""

View File

@ -276,7 +276,6 @@ class TestCollection(IntegrationTest):
db = self.db
self.assertRaises(TypeError, db.test.create_index, 5)
self.assertRaises(TypeError, db.test.create_index, {"hello": 1})
self.assertRaises(ValueError, db.test.create_index, [])
db.test.drop_indexes()