Auth improvements w/pooling and threads PYTHON-4

Also PYTHON-162 and PYTHON-189.

Credit goes to James Murty for most of the patch.

With this change we cache auth credentials (user,
password) in the driver so that each new socket
can be automatically authenticated. This solves
the problem of each new spawned thread having to
re-authenticate.
This commit is contained in:
behackett 2011-07-11 19:27:20 -07:00
parent cb2d396471
commit fcb7ffb25f
5 changed files with 154 additions and 26 deletions

View File

@ -137,14 +137,15 @@ class _Pool(threading.local):
self.pid = pid
if self.sock is not None and self.sock[0] == pid:
return self.sock[1]
return (self.sock[1], True)
try:
self.sock = (pid, self.sockets.pop())
return (self.sock[1], True)
except IndexError:
self.sock = (pid, self.connect(host, port))
return (self.sock[1], False)
return self.sock[1]
def return_socket(self):
if self.sock is not None and self.sock[0] == os.getpid():
@ -158,7 +159,7 @@ class _Pool(threading.local):
self.sock = None
class Connection(common.BaseObject): # TODO support auth for pooling
class Connection(common.BaseObject):
"""Connection to MongoDB.
"""
@ -319,6 +320,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling
# cache of existing indexes used by ensure_index ops
self.__index_cache = {}
self.__auth_credentials = {}
if _connect:
self.__find_node()
@ -413,6 +415,38 @@ class Connection(common.BaseObject): # TODO support auth for pooling
if index_name in self.__index_cache[database_name][collection_name]:
del self.__index_cache[database_name][collection_name][index_name]
def _cache_credentials(self, db_name, username, password):
"""Add credentials to the database authentication cache
for automatic login when a socket is created.
If credentials are already cached for `db_name` they
will be replaced.
"""
self.__auth_credentials[db_name] = (username, password)
def _purge_credentials(self, db_name=None):
"""Purge credentials from the database authentication cache.
If `db_name` is None purge credentials for all databases.
"""
if db_name is None:
self.__auth_credentials.clear()
elif db_name in self.__auth_credentials:
del self.__auth_credentials[db_name]
def __authenticate_socket(self):
"""Authenticate using cached database credentials.
If credentials for the 'admin' database are available only
this database is authenticated, since this gives global access.
"""
if "admin" in self.__auth_credentials:
username, password = self.__auth_credentials["admin"]
self.admin.authenticate(username, password)
else:
for db_name, (u, p) in self.__auth_credentials.iteritems():
self[db_name].authenticate(u, p)
@property
def host(self):
"""Current connected host.
@ -570,7 +604,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling
host, port = self.__find_node()
try:
sock = self.__pool.get_socket(host, port)
sock, from_pool = self.__pool.get_socket(host, port)
except socket.error:
self.disconnect()
raise AutoReconnect("could not connect to %s:%d" % (host, port))
@ -578,8 +612,10 @@ class Connection(common.BaseObject): # TODO support auth for pooling
if t - self.__last_checkout > 1:
if _closed(sock):
self.disconnect()
sock = self.__pool.get_socket(host, port)
sock, from_pool = self.__pool.get_socket(host, port)
self.__last_checkout = t
if self.__auth_credentials and not from_pool:
self.__authenticate_socket()
return sock
def disconnect(self):

View File

@ -561,28 +561,28 @@ class Database(common.BaseObject):
Once authenticated, the user has full read and write access to
this database. Raises :class:`TypeError` if either `name` or
`password` is not an instance of ``(str,
unicode)``. Authentication lasts for the life of the database
connection, or until :meth:`logout` is called.
unicode)``. Authentication lasts for the life of the underlying
:class:`~pymongo.connection.Connection`, or until :meth:`logout`
is called.
The "admin" database is special. Authenticating on "admin"
gives access to *all* databases. Effectively, "admin" access
means root access to the database.
.. note:: Currently, authentication is per
:class:`~socket.socket`. This means that there are a couple
of situations in which re-authentication is necessary:
- On failover (when an
:class:`~pymongo.errors.AutoReconnect` exception is
raised).
- After a call to
:meth:`~pymongo.connection.Connection.disconnect` or
:meth:`~pymongo.connection.Connection.end_request`.
.. note:: This method authenticates the current connection, and
will also cause all new :class:`~socket.socket` connections
in the underlying :class:`~pymongo.connection.Connection` to
be authenticated automatically.
- When sharing a :class:`~pymongo.connection.Connection`
between multiple threads, each thread will need to
authenticate separately.
between multiple threads, all threads will share the
authentication. If you need different authentication profiles
for different purposes (e.g. admin users) you must use
distinct instances of :class:`~pymongo.connection.Connection`.
- To get authentication to apply immediately to all
existing sockets you may need to reset this Connection's
sockets using :meth:`~pymongo.connection.Connection.disconnect`.
.. warning:: Currently, calls to
:meth:`~pymongo.connection.Connection.end_request` will
@ -608,16 +608,24 @@ class Database(common.BaseObject):
try:
self.command("authenticate", user=unicode(name),
nonce=nonce, key=key)
self.connection._cache_credentials(self.name,
unicode(name),
unicode(password))
return True
except OperationFailure:
return False
def logout(self):
"""Deauthorize use of this database for this connection.
"""Deauthorize use of this database for this connection
and future connections.
Note that other databases may still be authorized.
.. note:: Other databases may still be authenticated, and other
existing :class:`~socket.socket` connections may remain
authenticated for this database unless you reset all sockets
with :meth:`~pymongo.connection.Connection.disconnect`.
"""
self.command("logout")
self.connection._purge_credentials(self.name)
def dereference(self, dbref):
"""Dereference a :class:`~bson.dbref.DBRef`, getting the

View File

@ -245,6 +245,7 @@ class TestConnection(unittest.TestCase):
c.pymongo_test.system.users.remove({})
c.admin.add_user("admin", "pass")
c.admin.authenticate("admin", "pass")
c.pymongo_test.add_user("user", "pass")
self.assertRaises(ConfigurationError, Connection,
@ -269,6 +270,8 @@ class TestConnection(unittest.TestCase):
slave_okay=True).slave_okay)
self.assert_(Connection("mongodb://%s:%s/?slaveok=true;w=2" %
(self.host, self.port)).slave_okay)
c.admin.system.users.remove({})
c.pymongo_test.system.users.remove({})
def test_fork(self):
"""Test using a connection before and after a fork.

