diff --git a/pymongo/connection.py b/pymongo/connection.py index 558570c9b..368204f75 100644 --- a/pymongo/connection.py +++ b/pymongo/connection.py @@ -137,14 +137,15 @@ class _Pool(threading.local): self.pid = pid if self.sock is not None and self.sock[0] == pid: - return self.sock[1] + return (self.sock[1], True) try: self.sock = (pid, self.sockets.pop()) + return (self.sock[1], True) except IndexError: self.sock = (pid, self.connect(host, port)) + return (self.sock[1], False) - return self.sock[1] def return_socket(self): if self.sock is not None and self.sock[0] == os.getpid(): @@ -158,7 +159,7 @@ class _Pool(threading.local): self.sock = None -class Connection(common.BaseObject): # TODO support auth for pooling +class Connection(common.BaseObject): """Connection to MongoDB. """ @@ -319,6 +320,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling # cache of existing indexes used by ensure_index ops self.__index_cache = {} + self.__auth_credentials = {} if _connect: self.__find_node() @@ -413,6 +415,38 @@ class Connection(common.BaseObject): # TODO support auth for pooling if index_name in self.__index_cache[database_name][collection_name]: del self.__index_cache[database_name][collection_name][index_name] + def _cache_credentials(self, db_name, username, password): + """Add credentials to the database authentication cache + for automatic login when a socket is created. + + If credentials are already cached for `db_name` they + will be replaced. + """ + self.__auth_credentials[db_name] = (username, password) + + def _purge_credentials(self, db_name=None): + """Purge credentials from the database authentication cache. + + If `db_name` is None purge credentials for all databases. + """ + if db_name is None: + self.__auth_credentials.clear() + elif db_name in self.__auth_credentials: + del self.__auth_credentials[db_name] + + def __authenticate_socket(self): + """Authenticate using cached database credentials. + + If credentials for the 'admin' database are available only + this database is authenticated, since this gives global access. + """ + if "admin" in self.__auth_credentials: + username, password = self.__auth_credentials["admin"] + self.admin.authenticate(username, password) + else: + for db_name, (u, p) in self.__auth_credentials.iteritems(): + self[db_name].authenticate(u, p) + @property def host(self): """Current connected host. @@ -570,7 +604,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling host, port = self.__find_node() try: - sock = self.__pool.get_socket(host, port) + sock, from_pool = self.__pool.get_socket(host, port) except socket.error: self.disconnect() raise AutoReconnect("could not connect to %s:%d" % (host, port)) @@ -578,8 +612,10 @@ class Connection(common.BaseObject): # TODO support auth for pooling if t - self.__last_checkout > 1: if _closed(sock): self.disconnect() - sock = self.__pool.get_socket(host, port) + sock, from_pool = self.__pool.get_socket(host, port) self.__last_checkout = t + if self.__auth_credentials and not from_pool: + self.__authenticate_socket() return sock def disconnect(self): diff --git a/pymongo/database.py b/pymongo/database.py index 6ec08ec38..6cd71c1d3 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -561,28 +561,28 @@ class Database(common.BaseObject): Once authenticated, the user has full read and write access to this database. Raises :class:`TypeError` if either `name` or `password` is not an instance of ``(str, - unicode)``. Authentication lasts for the life of the database - connection, or until :meth:`logout` is called. + unicode)``. Authentication lasts for the life of the underlying + :class:`~pymongo.connection.Connection`, or until :meth:`logout` + is called. The "admin" database is special. Authenticating on "admin" gives access to *all* databases. Effectively, "admin" access means root access to the database. - .. note:: Currently, authentication is per - :class:`~socket.socket`. This means that there are a couple - of situations in which re-authentication is necessary: - - - On failover (when an - :class:`~pymongo.errors.AutoReconnect` exception is - raised). - - - After a call to - :meth:`~pymongo.connection.Connection.disconnect` or - :meth:`~pymongo.connection.Connection.end_request`. + .. note:: This method authenticates the current connection, and + will also cause all new :class:`~socket.socket` connections + in the underlying :class:`~pymongo.connection.Connection` to + be authenticated automatically. - When sharing a :class:`~pymongo.connection.Connection` - between multiple threads, each thread will need to - authenticate separately. + between multiple threads, all threads will share the + authentication. If you need different authentication profiles + for different purposes (e.g. admin users) you must use + distinct instances of :class:`~pymongo.connection.Connection`. + + - To get authentication to apply immediately to all + existing sockets you may need to reset this Connection's + sockets using :meth:`~pymongo.connection.Connection.disconnect`. .. warning:: Currently, calls to :meth:`~pymongo.connection.Connection.end_request` will @@ -608,16 +608,24 @@ class Database(common.BaseObject): try: self.command("authenticate", user=unicode(name), nonce=nonce, key=key) + self.connection._cache_credentials(self.name, + unicode(name), + unicode(password)) return True except OperationFailure: return False def logout(self): - """Deauthorize use of this database for this connection. + """Deauthorize use of this database for this connection + and future connections. - Note that other databases may still be authorized. + .. note:: Other databases may still be authenticated, and other + existing :class:`~socket.socket` connections may remain + authenticated for this database unless you reset all sockets + with :meth:`~pymongo.connection.Connection.disconnect`. """ self.command("logout") + self.connection._purge_credentials(self.name) def dereference(self, dbref): """Dereference a :class:`~bson.dbref.DBRef`, getting the diff --git a/test/test_connection.py b/test/test_connection.py index f0b760aa1..9e73d1c90 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -245,6 +245,7 @@ class TestConnection(unittest.TestCase): c.pymongo_test.system.users.remove({}) c.admin.add_user("admin", "pass") + c.admin.authenticate("admin", "pass") c.pymongo_test.add_user("user", "pass") self.assertRaises(ConfigurationError, Connection, @@ -269,6 +270,8 @@ class TestConnection(unittest.TestCase): slave_okay=True).slave_okay) self.assert_(Connection("mongodb://%s:%s/?slaveok=true;w=2" % (self.host, self.port)).slave_okay) + c.admin.system.users.remove({}) + c.pymongo_test.system.users.remove({}) def test_fork(self): """Test using a connection before and after a fork. diff --git a/test/test_pooling.py b/test/test_pooling.py index a0ead2760..db78a7353 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -216,8 +216,10 @@ class TestPooling(unittest.TestCase): b_sock = b._Connection__pool.sockets[0] b.test.test.find_one() a.test.test.find_one() - self.assertEqual(b_sock, b._Connection__pool.get_socket(b.host, b.port)) - self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port)) + self.assertEqual(b_sock, + b._Connection__pool.get_socket(b.host, b.port)[0]) + self.assertEqual(a_sock, + a._Connection__pool.get_socket(a.host, a.port)[0]) def test_pool_with_fork(self): if sys.platform == "win32": @@ -269,7 +271,8 @@ class TestPooling(unittest.TestCase): self.assert_(a_sock.getsockname() != b_sock) self.assert_(a_sock.getsockname() != c_sock) self.assert_(b_sock != c_sock) - self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port)) + self.assertEqual(a_sock, + a._Connection__pool.get_socket(a.host, a.port)[0]) def test_max_pool_size(self): c = get_connection(max_pool_size=4) diff --git a/test/test_threads.py b/test/test_threads.py index 42ba52c71..bc0e78fc6 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -20,7 +20,26 @@ import threading from nose.plugins.skip import SkipTest from test_connection import get_connection -from pymongo.errors import AutoReconnect +from pymongo.errors import (AutoReconnect, + OperationFailure, + DuplicateKeyError) + + +class AutoAuthenticateThreads(threading.Thread): + + def __init__(self, collection, num): + threading.Thread.__init__(self) + self.coll = collection + self.num = num + self.success = True + + def run(self): + try: + for i in xrange(self.num): + self.coll.insert({'num':i}, safe=True) + self.coll.find_one({'num':i}) + except Exception: + self.success = False class SaveAndFind(threading.Thread): @@ -177,5 +196,64 @@ class TestThreads(unittest.TestCase): t.join() +class TestThreadsAuth(unittest.TestCase): + + def setUp(self): + self.conn = get_connection() + + # Setup auth users + self.conn.admin.system.users.remove({}) + self.conn.admin.add_user('admin-user', 'password') + try: + self.conn.admin.system.users.find_one() + # If we reach here mongod was likely started + # without --auth. Skip this test since it's + # pointless without auth enabled. + self.tearDown() + raise SkipTest() + except OperationFailure: + pass + self.conn.admin.authenticate("admin-user", "password") + self.conn.auth_test.system.users.remove({}) + self.conn.auth_test.add_user("test-user", "password") + + def tearDown(self): + # Remove auth users from databases + self.conn.admin.authenticate("admin-user", "password") + self.conn.admin.system.users.remove({}) + self.conn.auth_test.system.users.remove({}) + self.conn.drop_database('auth_test') + + def test_auto_auth_login(self): + conn = get_connection() + self.assertRaises(OperationFailure, conn.auth_test.test.find_one) + + # Admin auth + conn = get_connection() + conn.admin.authenticate("admin-user", "password") + + threads = [] + for _ in xrange(10): + t = AutoAuthenticateThreads(conn.auth_test.test, 100) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertTrue(t.success) + + # Database-specific auth + conn = get_connection() + conn.auth_test.authenticate("test-user", "password") + + threads = [] + for _ in xrange(10): + t = AutoAuthenticateThreads(conn.auth_test.test, 100) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertTrue(t.success) + + if __name__ == "__main__": unittest.main()