PYTHON-1070 - Make index cache thread safe

This commit is contained in:
Bernie Hackett 2016-03-10 08:34:02 -08:00
parent cb4a80a28a
commit 264cdd8b5a
4 changed files with 142 additions and 50 deletions

View File

@ -1594,6 +1594,12 @@ class Collection(common.BaseObject):
keys = helpers._index_list(key_or_list)
name = kwargs["name"] = _gen_index_name(keys)
# Note that there is a race condition here. One thread could
# check if the index is cached and be preempted before creating
# and caching the index. This means multiple threads attempting
# to create the same index concurrently could send the index
# to the server two or more times. This has no practical impact
# other than wasted round trips.
if not self.__database.connection._cached(self.__database.name,
self.__name, name):
return self.create_index(key_or_list, cache_for, **kwargs)

View File

@ -407,6 +407,7 @@ class MongoClient(common.BaseObject):
# cache of existing indexes used by ensure_index ops
self.__index_cache = {}
self.__index_cache_lock = threading.Lock()
self.__auth_credentials = {}
super(MongoClient, self).__init__(**options)
@ -446,10 +447,14 @@ class MongoClient(common.BaseObject):
"""
cache = self.__index_cache
now = datetime.datetime.utcnow()
return (dbname in cache and
coll in cache[dbname] and
index in cache[dbname][coll] and
now < cache[dbname][coll][index])
self.__index_cache_lock.acquire()
try:
return (dbname in cache and
coll in cache[dbname] and
index in cache[dbname][coll] and
now < cache[dbname][coll][index])
finally:
self.__index_cache_lock.release()
def _cache_index(self, database, collection, index, cache_for):
"""Add an index to the index cache for ensure_index operations.
@ -457,17 +462,21 @@ class MongoClient(common.BaseObject):
now = datetime.datetime.utcnow()
expire = datetime.timedelta(seconds=cache_for) + now
if database not in self.__index_cache:
self.__index_cache[database] = {}
self.__index_cache[database][collection] = {}
self.__index_cache[database][collection][index] = expire
self.__index_cache_lock.acquire()
try:
if database not in self.__index_cache:
self.__index_cache[database] = {}
self.__index_cache[database][collection] = {}
self.__index_cache[database][collection][index] = expire
elif collection not in self.__index_cache[database]:
self.__index_cache[database][collection] = {}
self.__index_cache[database][collection][index] = expire
elif collection not in self.__index_cache[database]:
self.__index_cache[database][collection] = {}
self.__index_cache[database][collection][index] = expire
else:
self.__index_cache[database][collection][index] = expire
else:
self.__index_cache[database][collection][index] = expire
finally:
self.__index_cache_lock.release()
def _purge_index(self, database_name,
collection_name=None, index_name=None):
@ -477,22 +486,26 @@ class MongoClient(common.BaseObject):
If `collection_name` is None purge an entire database.
"""
if not database_name in self.__index_cache:
return
self.__index_cache_lock.acquire()
try:
if not database_name in self.__index_cache:
return
if collection_name is None:
del self.__index_cache[database_name]
return
if collection_name is None:
del self.__index_cache[database_name]
return
if not collection_name in self.__index_cache[database_name]:
return
if not collection_name in self.__index_cache[database_name]:
return
if index_name is None:
del self.__index_cache[database_name][collection_name]
return
if index_name is None:
del self.__index_cache[database_name][collection_name]
return
if index_name in self.__index_cache[database_name][collection_name]:
del self.__index_cache[database_name][collection_name][index_name]
if index_name in self.__index_cache[database_name][collection_name]:
del self.__index_cache[database_name][collection_name][index_name]
finally:
self.__index_cache_lock.release()
def _cache_credentials(self, source, credentials, connect=True):
"""Add credentials to the database authentication cache

View File

