Auth improvements w/pooling and threads PYTHON-4
Also PYTHON-162 and PYTHON-189. Credit goes to James Murty for most of the patch. With this change we cache auth credentials (user, password) in the driver so that each new socket can be automatically authenticated. This solves the problem of each new spawned thread having to re-authenticate.
This commit is contained in:
parent
cb2d396471
commit
fcb7ffb25f
@ -137,14 +137,15 @@ class _Pool(threading.local):
|
||||
self.pid = pid
|
||||
|
||||
if self.sock is not None and self.sock[0] == pid:
|
||||
return self.sock[1]
|
||||
return (self.sock[1], True)
|
||||
|
||||
try:
|
||||
self.sock = (pid, self.sockets.pop())
|
||||
return (self.sock[1], True)
|
||||
except IndexError:
|
||||
self.sock = (pid, self.connect(host, port))
|
||||
return (self.sock[1], False)
|
||||
|
||||
return self.sock[1]
|
||||
|
||||
def return_socket(self):
|
||||
if self.sock is not None and self.sock[0] == os.getpid():
|
||||
@ -158,7 +159,7 @@ class _Pool(threading.local):
|
||||
self.sock = None
|
||||
|
||||
|
||||
class Connection(common.BaseObject): # TODO support auth for pooling
|
||||
class Connection(common.BaseObject):
|
||||
"""Connection to MongoDB.
|
||||
"""
|
||||
|
||||
@ -319,6 +320,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling
|
||||
|
||||
# cache of existing indexes used by ensure_index ops
|
||||
self.__index_cache = {}
|
||||
self.__auth_credentials = {}
|
||||
|
||||
if _connect:
|
||||
self.__find_node()
|
||||
@ -413,6 +415,38 @@ class Connection(common.BaseObject): # TODO support auth for pooling
|
||||
if index_name in self.__index_cache[database_name][collection_name]:
|
||||
del self.__index_cache[database_name][collection_name][index_name]
|
||||
|
||||
def _cache_credentials(self, db_name, username, password):
|
||||
"""Add credentials to the database authentication cache
|
||||
for automatic login when a socket is created.
|
||||
|
||||
If credentials are already cached for `db_name` they
|
||||
will be replaced.
|
||||
"""
|
||||
self.__auth_credentials[db_name] = (username, password)
|
||||
|
||||
def _purge_credentials(self, db_name=None):
|
||||
"""Purge credentials from the database authentication cache.
|
||||
|
||||
If `db_name` is None purge credentials for all databases.
|
||||
"""
|
||||
if db_name is None:
|
||||
self.__auth_credentials.clear()
|
||||
elif db_name in self.__auth_credentials:
|
||||
del self.__auth_credentials[db_name]
|
||||
|
||||
def __authenticate_socket(self):
|
||||
"""Authenticate using cached database credentials.
|
||||
|
||||
If credentials for the 'admin' database are available only
|
||||
this database is authenticated, since this gives global access.
|
||||
"""
|
||||
if "admin" in self.__auth_credentials:
|
||||
username, password = self.__auth_credentials["admin"]
|
||||
self.admin.authenticate(username, password)
|
||||
else:
|
||||
for db_name, (u, p) in self.__auth_credentials.iteritems():
|
||||
self[db_name].authenticate(u, p)
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
"""Current connected host.
|
||||
@ -570,7 +604,7 @@ class Connection(common.BaseObject): # TODO support auth for pooling
|
||||
host, port = self.__find_node()
|
||||
|
||||
try:
|
||||
sock = self.__pool.get_socket(host, port)
|
||||
sock, from_pool = self.__pool.get_socket(host, port)
|
||||
except socket.error:
|
||||
self.disconnect()
|
||||
raise AutoReconnect("could not connect to %s:%d" % (host, port))
|
||||
@ -578,8 +612,10 @@ class Connection(common.BaseObject): # TODO support auth for pooling
|
||||
if t - self.__last_checkout > 1:
|
||||
if _closed(sock):
|
||||
self.disconnect()
|
||||
sock = self.__pool.get_socket(host, port)
|
||||
sock, from_pool = self.__pool.get_socket(host, port)
|
||||
self.__last_checkout = t
|
||||
if self.__auth_credentials and not from_pool:
|
||||
self.__authenticate_socket()
|
||||
return sock
|
||||
|
||||
def disconnect(self):
|
||||
|
||||
@ -561,28 +561,28 @@ class Database(common.BaseObject):
|
||||
Once authenticated, the user has full read and write access to
|
||||
this database. Raises :class:`TypeError` if either `name` or
|
||||
`password` is not an instance of ``(str,
|
||||
unicode)``. Authentication lasts for the life of the database
|
||||
connection, or until :meth:`logout` is called.
|
||||
unicode)``. Authentication lasts for the life of the underlying
|
||||
:class:`~pymongo.connection.Connection`, or until :meth:`logout`
|
||||
is called.
|
||||
|
||||
The "admin" database is special. Authenticating on "admin"
|
||||
gives access to *all* databases. Effectively, "admin" access
|
||||
means root access to the database.
|
||||
|
||||
.. note:: Currently, authentication is per
|
||||
:class:`~socket.socket`. This means that there are a couple
|
||||
of situations in which re-authentication is necessary:
|
||||
|
||||
- On failover (when an
|
||||
:class:`~pymongo.errors.AutoReconnect` exception is
|
||||
raised).
|
||||
|
||||
- After a call to
|
||||
:meth:`~pymongo.connection.Connection.disconnect` or
|
||||
:meth:`~pymongo.connection.Connection.end_request`.
|
||||
.. note:: This method authenticates the current connection, and
|
||||
will also cause all new :class:`~socket.socket` connections
|
||||
in the underlying :class:`~pymongo.connection.Connection` to
|
||||
be authenticated automatically.
|
||||
|
||||
- When sharing a :class:`~pymongo.connection.Connection`
|
||||
between multiple threads, each thread will need to
|
||||
authenticate separately.
|
||||
between multiple threads, all threads will share the
|
||||
authentication. If you need different authentication profiles
|
||||
for different purposes (e.g. admin users) you must use
|
||||
distinct instances of :class:`~pymongo.connection.Connection`.
|
||||
|
||||
- To get authentication to apply immediately to all
|
||||
existing sockets you may need to reset this Connection's
|
||||
sockets using :meth:`~pymongo.connection.Connection.disconnect`.
|
||||
|
||||
.. warning:: Currently, calls to
|
||||
:meth:`~pymongo.connection.Connection.end_request` will
|
||||
@ -608,16 +608,24 @@ class Database(common.BaseObject):
|
||||
try:
|
||||
self.command("authenticate", user=unicode(name),
|
||||
nonce=nonce, key=key)
|
||||
self.connection._cache_credentials(self.name,
|
||||
unicode(name),
|
||||
unicode(password))
|
||||
return True
|
||||
except OperationFailure:
|
||||
return False
|
||||
|
||||
def logout(self):
|
||||
"""Deauthorize use of this database for this connection.
|
||||
"""Deauthorize use of this database for this connection
|
||||
and future connections.
|
||||
|
||||
Note that other databases may still be authorized.
|
||||
.. note:: Other databases may still be authenticated, and other
|
||||
existing :class:`~socket.socket` connections may remain
|
||||
authenticated for this database unless you reset all sockets
|
||||
with :meth:`~pymongo.connection.Connection.disconnect`.
|
||||
"""
|
||||
self.command("logout")
|
||||
self.connection._purge_credentials(self.name)
|
||||
|
||||
def dereference(self, dbref):
|
||||
"""Dereference a :class:`~bson.dbref.DBRef`, getting the
|
||||
|
||||
@ -245,6 +245,7 @@ class TestConnection(unittest.TestCase):
|
||||
c.pymongo_test.system.users.remove({})
|
||||
|
||||
c.admin.add_user("admin", "pass")
|
||||
c.admin.authenticate("admin", "pass")
|
||||
c.pymongo_test.add_user("user", "pass")
|
||||
|
||||
self.assertRaises(ConfigurationError, Connection,
|
||||
@ -269,6 +270,8 @@ class TestConnection(unittest.TestCase):
|
||||
slave_okay=True).slave_okay)
|
||||
self.assert_(Connection("mongodb://%s:%s/?slaveok=true;w=2" %
|
||||
(self.host, self.port)).slave_okay)
|
||||
c.admin.system.users.remove({})
|
||||
c.pymongo_test.system.users.remove({})
|
||||
|
||||
def test_fork(self):
|
||||
"""Test using a connection before and after a fork.
|
||||
|
||||
@ -216,8 +216,10 @@ class TestPooling(unittest.TestCase):
|
||||
b_sock = b._Connection__pool.sockets[0]
|
||||
b.test.test.find_one()
|
||||
a.test.test.find_one()
|
||||
self.assertEqual(b_sock, b._Connection__pool.get_socket(b.host, b.port))
|
||||
self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port))
|
||||
self.assertEqual(b_sock,
|
||||
b._Connection__pool.get_socket(b.host, b.port)[0])
|
||||
self.assertEqual(a_sock,
|
||||
a._Connection__pool.get_socket(a.host, a.port)[0])
|
||||
|
||||
def test_pool_with_fork(self):
|
||||
if sys.platform == "win32":
|
||||
@ -269,7 +271,8 @@ class TestPooling(unittest.TestCase):
|
||||
self.assert_(a_sock.getsockname() != b_sock)
|
||||
self.assert_(a_sock.getsockname() != c_sock)
|
||||
self.assert_(b_sock != c_sock)
|
||||
self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port))
|
||||
self.assertEqual(a_sock,
|
||||
a._Connection__pool.get_socket(a.host, a.port)[0])
|
||||
|
||||
def test_max_pool_size(self):
|
||||
c = get_connection(max_pool_size=4)
|
||||
|
||||
@ -20,7 +20,26 @@ import threading
|
||||
from nose.plugins.skip import SkipTest
|
||||
|
||||
from test_connection import get_connection
|
||||
from pymongo.errors import AutoReconnect
|
||||
from pymongo.errors import (AutoReconnect,
|
||||
OperationFailure,
|
||||
DuplicateKeyError)
|
||||
|
||||
|
||||
class AutoAuthenticateThreads(threading.Thread):
|
||||
|
||||
def __init__(self, collection, num):
|
||||
threading.Thread.__init__(self)
|
||||
self.coll = collection
|
||||
self.num = num
|
||||
self.success = True
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
for i in xrange(self.num):
|
||||
self.coll.insert({'num':i}, safe=True)
|
||||
self.coll.find_one({'num':i})
|
||||
except Exception:
|
||||
self.success = False
|
||||
|
||||
|
||||
class SaveAndFind(threading.Thread):
|
||||
@ -177,5 +196,64 @@ class TestThreads(unittest.TestCase):
|
||||
t.join()
|
||||
|
||||
|
||||
class TestThreadsAuth(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.conn = get_connection()
|
||||
|
||||
# Setup auth users
|
||||
self.conn.admin.system.users.remove({})
|
||||
self.conn.admin.add_user('admin-user', 'password')
|
||||
try:
|
||||
self.conn.admin.system.users.find_one()
|
||||
# If we reach here mongod was likely started
|
||||
# without --auth. Skip this test since it's
|
||||
# pointless without auth enabled.
|
||||
self.tearDown()
|
||||
raise SkipTest()
|
||||
except OperationFailure:
|
||||
pass
|
||||
self.conn.admin.authenticate("admin-user", "password")
|
||||
self.conn.auth_test.system.users.remove({})
|
||||
self.conn.auth_test.add_user("test-user", "password")
|
||||
|
||||
def tearDown(self):
|
||||
# Remove auth users from databases
|
||||
self.conn.admin.authenticate("admin-user", "password")
|
||||
self.conn.admin.system.users.remove({})
|
||||
self.conn.auth_test.system.users.remove({})
|
||||
self.conn.drop_database('auth_test')
|
||||
|
||||
def test_auto_auth_login(self):
|
||||
conn = get_connection()
|
||||
self.assertRaises(OperationFailure, conn.auth_test.test.find_one)
|
||||
|
||||
# Admin auth
|
||||
conn = get_connection()
|
||||
conn.admin.authenticate("admin-user", "password")
|
||||
|
||||
threads = []
|
||||
for _ in xrange(10):
|
||||
t = AutoAuthenticateThreads(conn.auth_test.test, 100)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
self.assertTrue(t.success)
|
||||
|
||||
# Database-specific auth
|
||||
conn = get_connection()
|
||||
conn.auth_test.authenticate("test-user", "password")
|
||||
|
||||
threads = []
|
||||
for _ in xrange(10):
|
||||
t = AutoAuthenticateThreads(conn.auth_test.test, 100)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
self.assertTrue(t.success)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user