From f757fe39cc12520e299090adf750fc07eb03dec8 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 26 Mar 2024 10:18:54 -0700 Subject: [PATCH] PYTHON-4297 Allow passing arbitrary options to create_search_index/SearchIndexModel (#1561) --- doc/changelog.rst | 2 ++ pymongo/collection.py | 2 +- pymongo/operations.py | 21 +++++++++++++-------- test/test_index_management.py | 9 ++++++++- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 57808e041..647a783c4 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -39,6 +39,8 @@ PyMongo 4.7 brings a number of improvements including: - Added the :attr:`pymongo.monitoring.ConnectionCheckedOutEvent.duration`, :attr:`pymongo.monitoring.ConnectionCheckOutFailedEvent.duration`, and :attr:`pymongo.monitoring.ConnectionReadyEvent.duration` properties. +- Added the ``type`` and ``kwargs`` arguments to :class:`~pymongo.operations.SearchIndexModel` to enable + creating vector search indexes in MongoDB Atlas. Unavoidable breaking changes diff --git a/pymongo/collection.py b/pymongo/collection.py index 958f7ec21..2b771f4f6 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -2410,7 +2410,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): .. versionadded:: 4.5 """ if not isinstance(model, SearchIndexModel): - model = SearchIndexModel(model["definition"], model.get("name"), model.get("type")) + model = SearchIndexModel(**model) return self.create_search_indexes([model], session, comment, **kwargs)[0] def create_search_indexes( diff --git a/pymongo/operations.py b/pymongo/operations.py index 0240b6e17..58655753b 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -593,23 +593,28 @@ class SearchIndexModel: definition: Mapping[str, Any], name: Optional[str] = None, type: Optional[str] = "search", + **kwargs: Any, ) -> None: """Create a Search Index instance. For use with :meth:`~pymongo.collection.Collection.create_search_index` and :meth:`~pymongo.collection.Collection.create_search_indexes`. - :param definition: - The definition for this index. - :param name: - The name for this index, if present. - - .. versionadded:: 4.5 + :param definition: The definition for this index. + :param name: The name for this index, if present. + :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". + :param kwargs: Keyword arguments supplying any additional options. .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. + .. versionadded:: 4.5 + .. versionchanged:: 4.7 + Added the type and kwargs arguments. """ + self.__document: dict[str, Any] = {} if name is not None: - self.__document = dict(name=name, definition=definition) - else: - self.__document = dict(definition=definition) - self.__document["type"] = type # type: ignore[assignment] + self.__document["name"] = name + self.__document["definition"] = definition + self.__document["type"] = type + self.__document.update(kwargs) @property def document(self) -> Mapping[str, Any]: diff --git a/test/test_index_management.py b/test/test_index_management.py index b6d5c4775..25541f980 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -25,6 +25,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, unittest from test.unified_format import generate_test_classes +from test.utils import AllowListEventListener from pymongo import MongoClient from pymongo.errors import OperationFailure @@ -39,7 +40,8 @@ class TestCreateSearchIndex(IntegrationTest): def test_inputs(self): if not os.environ.get("TEST_INDEX_MANAGEMENT"): raise unittest.SkipTest("Skipping index management tests") - client = MongoClient() + listener = AllowListEventListener("createSearchIndexes") + client = MongoClient(event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.drop() @@ -55,6 +57,11 @@ class TestCreateSearchIndex(IntegrationTest): with self.assertRaises(OperationFailure): coll.create_search_index(model_kwargs) + listener.reset() + with self.assertRaises(OperationFailure): + coll.create_search_index({"definition": definition, "arbitraryOption": 1}) + self.assertIn("arbitraryOption", listener.events[0].command["indexes"][0]) + class TestSearchIndexProse(unittest.TestCase): @classmethod