From 81c759a3a06c8bfea84902e67bdbabb7f218a5c0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Oct 2023 14:36:27 -0700 Subject: [PATCH] PYTHON-2878 Allow passing dict to sort/create_index/hint (#1389) --- pymongo/collection.py | 6 ++--- pymongo/cursor.py | 4 +++- pymongo/helpers.py | 51 ++++++++++++++++++++++++----------------- pymongo/operations.py | 6 +++-- test/test_client.py | 22 ++++++++++++++++++ test/test_collection.py | 1 - 6 files changed, 62 insertions(+), 28 deletions(-) diff --git a/pymongo/collection.py b/pymongo/collection.py index 16b4f9b4b..7768d5f52 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -798,7 +798,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) # type: ignore[assignment] + hint = helpers._index_document(hint) update_doc["hint"] = hint command = SON([("update", self.name), ("ordered", ordered), ("updates", [update_doc])]) if let is not None: @@ -1277,7 +1277,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) # type: ignore[assignment] + hint = helpers._index_document(hint) delete_doc["hint"] = hint command = SON([("delete", self.name), ("ordered", ordered), ("deletes", [delete_doc])]) @@ -3097,7 +3097,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd["upsert"] = upsert if hint is not None: if not isinstance(hint, str): - hint = helpers._index_document(hint) # type: ignore[assignment] + hint = helpers._index_document(hint) write_concern = self._write_concern_for_cmd(cmd, session) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index be4f6e2c1..4f9234e26 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -157,7 +157,9 @@ class _ConnectionManager: self.conn = None -_Sort = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]] +_Sort = Union[ + Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] +] _Hint = Union[str, _Sort] diff --git a/pymongo/helpers.py b/pymongo/helpers.py index e8c09aaa8..6faaea2db 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -110,8 +110,10 @@ def _index_list( else: if isinstance(key_or_list, str): return [(key_or_list, ASCENDING)] - if isinstance(key_or_list, abc.ItemsView): - return list(key_or_list) + elif isinstance(key_or_list, abc.ItemsView): + return list(key_or_list) # type: ignore[arg-type] + elif isinstance(key_or_list, abc.Mapping): + return list(key_or_list.items()) elif not isinstance(key_or_list, (list, tuple)): raise TypeError("if no direction is specified, key_or_list must be an instance of list") values: list[tuple[str, int]] = [] @@ -127,33 +129,40 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]: Takes a list of (key, direction) pairs. """ - if isinstance(index_list, abc.Mapping): + if not isinstance(index_list, (list, tuple, abc.Mapping)): raise TypeError( - "passing a dict to sort/create_index/hint is not " - "allowed - use a list of tuples instead. did you " - "mean %r?" % list(index_list.items()) + "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) ) - elif not isinstance(index_list, (list, tuple)): - raise TypeError("must use a list of (key, direction) pairs, not: " + repr(index_list)) if not len(index_list): - raise ValueError("key_or_list must not be the empty list") + raise ValueError("key_or_list must not be empty") index: SON[str, Any] = SON() - for item in index_list: - if isinstance(item, str): - item = (item, ASCENDING) - key, value = item - if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") - if not isinstance(value, (str, int, abc.Mapping)): - raise TypeError( - "second item in each key pair must be 1, -1, " - "'2d', or another valid MongoDB index specifier." - ) - index[key] = value + + if isinstance(index_list, abc.Mapping): + for key in index_list: + value = index_list[key] + _validate_index_key_pair(key, value) + index[key] = value + else: + for item in index_list: + if isinstance(item, str): + item = (item, ASCENDING) + key, value = item + _validate_index_key_pair(key, value) + index[key] = value return index +def _validate_index_key_pair(key: Any, value: Any) -> None: + if not isinstance(key, str): + raise TypeError("first item in each key pair must be an instance of str") + if not isinstance(value, (str, int, abc.Mapping)): + raise TypeError( + "second item in each key pair must be 1, -1, " + "'2d', or another valid MongoDB index specifier." + ) + + def _check_command_response( response: _DocumentOut, max_wire_version: Optional[int], diff --git a/pymongo/operations.py b/pymongo/operations.py index 92d920bf0..adaca5707 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -37,8 +37,10 @@ if TYPE_CHECKING: from bson.son import SON from pymongo.bulk import _Bulk -# Hint supports index name, "myIndex", or list of either strings or index pairs: [('x', 1), ('y', -1), 'z''] -_IndexList = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]] +# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary +_IndexList = Union[ + Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] +] _IndexKeyHint = Union[str, _IndexList] diff --git a/test/test_client.py b/test/test_client.py index 5d5208043..1039031e0 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1836,6 +1836,28 @@ class TestClient(IntegrationTest): None, ) + def test_dict_hints(self): + c = rs_or_single_client() + try: + c.t.t.find(hint={"x": 1}) + except Exception: + self.fail("passing a dictionary hint to find failed!") + + def test_dict_hints_sort(self): + c = rs_or_single_client() + try: + result = c.t.t.find() + result.sort({"x": 1}) + except Exception: + self.fail("passing a dictionary to sort failed!") + + def test_dict_hints_create_index(self): + c = rs_or_single_client() + try: + c.t.t.create_index({"x": pymongo.ASCENDING}) + except Exception: + self.fail("passing a dictionary to create_index failed!") + class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" diff --git a/test/test_collection.py b/test/test_collection.py index 08cc0061f..bbaac0123 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -276,7 +276,6 @@ class TestCollection(IntegrationTest): db = self.db self.assertRaises(TypeError, db.test.create_index, 5) - self.assertRaises(TypeError, db.test.create_index, {"hello": 1}) self.assertRaises(ValueError, db.test.create_index, []) db.test.drop_indexes()