From fcb7ffb25fed693bfec13d06eaa4dd25727d192f Mon Sep 17 00:00:00 2001 From: behackett Date: Mon, 11 Jul 2011 19:27:20 -0700 Subject: [PATCH] Auth improvements w/pooling and threads PYTHON-4 Also PYTHON-162 and PYTHON-189. Credit goes to James Murty for most of the patch. With this change we cache auth credentials (user, password) in the driver so that each new socket can be automatically authenticated. This solves the problem of each new spawned thread having to re-authenticate. --- pymongo/connection.py | 46 +++++++++++++++++++++--- pymongo/database.py | 42 +++++++++++++--------- test/test_connection.py | 3 ++ test/test_pooling.py | 9 +++-- test/test_threads.py | 80 ++++++++++++++++++++++++++++++++++++++++- 5 files changed, 154 insertions(+), 26 deletions(-) diff --git a/pymongo/connection.py b/pymongo/connection.py index 558570c9b..368204f75 100644 --- a/pymongo/connection.py +++ b/pymongo/connection.py @@ -137,14 +137,15 @@ class _Pool(threading.local): self.pid = pid if self.sock is not None and self.sock[0] == pid: - return self.sock[1] + return (self.sock[1], True) try: self.sock = (pid, self.sockets.pop()) + return (self.sock[1], True) except IndexError: self.sock = (pid, self.connect(host, port)) + return (self.sock[1], False) - return self.sock[1] def return_socket(self): if self.sock is not None and self.sock[0] == os.getpid(): @@ -158,7 +159,7 @@ class _Pool(threading.local): self.sock = None -class Connection(common.BaseObject): # TODO support auth for pooling +class Connection(common.BaseObject): """Connection to MongoDB. """ @@ -319,6 +320,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling # cache of existing indexes used by ensure_index ops self.__index_cache = {} + self.__auth_credentials = {} if _connect: self.__find_node() @@ -413,6 +415,38 @@ class Connection(common.BaseObject): # TODO support auth for pooling if index_name in self.__index_cache[database_name][collection_name]: del self.__index_cache[database_name][collection_name][index_name] + def _cache_credentials(self, db_name, username, password): + """Add credentials to the database authentication cache + for automatic login when a socket is created. + + If credentials are already cached for `db_name` they + will be replaced. + """ + self.__auth_credentials[db_name] = (username, password) + + def _purge_credentials(self, db_name=None): + """Purge credentials from the database authentication cache. + + If `db_name` is None purge credentials for all databases. + """ + if db_name is None: + self.__auth_credentials.clear() + elif db_name in self.__auth_credentials: + del self.__auth_credentials[db_name] + + def __authenticate_socket(self): + """Authenticate using cached database credentials. + + If credentials for the 'admin' database are available only + this database is authenticated, since this gives global access. + """ + if "admin" in self.__auth_credentials: + username, password = self.__auth_credentials["admin"] + self.admin.authenticate(username, password) + else: + for db_name, (u, p) in self.__auth_credentials.iteritems(): + self[db_name].authenticate(u, p) + @property def host(self): """Current connected host. @@ -570,7 +604,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling host, port = self.__find_node() try: - sock = self.__pool.get_socket(host, port) + sock, from_pool = self.__pool.get_socket(host, port) except socket.error: self.disconnect() raise AutoReconnect("could not connect to %s:%d" % (host, port)) @@ -578,8 +612,10 @@ class Connection(common.BaseObject): # TODO support auth for pooling if t - self.__last_checkout > 1: if _closed(sock): self.disconnect() - sock = self.__pool.get_socket(host, port) + sock, from_pool = self.__pool.get_socket(host, port) self.__last_checkout = t + if self.__auth_credentials and not from_pool: + self.__authenticate_socket() return sock def disconnect(self): diff --git a/pymongo/database.py b/pymongo/database.py index 6ec08ec38..6cd71c1d3 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -561,28 +561,28 @@ class Database(common.BaseObject): Once authenticated, the user has full read and write access to this database. Raises :class:`TypeError` if either `name` or `password` is not an instance of ``(str, - unicode)``. Authentication lasts for the life of the database - connection, or until :meth:`logout` is called. + unicode)``. Authentication lasts for the life of the underlying + :class:`~pymongo.connection.Connection`, or until :meth:`logout` + is called. The "admin" database is special. Authenticating on "admin" gives access to *all* databases. Effectively, "admin" access means root access to the database. - .. note:: Currently, authentication is per - :class:`~socket.socket`. This means that there are a couple - of situations in which re-authentication is necessary: - - - On failover (when an - :class:`~pymongo.errors.AutoReconnect` exception is - raised). - - - After a call to - :meth:`~pymongo.connection.Connection.disconnect` or - :meth:`~pymongo.connection.Connection.end_request`. + .. note:: This method authenticates the current connection, and + will also cause all new :class:`~socket.socket` connections + in the underlying :class:`~pymongo.connection.Connection` to + be authenticated automatically. - When sharing a :class:`~pymongo.connection.Connection` - between multiple threads, each thread will need to - authenticate separately. + between multiple threads, all threads will share the + authentication. If you need different authentication profiles + for different purposes (e.g. admin users) you must use + distinct instances of :class:`~pymongo.connection.Connection`. + + - To get authentication to apply immediately to all + existing sockets you may need to reset this Connection's + sockets using :meth:`~pymongo.connection.Connection.disconnect`. .. warning:: Currently, calls to :meth:`~pymongo.connection.Connection.end_request` will @@ -608,16 +608,24 @@ class Database(common.BaseObject): try: self.command("authenticate", user=unicode(name), nonce=nonce, key=key) + self.connection._cache_credentials(self.name, + unicode(name), + unicode(password)) return True except OperationFailure: return False def logout(self): - """Deauthorize use of this database for this connection. + """Deauthorize use of this database for this connection + and future connections. - Note that other databases may still be authorized. + .. note:: Other databases may still be authenticated, and other + existing :class:`~socket.socket` connections may remain + authenticated for this database unless you reset all sockets + with :meth:`~pymongo.connection.Connection.disconnect`. """ self.command("logout") + self.connection._purge_credentials(self.name) def dereference(self, dbref): """Dereference a :class:`~bson.dbref.DBRef`, getting the diff --git a/test/test_connection.py b/test/test_connection.py index f0b760aa1..9e73d1c90 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -245,6 +245,7 @@ class TestConnection(unittest.TestCase): c.pymongo_test.system.users.remove({}) c.admin.add_user("admin", "pass") + c.admin.authenticate("admin", "pass") c.pymongo_test.add_user("user", "pass") self.assertRaises(ConfigurationError, Connection, @@ -269,6 +270,8 @@ class TestConnection(unittest.TestCase): slave_okay=True).slave_okay) self.assert_(Connection("mongodb://%s:%s/?slaveok=true;w=2" % (self.host, self.port)).slave_okay) + c.admin.system.users.remove({}) + c.pymongo_test.system.users.remove({}) def test_fork(self): """Test using a connection before and after a fork. diff --git a/test/test_pooling.py b/test/test_pooling.py index a0ead2760..db78a7353 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -216,8 +216,10 @@ class TestPooling(unittest.TestCase): b_sock = b._Connection__pool.sockets[0] b.test.test.find_one() a.test.test.find_one() - self.assertEqual(b_sock, b._Connection__pool.get_socket(b.host, b.port)) - self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port)) + self.assertEqual(b_sock, + b._Connection__pool.get_socket(b.host, b.port)[0]) + self.assertEqual(a_sock, + a._Connection__pool.get_socket(a.host, a.port)[0]) def test_pool_with_fork(self): if sys.platform == "win32": @@ -269,7 +271,8 @@ class TestPooling(unittest.TestCase): self.assert_(a_sock.getsockname() != b_sock) self.assert_(a_sock.getsockname() != c_sock) self.assert_(b_sock != c_sock) - self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port)) + self.assertEqual(a_sock, + a._Connection__pool.get_socket(a.host, a.port)[0]) def test_max_pool_size(self): c = get_connection(max_pool_size=4) diff --git a/test/test_threads.py b/test/test_threads.py index 42ba52c71..bc0e78fc6 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -20,7 +20,26 @@ import threading from nose.plugins.skip import SkipTest from test_connection import get_connection -from pymongo.errors import AutoReconnect +from pymongo.errors import (AutoReconnect, + OperationFailure, + DuplicateKeyError) + + +class AutoAuthenticateThreads(threading.Thread): + + def __init__(self, collection, num): + threading.Thread.__init__(self) + self.coll = collection + self.num = num + self.success = True + + def run(self): + try: + for i in xrange(self.num): + self.coll.insert({'num':i}, safe=True) + self.coll.find_one({'num':i}) + except Exception: + self.success = False class SaveAndFind(threading.Thread): @@ -177,5 +196,64 @@ class TestThreads(unittest.TestCase): t.join() +class TestThreadsAuth(unittest.TestCase): + + def setUp(self): + self.conn = get_connection() + + # Setup auth users + self.conn.admin.system.users.remove({}) + self.conn.admin.add_user('admin-user', 'password') + try: + self.conn.admin.system.users.find_one() + # If we reach here mongod was likely started + # without --auth. Skip this test since it's + # pointless without auth enabled. + self.tearDown() + raise SkipTest() + except OperationFailure: + pass + self.conn.admin.authenticate("admin-user", "password") + self.conn.auth_test.system.users.remove({}) + self.conn.auth_test.add_user("test-user", "password") + + def tearDown(self): + # Remove auth users from databases + self.conn.admin.authenticate("admin-user", "password") + self.conn.admin.system.users.remove({}) + self.conn.auth_test.system.users.remove({}) + self.conn.drop_database('auth_test') + + def test_auto_auth_login(self): + conn = get_connection() + self.assertRaises(OperationFailure, conn.auth_test.test.find_one) + + # Admin auth + conn = get_connection() + conn.admin.authenticate("admin-user", "password") + + threads = [] + for _ in xrange(10): + t = AutoAuthenticateThreads(conn.auth_test.test, 100) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertTrue(t.success) + + # Database-specific auth + conn = get_connection() + conn.auth_test.authenticate("test-user", "password") + + threads = [] + for _ in xrange(10): + t = AutoAuthenticateThreads(conn.auth_test.test, 100) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertTrue(t.success) + + if __name__ == "__main__": unittest.main()