View File

@ -216,8 +216,10 @@ class TestPooling(unittest.TestCase):
b_sock = b._Connection__pool.sockets[0]
b.test.test.find_one()
a.test.test.find_one()
self.assertEqual(b_sock, b._Connection__pool.get_socket(b.host, b.port))
self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port))
self.assertEqual(b_sock,
b._Connection__pool.get_socket(b.host, b.port)[0])
self.assertEqual(a_sock,
a._Connection__pool.get_socket(a.host, a.port)[0])
def test_pool_with_fork(self):
if sys.platform == "win32":
@ -269,7 +271,8 @@ class TestPooling(unittest.TestCase):
self.assert_(a_sock.getsockname() != b_sock)
self.assert_(a_sock.getsockname() != c_sock)
self.assert_(b_sock != c_sock)
self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port))
self.assertEqual(a_sock,
a._Connection__pool.get_socket(a.host, a.port)[0])
def test_max_pool_size(self):
c = get_connection(max_pool_size=4)

View File

@ -20,7 +20,26 @@ import threading
from nose.plugins.skip import SkipTest
from test_connection import get_connection
from pymongo.errors import AutoReconnect
from pymongo.errors import (AutoReconnect,
OperationFailure,
DuplicateKeyError)
class AutoAuthenticateThreads(threading.Thread):
def __init__(self, collection, num):
threading.Thread.__init__(self)
self.coll = collection
self.num = num
self.success = True
def run(self):
try:
for i in xrange(self.num):
self.coll.insert({'num':i}, safe=True)
self.coll.find_one({'num':i})
except Exception:
self.success = False
class SaveAndFind(threading.Thread):
@ -177,5 +196,64 @@ class TestThreads(unittest.TestCase):
t.join()
class TestThreadsAuth(unittest.TestCase):
def setUp(self):
self.conn = get_connection()
# Setup auth users
self.conn.admin.system.users.remove({})
self.conn.admin.add_user('admin-user', 'password')
try:
self.conn.admin.system.users.find_one()
# If we reach here mongod was likely started
# without --auth. Skip this test since it's
# pointless without auth enabled.
self.tearDown()
raise SkipTest()
except OperationFailure:
pass
self.conn.admin.authenticate("admin-user", "password")
self.conn.auth_test.system.users.remove({})
self.conn.auth_test.add_user("test-user", "password")
def tearDown(self):
# Remove auth users from databases
self.conn.admin.authenticate("admin-user", "password")
self.conn.admin.system.users.remove({})
self.conn.auth_test.system.users.remove({})
self.conn.drop_database('auth_test')
def test_auto_auth_login(self):
conn = get_connection()
self.assertRaises(OperationFailure, conn.auth_test.test.find_one)
# Admin auth
conn = get_connection()
conn.admin.authenticate("admin-user", "password")
threads = []
for _ in xrange(10):
t = AutoAuthenticateThreads(conn.auth_test.test, 100)
t.start()
threads.append(t)
for t in threads:
t.join()
self.assertTrue(t.success)
# Database-specific auth
conn = get_connection()
conn.auth_test.authenticate("test-user", "password")
threads = []
for _ in xrange(10):
t = AutoAuthenticateThreads(conn.auth_test.test, 100)
t.start()
threads.append(t)
for t in threads:
t.join()
self.assertTrue(t.success)
if __name__ == "__main__":
unittest.main()