@ -584,6 +584,7 @@ class MongoReplicaSetClient(common.BaseObject):
self.__opts = {}
self.__seeds = set()
self.__index_cache = {}
self.__index_cache_lock = threading.Lock()
self.__auth_credentials = {}
self.__monitor = None
@ -780,10 +781,14 @@ class MongoReplicaSetClient(common.BaseObject):
"""
cache = self.__index_cache
now = datetime.datetime.utcnow()
return (dbname in cache and
coll in cache[dbname] and
index in cache[dbname][coll] and
now < cache[dbname][coll][index])
self.__index_cache_lock.acquire()
try:
return (dbname in cache and
coll in cache[dbname] and
index in cache[dbname][coll] and
now < cache[dbname][coll][index])
finally:
self.__index_cache_lock.release()
def _cache_index(self, dbase, collection, index, cache_for):
"""Add an index to the index cache for ensure_index operations.
@ -791,17 +796,21 @@ class MongoReplicaSetClient(common.BaseObject):
now = datetime.datetime.utcnow()
expire = datetime.timedelta(seconds=cache_for) + now
if dbase not in self.__index_cache:
self.__index_cache[dbase] = {}
self.__index_cache[dbase][collection] = {}
self.__index_cache[dbase][collection][index] = expire
self.__index_cache_lock.acquire()
try:
if dbase not in self.__index_cache:
self.__index_cache[dbase] = {}
self.__index_cache[dbase][collection] = {}
self.__index_cache[dbase][collection][index] = expire
elif collection not in self.__index_cache[dbase]:
self.__index_cache[dbase][collection] = {}
self.__index_cache[dbase][collection][index] = expire
elif collection not in self.__index_cache[dbase]:
self.__index_cache[dbase][collection] = {}
self.__index_cache[dbase][collection][index] = expire
else:
self.__index_cache[dbase][collection][index] = expire
else:
self.__index_cache[dbase][collection][index] = expire
finally:
self.__index_cache_lock.release()
def _purge_index(self, database_name,
collection_name=None, index_name=None):
@ -811,22 +820,26 @@ class MongoReplicaSetClient(common.BaseObject):
If `collection_name` is None purge an entire database.
"""
if not database_name in self.__index_cache:
return
self.__index_cache_lock.acquire()
try:
if not database_name in self.__index_cache:
return
if collection_name is None:
del self.__index_cache[database_name]
return
if collection_name is None:
del self.__index_cache[database_name]
return
if not collection_name in self.__index_cache[database_name]:
return
if not collection_name in self.__index_cache[database_name]:
return
if index_name is None:
del self.__index_cache[database_name][collection_name]
return
if index_name is None:
del self.__index_cache[database_name][collection_name]
return
if index_name in self.__index_cache[database_name][collection_name]:
del self.__index_cache[database_name][collection_name][index_name]
if index_name in self.__index_cache[database_name][collection_name]:
del self.__index_cache[database_name][collection_name][index_name]
finally:
self.__index_cache_lock.release()
def _cache_credentials(self, source, credentials, connect=True):
"""Add credentials to the database authentication cache

View File

@ -259,6 +259,66 @@ class TestCollection(unittest.TestCase):
ctx.exit()
self.assertEqual(None, db.test.ensure_index("goodbye"))
def test_ensure_index_threaded(self):
coll = self.db.threaded_index_creation
index_docs = []
class Indexer(threading.Thread):
def run(self):
coll.ensure_index('foo0')
coll.ensure_index('foo1')
coll.ensure_index('foo2')
index_docs.append(coll.index_information())
try:
threads = []
for _ in range(10):
t = Indexer()
t.setDaemon(True)
threads.append(t)
for thread in threads:
thread.start()
joinall(threads)
first = index_docs[0]
for index_doc in index_docs[1:]:
self.assertEqual(index_doc, first)
finally:
coll.drop()
def test_ensure_purge_index_threaded(self):
coll = self.db.threaded_index_creation
class Indexer(threading.Thread):
def run(self):
coll.ensure_index('foo')
try:
coll.drop_index('foo')
except OperationFailure:
# The index may have already been dropped.
pass
coll.ensure_index('foo')
coll.drop_indexes()
coll.ensure_index('foo')
try:
threads = []
for _ in range(10):
t = Indexer()
t.setDaemon(True)
threads.append(t)
for thread in threads:
thread.start()
joinall(threads)
self.assertTrue('foo_1' in coll.index_information())
finally:
coll.drop()
def test_ensure_unique_index_threaded(self):
coll = self.db.test_unique_threaded
coll.drop()