From 264cdd8b5a5db6efdee9dd5838d948067a02aa77 Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Thu, 10 Mar 2016 08:34:02 -0800 Subject: [PATCH] PYTHON-1070 - Make index cache thread safe --- pymongo/collection.py | 6 +++ pymongo/mongo_client.py | 63 +++++++++++++++++------------ pymongo/mongo_replica_set_client.py | 63 +++++++++++++++++------------ test/test_collection.py | 60 +++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 50 deletions(-) diff --git a/pymongo/collection.py b/pymongo/collection.py index cf73ae642..0074b1a6d 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -1594,6 +1594,12 @@ class Collection(common.BaseObject): keys = helpers._index_list(key_or_list) name = kwargs["name"] = _gen_index_name(keys) + # Note that there is a race condition here. One thread could + # check if the index is cached and be preempted before creating + # and caching the index. This means multiple threads attempting + # to create the same index concurrently could send the index + # to the server two or more times. This has no practical impact + # other than wasted round trips. if not self.__database.connection._cached(self.__database.name, self.__name, name): return self.create_index(key_or_list, cache_for, **kwargs) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 74f379eb2..ae75f6a59 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -407,6 +407,7 @@ class MongoClient(common.BaseObject): # cache of existing indexes used by ensure_index ops self.__index_cache = {} + self.__index_cache_lock = threading.Lock() self.__auth_credentials = {} super(MongoClient, self).__init__(**options) @@ -446,10 +447,14 @@ class MongoClient(common.BaseObject): """ cache = self.__index_cache now = datetime.datetime.utcnow() - return (dbname in cache and - coll in cache[dbname] and - index in cache[dbname][coll] and - now < cache[dbname][coll][index]) + self.__index_cache_lock.acquire() + try: + return (dbname in cache and + coll in cache[dbname] and + index in cache[dbname][coll] and + now < cache[dbname][coll][index]) + finally: + self.__index_cache_lock.release() def _cache_index(self, database, collection, index, cache_for): """Add an index to the index cache for ensure_index operations. @@ -457,17 +462,21 @@ class MongoClient(common.BaseObject): now = datetime.datetime.utcnow() expire = datetime.timedelta(seconds=cache_for) + now - if database not in self.__index_cache: - self.__index_cache[database] = {} - self.__index_cache[database][collection] = {} - self.__index_cache[database][collection][index] = expire + self.__index_cache_lock.acquire() + try: + if database not in self.__index_cache: + self.__index_cache[database] = {} + self.__index_cache[database][collection] = {} + self.__index_cache[database][collection][index] = expire - elif collection not in self.__index_cache[database]: - self.__index_cache[database][collection] = {} - self.__index_cache[database][collection][index] = expire + elif collection not in self.__index_cache[database]: + self.__index_cache[database][collection] = {} + self.__index_cache[database][collection][index] = expire - else: - self.__index_cache[database][collection][index] = expire + else: + self.__index_cache[database][collection][index] = expire + finally: + self.__index_cache_lock.release() def _purge_index(self, database_name, collection_name=None, index_name=None): @@ -477,22 +486,26 @@ class MongoClient(common.BaseObject): If `collection_name` is None purge an entire database. """ - if not database_name in self.__index_cache: - return + self.__index_cache_lock.acquire() + try: + if not database_name in self.__index_cache: + return - if collection_name is None: - del self.__index_cache[database_name] - return + if collection_name is None: + del self.__index_cache[database_name] + return - if not collection_name in self.__index_cache[database_name]: - return + if not collection_name in self.__index_cache[database_name]: + return - if index_name is None: - del self.__index_cache[database_name][collection_name] - return + if index_name is None: + del self.__index_cache[database_name][collection_name] + return - if index_name in self.__index_cache[database_name][collection_name]: - del self.__index_cache[database_name][collection_name][index_name] + if index_name in self.__index_cache[database_name][collection_name]: + del self.__index_cache[database_name][collection_name][index_name] + finally: + self.__index_cache_lock.release() def _cache_credentials(self, source, credentials, connect=True): """Add credentials to the database authentication cache diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index 19ba1e730..e30b92ff1 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -584,6 +584,7 @@ class MongoReplicaSetClient(common.BaseObject): self.__opts = {} self.__seeds = set() self.__index_cache = {} + self.__index_cache_lock = threading.Lock() self.__auth_credentials = {} self.__monitor = None @@ -780,10 +781,14 @@ class MongoReplicaSetClient(common.BaseObject): """ cache = self.__index_cache now = datetime.datetime.utcnow() - return (dbname in cache and - coll in cache[dbname] and - index in cache[dbname][coll] and - now < cache[dbname][coll][index]) + self.__index_cache_lock.acquire() + try: + return (dbname in cache and + coll in cache[dbname] and + index in cache[dbname][coll] and + now < cache[dbname][coll][index]) + finally: + self.__index_cache_lock.release() def _cache_index(self, dbase, collection, index, cache_for): """Add an index to the index cache for ensure_index operations. @@ -791,17 +796,21 @@ class MongoReplicaSetClient(common.BaseObject): now = datetime.datetime.utcnow() expire = datetime.timedelta(seconds=cache_for) + now - if dbase not in self.__index_cache: - self.__index_cache[dbase] = {} - self.__index_cache[dbase][collection] = {} - self.__index_cache[dbase][collection][index] = expire + self.__index_cache_lock.acquire() + try: + if dbase not in self.__index_cache: + self.__index_cache[dbase] = {} + self.__index_cache[dbase][collection] = {} + self.__index_cache[dbase][collection][index] = expire - elif collection not in self.__index_cache[dbase]: - self.__index_cache[dbase][collection] = {} - self.__index_cache[dbase][collection][index] = expire + elif collection not in self.__index_cache[dbase]: + self.__index_cache[dbase][collection] = {} + self.__index_cache[dbase][collection][index] = expire - else: - self.__index_cache[dbase][collection][index] = expire + else: + self.__index_cache[dbase][collection][index] = expire + finally: + self.__index_cache_lock.release() def _purge_index(self, database_name, collection_name=None, index_name=None): @@ -811,22 +820,26 @@ class MongoReplicaSetClient(common.BaseObject): If `collection_name` is None purge an entire database. """ - if not database_name in self.__index_cache: - return + self.__index_cache_lock.acquire() + try: + if not database_name in self.__index_cache: + return - if collection_name is None: - del self.__index_cache[database_name] - return + if collection_name is None: + del self.__index_cache[database_name] + return - if not collection_name in self.__index_cache[database_name]: - return + if not collection_name in self.__index_cache[database_name]: + return - if index_name is None: - del self.__index_cache[database_name][collection_name] - return + if index_name is None: + del self.__index_cache[database_name][collection_name] + return - if index_name in self.__index_cache[database_name][collection_name]: - del self.__index_cache[database_name][collection_name][index_name] + if index_name in self.__index_cache[database_name][collection_name]: + del self.__index_cache[database_name][collection_name][index_name] + finally: + self.__index_cache_lock.release() def _cache_credentials(self, source, credentials, connect=True): """Add credentials to the database authentication cache diff --git a/test/test_collection.py b/test/test_collection.py index efecd1f2f..1ab8553d8 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -259,6 +259,66 @@ class TestCollection(unittest.TestCase): ctx.exit() self.assertEqual(None, db.test.ensure_index("goodbye")) + def test_ensure_index_threaded(self): + coll = self.db.threaded_index_creation + index_docs = [] + + class Indexer(threading.Thread): + def run(self): + coll.ensure_index('foo0') + coll.ensure_index('foo1') + coll.ensure_index('foo2') + index_docs.append(coll.index_information()) + + try: + threads = [] + for _ in range(10): + t = Indexer() + t.setDaemon(True) + threads.append(t) + + for thread in threads: + thread.start() + + joinall(threads) + + first = index_docs[0] + for index_doc in index_docs[1:]: + self.assertEqual(index_doc, first) + finally: + coll.drop() + + def test_ensure_purge_index_threaded(self): + coll = self.db.threaded_index_creation + + class Indexer(threading.Thread): + def run(self): + coll.ensure_index('foo') + try: + coll.drop_index('foo') + except OperationFailure: + # The index may have already been dropped. + pass + coll.ensure_index('foo') + coll.drop_indexes() + coll.ensure_index('foo') + + try: + threads = [] + for _ in range(10): + t = Indexer() + t.setDaemon(True) + threads.append(t) + + for thread in threads: + thread.start() + + joinall(threads) + + self.assertTrue('foo_1' in coll.index_information()) + finally: + coll.drop() + def test_ensure_unique_index_threaded(self): coll = self.db.test_unique_threaded coll.drop()