mongo-python-driver/test/test_pooling.py
behackett fcb7ffb25f Auth improvements w/pooling and threads PYTHON-4
Also PYTHON-162 and PYTHON-189.

Credit goes to James Murty for most of the patch.

With this change we cache auth credentials (user,
password) in the driver so that each new socket
can be automatically authenticated. This solves
the problem of each new spawned thread having to
re-authenticate.
2011-07-11 19:27:20 -07:00

295 lines
8.4 KiB
Python

# Copyright 2009-2010 10gen, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test built in connection-pooling."""
import os
import random
import sys
import threading
import time
import unittest
sys.path[0:0] = [""]
from nose.plugins.skip import SkipTest
from pymongo.connection import Connection, _Pool
from pymongo.errors import ConfigurationError
from test_connection import get_connection
N = 50
DB = "pymongo-pooling-tests"
class MongoThread(threading.Thread):
def __init__(self, test_case):
threading.Thread.__init__(self)
self.connection = test_case.c
self.db = self.connection[DB]
self.ut = test_case
class SaveAndFind(MongoThread):
def run(self):
for _ in xrange(N):
rand = random.randint(0, N)
id = self.db.sf.save({"x": rand}, safe=True)
self.ut.assertEqual(rand, self.db.sf.find_one(id)["x"])
self.connection.end_request()
class Unique(MongoThread):
def run(self):
for _ in xrange(N):
self.db.unique.insert({})
self.ut.assertEqual(None, self.db.error())
self.connection.end_request()
class NonUnique(MongoThread):
def run(self):
for _ in xrange(N):
self.db.unique.insert({"_id": "mike"})
self.ut.assertNotEqual(None, self.db.error())
self.connection.end_request()
class Disconnect(MongoThread):
def run(self):
for _ in xrange(N):
self.connection.disconnect()
class NoRequest(MongoThread):
def run(self):
errors = 0
for _ in xrange(N):
self.db.unique.insert({"_id": "mike"})
if self.db.error() is None:
errors += 1
self.ut.assertEqual(0, errors)
def run_cases(ut, cases):
threads = []
for case in cases:
for i in range(10):
thread = case(ut)
thread.start()
threads.append(thread)
for t in threads:
t.join()
class OneOp(threading.Thread):
def __init__(self, connection):
threading.Thread.__init__(self)
self.c = connection
def run(self):
assert len(self.c._Connection__pool.sockets) == 1
self.c.test.test.find_one()
assert len(self.c._Connection__pool.sockets) == 0
self.c.end_request()
assert len(self.c._Connection__pool.sockets) == 1
class CreateAndReleaseSocket(threading.Thread):
def __init__(self, connection):
threading.Thread.__init__(self)
self.c = connection
def run(self):
self.c.test.test.find_one()
time.sleep(1)
self.c.end_request()
class TestPooling(unittest.TestCase):
def setUp(self):
self.c = get_connection()
# reset the db
self.c.drop_database(DB)
self.c[DB].unique.insert({"_id": "mike"})
self.c[DB].unique.find_one()
def tearDown(self):
self.c = None
def test_max_pool_size_validation(self):
self.assertRaises(ValueError, Connection, max_pool_size=-1)
self.assertRaises(AssertionError, Connection, max_pool_size='foo')
self.assertRaises(ConfigurationError, Connection,
'mongodb://localhost/?maxPoolSize=-1')
self.assertRaises(ConfigurationError, Connection,
'mongodb://localhost/?maxPoolSize=foo')
self.assertRaises(ConfigurationError, Connection,
'mongodb://localhost/?maxPoolSize=5.5')
c = Connection('mongodb://localhost/?maxPoolSize=5')
self.assertEqual(c.max_pool_size, 5)
c = Connection(max_pool_size=100)
self.assertEqual(c.max_pool_size, 100)
def test_no_disconnect(self):
run_cases(self, [NoRequest, NonUnique, Unique, SaveAndFind])
def test_simple_disconnect(self):
self.c.test.stuff.find()
self.assertEqual(0, len(self.c._Connection__pool.sockets))
self.assertNotEqual(None, self.c._Connection__pool.sock)
self.c.end_request()
self.assertEqual(1, len(self.c._Connection__pool.sockets))
self.assertEqual(None, self.c._Connection__pool.sock)
self.c.disconnect()
self.assertEqual(0, len(self.c._Connection__pool.sockets))
self.assertEqual(None, self.c._Connection__pool.sock)
def test_disconnect(self):
run_cases(self, [SaveAndFind, Disconnect, Unique])
def test_independent_pools(self):
p = _Pool(10, 0)
self.assertEqual([], p.sockets)
self.c.end_request()
self.assertEqual([], p.sockets)
def test_dependent_pools(self):
c = get_connection()
self.assertEqual(1, len(c._Connection__pool.sockets))
c.test.test.find_one()
self.assertEqual(0, len(c._Connection__pool.sockets))
c.end_request()
self.assertEqual(1, len(c._Connection__pool.sockets))
t = OneOp(c)
t.start()
t.join()
self.assertEqual(1, len(c._Connection__pool.sockets))
c.test.test.find_one()
self.assertEqual(0, len(c._Connection__pool.sockets))
def test_multiple_connections(self):
a = get_connection()
b = get_connection()
self.assertEqual(1, len(a._Connection__pool.sockets))
self.assertEqual(1, len(b._Connection__pool.sockets))
a.test.test.find_one()
a.end_request()
self.assertEqual(1, len(a._Connection__pool.sockets))
self.assertEqual(1, len(b._Connection__pool.sockets))
a_sock = a._Connection__pool.sockets[0]
b.end_request()
self.assertEqual(1, len(a._Connection__pool.sockets))
self.assertEqual(1, len(b._Connection__pool.sockets))
b.test.test.find_one()
self.assertEqual(1, len(a._Connection__pool.sockets))
self.assertEqual(0, len(b._Connection__pool.sockets))
b.end_request()
b_sock = b._Connection__pool.sockets[0]
b.test.test.find_one()
a.test.test.find_one()
self.assertEqual(b_sock,
b._Connection__pool.get_socket(b.host, b.port)[0])
self.assertEqual(a_sock,
a._Connection__pool.get_socket(a.host, a.port)[0])
def test_pool_with_fork(self):
if sys.platform == "win32":
raise SkipTest()
try:
from multiprocessing import Process, Pipe
except ImportError:
raise SkipTest()
a = get_connection()
a.test.test.find_one()
a.end_request()
self.assertEqual(1, len(a._Connection__pool.sockets))
a_sock = a._Connection__pool.sockets[0]
def loop(pipe):
c = get_connection()
self.assertEqual(1, len(c._Connection__pool.sockets))
c.test.test.find_one()
self.assertEqual(0, len(c._Connection__pool.sockets))
c.end_request()
self.assertEqual(1, len(c._Connection__pool.sockets))
pipe.send(c._Connection__pool.sockets[0].getsockname())
cp1, cc1 = Pipe()
cp2, cc2 = Pipe()
p1 = Process(target=loop, args=(cc1,))
p2 = Process(target=loop, args=(cc2,))
p1.start()
p2.start()
p1.join(1)
p2.join(1)
p1.terminate()
p2.terminate()
p1.join()
p2.join()
cc1.close()
cc2.close()
b_sock = cp1.recv()
c_sock = cp2.recv()
self.assert_(a_sock.getsockname() != b_sock)
self.assert_(a_sock.getsockname() != c_sock)
self.assert_(b_sock != c_sock)
self.assertEqual(a_sock,
a._Connection__pool.get_socket(a.host, a.port)[0])
def test_max_pool_size(self):
c = get_connection(max_pool_size=4)
threads = []
for i in range(40):
t = CreateAndReleaseSocket(c)
t.start()
threads.append(t)
for t in threads:
t.join()
# There's a race condition, so be lenient
self.assert_(abs(4 - len(c._Connection__pool.sockets)) < 4)
if __name__ == "__main__":
unittest.main()