PYTHON-785 Don't use requests in tests.

This commit is contained in:
A. Jesse Jiryu Davis 2014-11-20 16:48:03 -05:00
parent daad4446f8
commit e8be121a89
11 changed files with 232 additions and 1472 deletions

View File

@ -31,6 +31,7 @@ from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple
from pymongo.errors import OperationFailure
from pymongo.read_preferences import ReadPreference
from test import client_context, host, port, SkipTest, unittest, Version
from test.utils import delay
# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS.
GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
@ -46,18 +47,22 @@ SASL_DB = os.environ.get('SASL_DB', '$external')
class AutoAuthenticateThread(threading.Thread):
"""Used in testing threaded authentication.
This does collection.find_one() with a 1-second delay to ensure it must
check out and authenticate multiple sockets from the pool concurrently.
:Parameters:
`collection`: An auth-protected collection containing one document.
"""
def __init__(self, database):
def __init__(self, collection):
super(AutoAuthenticateThread, self).__init__()
self.database = database
self.success = True
self.collection = collection
self.success = False
def run(self):
try:
self.database.command('dbstats')
except OperationFailure:
self.success = False
assert self.collection.find_one({'$where': delay(1)}) is not None
self.success = True
class TestGSSAPI(unittest.TestCase):
@ -148,14 +153,20 @@ class TestGSSAPI(unittest.TestCase):
def test_gssapi_threaded(self):
client = MongoClient(GSSAPI_HOST)
# Make sure each thread uses a different socket.
client.start_request()
self.assertTrue(client.test.authenticate(PRINCIPAL,
mechanism='GSSAPI'))
# Need one document in the collection. AutoAuthenticateThread does
# collection.find_one with a 1-second delay, forcing it to check out
# multiple sockets from the pool concurrently, proving that
# auto-authentication works with GSSAPI.
collection = client.test.collection
collection.remove()
collection.insert({'_id': 1})
threads = []
for _ in range(4):
threads.append(AutoAuthenticateThread(client.test))
threads.append(AutoAuthenticateThread(collection))
for thread in threads:
thread.start()
for thread in threads:
@ -174,7 +185,7 @@ class TestGSSAPI(unittest.TestCase):
threads = []
for _ in range(4):
threads.append(AutoAuthenticateThread(client.foo))
threads.append(AutoAuthenticateThread(collection))
for thread in threads:
thread.start()
for thread in threads:

View File

@ -29,7 +29,7 @@ from test import (client_context,
port,
IntegrationTest,
SkipTest)
from test.utils import oid_generated_on_client, remove_all_users
from test.utils import oid_generated_on_client, remove_all_users, wait_until
class BulkTestBase(IntegrationTest):
@ -1115,9 +1115,9 @@ class TestBulkNoResults(BulkTestBase):
batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}})
batch.insert({'_id': 2})
batch.find({'_id': 1}).remove_one()
with self.client.start_request():
self.assertTrue(batch.execute({'w': 0}) is None)
self.assertEqual(2, self.coll.count())
self.assertTrue(batch.execute({'w': 0}) is None)
wait_until(lambda: 2 == self.coll.count(),
'insert 2 documents')
def test_no_results_ordered_failure(self):
@ -1127,9 +1127,9 @@ class TestBulkNoResults(BulkTestBase):
batch.insert({'_id': 2})
batch.insert({'_id': 1})
batch.find({'_id': 1}).remove_one()
with self.client.start_request():
self.assertTrue(batch.execute({'w': 0}) is None)
self.assertEqual(3, self.coll.count())
self.assertTrue(batch.execute({'w': 0}) is None)
wait_until(lambda: 3 == self.coll.count(),
'insert 3 documents')
def test_no_results_unordered_success(self):
@ -1138,9 +1138,9 @@ class TestBulkNoResults(BulkTestBase):
batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}})
batch.insert({'_id': 2})
batch.find({'_id': 1}).remove_one()
with self.client.start_request():
self.assertTrue(batch.execute({'w': 0}) is None)
self.assertEqual(2, self.coll.count())
self.assertTrue(batch.execute({'w': 0}) is None)
wait_until(lambda: 2 == self.coll.count(),
'insert 2 documents')
def test_no_results_unordered_failure(self):
@ -1150,10 +1150,10 @@ class TestBulkNoResults(BulkTestBase):
batch.insert({'_id': 2})
batch.insert({'_id': 1})
batch.find({'_id': 1}).remove_one()
with self.client.start_request():
self.assertTrue(batch.execute({'w': 0}) is None)
self.assertEqual(2, self.coll.count())
self.assertTrue(self.coll.find_one({'_id': 1}) is None)
self.assertTrue(batch.execute({'w': 0}) is None)
wait_until(lambda: 2 == self.coll.count(),
'insert 2 documents')
self.assertTrue(self.coll.find_one({'_id': 1}) is None)
class TestBulkAuthorization(BulkTestBase):

View File

@ -18,7 +18,6 @@ import contextlib
import datetime
import multiprocessing
import os
import threading
import socket
import sys
import time
@ -32,7 +31,6 @@ from bson.son import SON
from bson.tz_util import utc
from pymongo.mongo_client import MongoClient
from pymongo.database import Database
from pymongo.pool import SocketInfo
from pymongo import auth, message
from pymongo.errors import (AutoReconnect,
ConfigurationError,
@ -63,7 +61,6 @@ from test.utils import (assertRaisesExactly,
ignore_deprecations,
remove_all_users,
server_is_master_with_slave,
TestRequestMixin,
get_pool,
one,
connected,
@ -74,7 +71,7 @@ from test.utils import (assertRaisesExactly,
NTHREADS)
class ClientUnitTest(unittest.TestCase, TestRequestMixin):
class ClientUnitTest(unittest.TestCase):
"""MongoClient tests that don't require a server."""
@classmethod
@ -166,7 +163,7 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin):
self.assertEqual(Database(c, 'foo'), c.get_default_database())
class TestClient(IntegrationTest, TestRequestMixin):
class TestClient(IntegrationTest):
def test_constants(self):
# Set bad defaults.
@ -605,80 +602,6 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
self.assertEqual(0, len(get_pool(client).sockets))
def test_with_start_request(self):
pool = get_pool(self.client)
# No request started
self.assertNoRequest(pool)
self.assertDifferentSock(pool)
# Start a request
request_context_mgr = self.client.start_request()
self.assertTrue(
isinstance(request_context_mgr, object)
)
self.assertNoSocketYet(pool)
self.assertSameSock(pool)
self.assertRequestSocket(pool)
# End request
request_context_mgr.__exit__(None, None, None)
self.assertNoRequest(pool)
self.assertDifferentSock(pool)
# Test the 'with' statement
with self.client.start_request() as request:
self.assertEqual(self.client, request.connection)
self.assertNoSocketYet(pool)
self.assertSameSock(pool)
self.assertRequestSocket(pool)
# Request has ended
self.assertNoRequest(pool)
self.assertDifferentSock(pool)
def test_request_threads(self):
client = self.client
pool = get_pool(client)
self.assertNotInRequestAndDifferentSock(client, pool)
started_request, ended_request = threading.Event(), threading.Event()
checked_request = threading.Event()
thread_done = [False]
# Starting a request in one thread doesn't put the other thread in a
# request
def f():
self.assertNotInRequestAndDifferentSock(client, pool)
client.start_request()
self.assertInRequestAndSameSock(client, pool)
started_request.set()
checked_request.wait()
checked_request.clear()
self.assertInRequestAndSameSock(client, pool)
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pool)
ended_request.set()
checked_request.wait()
thread_done[0] = True
t = threading.Thread(target=f)
t.setDaemon(True)
t.start()
# It doesn't matter in what order the main thread or t initially get
# to started_request.set() / wait(); by waiting here we ensure that t
# has called client.start_request() before we assert on the next line.
started_request.wait()
self.assertNotInRequestAndDifferentSock(client, pool)
checked_request.set()
ended_request.wait()
self.assertNotInRequestAndDifferentSock(client, pool)
checked_request.set()
t.join()
self.assertNotInRequestAndDifferentSock(client, pool)
self.assertTrue(thread_done[0], "Thread didn't complete")
def test_interrupt_signal(self):
if sys.platform.startswith('java'):
# We can't figure out how to raise an exception on a thread that's
@ -724,7 +647,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
next(db.foo.find())
)
def test_operation_failure_without_request(self):
def test_operation_failure(self):
# Ensure MongoClient doesn't close socket after it gets an error
# response to getLastError. PYTHON-395.
pool = get_pool(self.client)
@ -741,27 +664,6 @@ class TestClient(IntegrationTest, TestRequestMixin):
new_sock_info = next(iter(pool.sockets))
self.assertEqual(old_sock_info, new_sock_info)
def test_operation_failure_with_request(self):
# Ensure MongoClient doesn't close socket after it gets an error
# response to getLastError. PYTHON-395.
c = rs_or_single_client()
c.start_request()
pool = get_pool(c)
# Pool reserves a socket for this thread.
c.pymongo_test.test.find_one()
self.assertTrue(isinstance(pool._get_request_state(), SocketInfo))
old_sock_info = pool._get_request_state()
c.pymongo_test.test.drop()
c.pymongo_test.test.insert({'_id': 'foo'})
self.assertRaises(
OperationFailure,
c.pymongo_test.test.insert, {'_id': 'foo'})
# OperationFailure doesn't affect the request socket
self.assertEqual(old_sock_info, pool._get_request_state())
def test_alive(self):
self.assertTrue(self.client.alive())

View File

@ -50,7 +50,7 @@ from pymongo.errors import (DocumentTooLarge,
from test.test_client import IntegrationTest
from test.utils import (is_mongos, joinall, enable_text_search, get_pool,
oid_generated_on_client, one, ignore_deprecations,
rs_or_single_client)
rs_or_single_client, wait_until)
from test import client_context, host, port, qcheck, unittest
@ -793,14 +793,15 @@ class TestCollection(IntegrationTest):
db.drop_collection("test")
db.test.ensure_index([('i', ASCENDING)], unique=True)
with db.connection.start_request():
# No error
db.test.insert([{'i': i} for i in range(5, 10)], w=0)
db.test.remove()
# No error
db.test.insert([{'i': i} for i in range(5, 10)], w=0)
wait_until(lambda: 5 == db.test.count(), 'insert 5 documents')
# No error
db.test.insert([{'i': 1}] * 2, w=0)
self.assertEqual(1, db.test.count())
db.test.remove()
# No error
db.test.insert([{'i': 1}] * 2, w=0)
wait_until(lambda: 1 == db.test.count(), 'insert 1 document')
self.assertRaises(
DuplicateKeyError,
@ -811,20 +812,19 @@ class TestCollection(IntegrationTest):
wc = db.write_concern
db.write_concern = {"w": 0}
try:
with db.connection.start_request():
db.test.ensure_index([('i', ASCENDING)], unique=True)
db.test.ensure_index([('i', ASCENDING)], unique=True)
# No error
db.test.insert([{'i': 1}] * 2)
self.assertEqual(1, db.test.count())
# No error.
db.test.insert([{'i': 1}] * 2)
wait_until(lambda: 1 == db.test.count(), 'insert 1 document')
# Implied safe
# Implied acknowledged.
self.assertRaises(
DuplicateKeyError,
lambda: db.test.insert([{'i': 2}] * 2, fsync=True),
)
# Explicit safe
# Explicit acknowledged.
self.assertRaises(
DuplicateKeyError,
lambda: db.test.insert([{'i': 2}] * 2, w=1),
@ -902,7 +902,7 @@ class TestCollection(IntegrationTest):
self.db.test.find_one({'_id': 'explicit_id'})['hello']
)
# Safe mode
# Acknowledged mode.
self.db.test.create_index("hello", unique=True)
# No exception, even though we duplicate the first doc's "hello" value
self.db.test.save({'_id': 'explicit_id', 'hello': 'world'}, w=0)
@ -924,23 +924,19 @@ class TestCollection(IntegrationTest):
def test_unique_index(self):
db = self.db
db.drop_collection("test")
db.test.create_index("hello")
with self.client.start_request():
db.drop_collection("test")
db.test.create_index("hello")
# No error.
db.test.save({"hello": "world"})
db.test.save({"hello": "world"})
db.drop_collection("test")
db.test.create_index("hello", unique=True)
with self.assertRaises(DuplicateKeyError):
db.test.save({"hello": "world"})
db.test.save({"hello": "mike"})
db.test.save({"hello": "world"})
self.assertFalse(db.error())
db.drop_collection("test")
db.test.create_index("hello", unique=True)
db.test.save({"hello": "world"})
db.test.save({"hello": "mike"})
db.test.save({"hello": "world"}, w=0)
self.assertTrue(db.error())
def test_duplicate_key_error(self):
db = self.db
@ -949,41 +945,17 @@ class TestCollection(IntegrationTest):
db.test.create_index("x", unique=True)
db.test.insert({"_id": 1, "x": 1})
db.test.insert({"_id": 2, "x": 2})
with db.connection.start_request():
# No error
db.test.insert({"_id": 1, "x": 1}, w=0)
db.test.save({"_id": 1, "x": 1}, w=0)
db.test.insert({"_id": 2, "x": 2}, w=0)
db.test.save({"_id": 2, "x": 2}, w=0)
with self.assertRaises(DuplicateKeyError) as context:
db.test.insert({"x": 1})
# But all those statements didn't do anything
self.assertEqual(2, db.test.count())
self.assertIsNotNone(context.exception.details)
expected_error = OperationFailure
if client_context.version.at_least(1, 3):
expected_error = DuplicateKeyError
with self.assertRaises(DuplicateKeyError) as context:
db.test.save({"x": 1})
self.assertRaises(expected_error,
db.test.insert, {"_id": 1})
self.assertRaises(expected_error,
db.test.insert, {"x": 1})
self.assertRaises(expected_error,
db.test.save, {"x": 2})
self.assertRaises(expected_error,
db.test.update, {"x": 1},
{"$inc": {"x": 1}})
try:
db.test.insert({"_id": 1})
except expected_error as exc:
# Just check that we set the error document. Fields
# vary by MongoDB version.
self.assertTrue(exc.details is not None)
else:
self.fail("%s was not raised" % (expected_error.__name__,))
self.assertIsNotNone(context.exception.details)
self.assertEqual(1, db.test.count())
def test_wtimeout(self):
# Ensure setting wtimeout doesn't disable write concern altogether.
@ -1006,33 +978,33 @@ class TestCollection(IntegrationTest):
self.assertEqual(1, db.test.count())
docs = []
docs.append({"_id": oid, "two": 2})
docs.append({"_id": oid, "two": 2}) # Duplicate _id.
docs.append({"three": 3})
docs.append({"four": 4})
docs.append({"five": 5})
with self.client.start_request():
db.test.insert(docs, manipulate=False, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(1, db.test.count())
with self.assertRaises(DuplicateKeyError):
db.test.insert(docs, manipulate=False)
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(4, db.test.count())
self.assertEqual(1, db.test.count())
db.drop_collection("test")
oid = db.test.insert({"_id": oid, "one": 1}, w=0)
self.assertEqual(1, db.test.count())
docs[0].pop("_id")
docs[2]["_id"] = oid
with self.assertRaises(DuplicateKeyError):
db.test.insert(docs, manipulate=False, continue_on_error=True)
db.test.insert(docs, manipulate=False, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(3, db.test.count())
self.assertEqual(4, db.test.count())
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(6, db.test.count())
db.drop_collection("test")
oid = db.test.insert({"_id": oid, "one": 1}, w=0)
self.assertEqual(1, db.test.count())
docs[0].pop("_id")
docs[2]["_id"] = oid
with self.assertRaises(DuplicateKeyError):
db.test.insert(docs, manipulate=False)
self.assertEqual(3, db.test.count())
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(6, db.test.count())
def test_error_code(self):
try:
@ -1062,17 +1034,14 @@ class TestCollection(IntegrationTest):
self.assertRaises(DuplicateKeyError,
db.test.insert, {"hello": {"a": 4, "b": 10}})
def test_safe_insert(self):
with self.client.start_request():
db = self.db
db.drop_collection("test")
def test_acknowledged_insert(self):
db = self.db
db.drop_collection("test")
a = {"hello": "world"}
db.test.insert(a)
db.test.insert(a, w=0)
self.assertTrue("E11000" in db.error()["err"])
self.assertRaises(OperationFailure, db.test.insert, a)
a = {"hello": "world"}
db.test.insert(a)
db.test.insert(a, w=0)
self.assertRaises(OperationFailure, db.test.insert, a)
def test_update(self):
db = self.db
@ -1159,36 +1128,25 @@ class TestCollection(IntegrationTest):
self.assertEqual(1, db.test.count())
self.assertEqual(2, db.test.find_one()["count"])
def test_safe_update(self):
with self.client.start_request():
db = self.db
v113minus = client_context.version.at_least(1, 1, 3, -1)
v19 = client_context.version.at_least(1, 9)
def test_acknowledged_update(self):
db = self.db
db.drop_collection("test")
db.test.create_index("x", unique=True)
db.drop_collection("test")
db.test.create_index("x", unique=True)
db.test.insert({"x": 5})
id = db.test.insert({"x": 4})
db.test.insert({"x": 5})
id = db.test.insert({"x": 4})
self.assertEqual(
None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0))
self.assertEqual(
None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0))
self.assertRaises(DuplicateKeyError, db.test.update,
{"_id": id}, {"$inc": {"x": 1}})
if v19:
self.assertTrue("E11000" in db.error()["err"])
elif v113minus:
self.assertTrue(db.error()["err"].startswith("E11001"))
else:
self.assertTrue(db.error()["err"].startswith("E12011"))
self.assertEqual(1, db.test.update({"_id": id},
{"$inc": {"x": 2}})["n"])
self.assertRaises(OperationFailure, db.test.update,
{"_id": id}, {"$inc": {"x": 1}})
self.assertEqual(1, db.test.update({"_id": id},
{"$inc": {"x": 2}})["n"])
self.assertEqual(0, db.test.update({"_id": "foo"},
{"$inc": {"x": 2}})["n"])
self.assertEqual(0, db.test.update({"_id": "foo"},
{"$inc": {"x": 2}})["n"])
def test_update_with_invalid_keys(self):
self.db.drop_collection("test")
@ -1233,20 +1191,17 @@ class TestCollection(IntegrationTest):
self.assertNotEqual(0, self.db.test.update({"hello": "world"},
{})['n'])
def test_safe_save(self):
with self.client.start_request():
db = self.db
db.drop_collection("test")
db.test.create_index("hello", unique=True)
def test_acknowledged_save(self):
db = self.db
db.drop_collection("test")
db.test.create_index("hello", unique=True)
db.test.save({"hello": "world"})
db.test.save({"hello": "world"}, w=0)
self.assertTrue("E11000" in db.error()["err"])
db.test.save({"hello": "world"})
db.test.save({"hello": "world"}, w=0)
self.assertRaises(DuplicateKeyError, db.test.save,
{"hello": "world"})
self.assertRaises(OperationFailure, db.test.save,
{"hello": "world"})
def test_safe_remove(self):
def test_acknowledged_remove(self):
db = self.db
db.drop_collection("test")
db.create_collection("test", capped=True, size=1000)
@ -1254,16 +1209,8 @@ class TestCollection(IntegrationTest):
db.test.insert({"x": 1})
self.assertEqual(1, db.test.count())
with db.connection.start_request():
self.assertEqual(None, db.test.remove({"x": 1}, w=0))
self.assertEqual(1, db.test.count())
if client_context.version.at_least(1, 1, 3, -1):
self.assertRaises(OperationFailure, db.test.remove,
{"x": 1})
else: # Just test that it doesn't blow up
db.test.remove({"x": 1})
# Can't remove from capped collection.
self.assertRaises(OperationFailure, db.test.remove, {"x": 1})
db.drop_collection("test")
db.test.insert({"x": 1})
db.test.insert({"x": 1})
@ -1836,16 +1783,16 @@ class TestCollection(IntegrationTest):
batch[1]['_id'] = batch[0]['_id']
# Test that inserts fail after first error, acknowledged.
# Test that inserts fail after first error.
self.db.test.drop()
self.assertRaises(DuplicateKeyError, self.db.test.insert, batch, w=1)
self.assertRaises(DuplicateKeyError, self.db.test.insert, batch)
self.assertEqual(1, self.db.test.count())
# Test that inserts fail after first error, unacknowledged.
# 2 batches, 2 errors, continue on error.
self.db.test.drop()
with self.client.start_request():
self.assertTrue(self.db.test.insert(batch, w=0))
self.assertEqual(1, self.db.test.count())
self.assertTrue(self.db.test.insert(batch, w=0))
wait_until(lambda: 1 == self.db.test.count(),
'insert 1 document')
# 2 batches, 2 errors, acknowledged, continue on error
self.db.test.drop()
@ -1862,11 +1809,11 @@ class TestCollection(IntegrationTest):
# 2 batches, 2 errors, unacknowledged, continue on error
self.db.test.drop()
with self.client.start_request():
self.assertTrue(
self.db.test.insert(batch, continue_on_error=True, w=0))
# Only the first and third documents should be inserted.
self.assertEqual(2, self.db.test.count())
self.assertTrue(
self.db.test.insert(batch, continue_on_error=True, w=0))
# Only the first and third documents should be inserted.
wait_until(lambda: 2 == self.db.test.count(),
'insert 2 documents')
def test_numerous_inserts(self):
# Ensure we don't exceed server's 1000-document batch size limit.

View File

@ -57,6 +57,7 @@ from test.utils import (
ignore_deprecations,
remove_all_users,
rs_or_single_client_noauth,
rs_or_single_client,
server_started_with_auth)
@ -285,38 +286,37 @@ class TestDatabase(IntegrationTest):
@client_context.require_no_mongos
def test_errors(self):
db = self.client.pymongo_test
# We must call getlasterror, etc. on same socket as the last operation.
db = rs_or_single_client(max_pool_size=1).pymongo_test
db.reset_error_history()
self.assertEqual(None, db.error())
self.assertEqual(None, db.previous_error())
with self.client.start_request():
db.reset_error_history()
self.assertEqual(None, db.error())
self.assertEqual(None, db.previous_error())
db.command("forceerror", check=False)
self.assertTrue(db.error())
self.assertTrue(db.previous_error())
db.command("forceerror", check=False)
self.assertTrue(db.error())
self.assertTrue(db.previous_error())
db.command("forceerror", check=False)
self.assertTrue(db.error())
prev_error = db.previous_error()
self.assertEqual(prev_error["nPrev"], 1)
del prev_error["nPrev"]
prev_error.pop("lastOp", None)
error = db.error()
error.pop("lastOp", None)
# getLastError includes "connectionId" in recent
# server versions, getPrevError does not.
error.pop("connectionId", None)
self.assertEqual(error, prev_error)
db.command("forceerror", check=False)
self.assertTrue(db.error())
prev_error = db.previous_error()
self.assertEqual(prev_error["nPrev"], 1)
del prev_error["nPrev"]
prev_error.pop("lastOp", None)
error = db.error()
error.pop("lastOp", None)
# getLastError includes "connectionId" in recent
# server versions, getPrevError does not.
error.pop("connectionId", None)
self.assertEqual(error, prev_error)
db.test.find_one()
self.assertEqual(None, db.error())
self.assertTrue(db.previous_error())
self.assertEqual(db.previous_error()["nPrev"], 2)
db.test.find_one()
self.assertEqual(None, db.error())
self.assertTrue(db.previous_error())
self.assertEqual(db.previous_error()["nPrev"], 2)
db.reset_error_history()
self.assertEqual(None, db.error())
self.assertEqual(None, db.previous_error())
db.reset_error_history()
self.assertEqual(None, db.error())
self.assertEqual(None, db.previous_error())
def test_command(self):
db = self.client.admin
@ -338,17 +338,16 @@ class TestDatabase(IntegrationTest):
self.assertTrue(isinstance(result['result'][0]['r'], Regex))
def test_last_status(self):
db = self.client.pymongo_test
# We must call getlasterror on the same socket as the last operation.
db = rs_or_single_client(max_pool_size=1).pymongo_test
db.test.remove({})
db.test.save({"i": 1})
with self.client.start_request():
db.test.remove({})
db.test.save({"i": 1})
db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0)
self.assertTrue(db.last_status()["updatedExisting"])
db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0)
self.assertTrue(db.last_status()["updatedExisting"])
db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0)
self.assertFalse(db.last_status()["updatedExisting"])
db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0)
self.assertFalse(db.last_status()["updatedExisting"])
def test_password_digest(self):
self.assertRaises(TypeError, auth._password_digest, 5)

View File

@ -369,28 +369,6 @@ class TestGridfs(IntegrationTest):
self.assertTrue(iterate_file(f))
def test_request(self):
c = self.db.connection
c.start_request()
n = 5
for i in range(n):
file = self.fs.new_file(filename="test")
file.write(b"hello")
file.close()
c.end_request()
self.assertEqual(
n,
self.db.fs.files.find({'filename':'test'}).count()
)
def test_gridfs_request(self):
self.assertFalse(self.db.connection.in_request())
self.fs.put(b"hello world")
# Request started and ended by put(), we're back to original state
self.assertFalse(self.db.connection.in_request())
def test_gridfs_lazy_connect(self):
with client_knobs(server_wait_time=0.01):
client = MongoClient('badhost', connect=False)

View File

@ -22,20 +22,17 @@ import threading
import time
from pymongo import MongoClient
from pymongo.errors import ConfigurationError, ConnectionFailure, \
ExceededMaxWaiters
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
DuplicateKeyError,
ExceededMaxWaiters)
from pymongo.server_selectors import writable_server_selector
sys.path[0:0] = [""]
from bson.py3compat import thread
from pymongo.pool import (NO_REQUEST,
NO_SOCKET_YET,
Pool,
PoolOptions,
SocketInfo, _closed)
from pymongo.pool import Pool, PoolOptions, _closed
from test import host, port, SkipTest, unittest, client_context
from test.utils import get_pool, one, rs_or_single_client
from test.utils import get_pool, one, rs_or_single_client, connected, joinall
@client_context.require_connection
@ -86,18 +83,18 @@ class SaveAndFind(MongoThread):
class Unique(MongoThread):
def run_mongo_thread(self):
for _ in range(N):
self.client.start_request()
self.db.unique.insert({}) # no error
self.client.end_request()
class NonUnique(MongoThread):
def run_mongo_thread(self):
for _ in range(N):
self.client.start_request()
self.db.unique.insert({"_id": "jesse"}, w=0)
assert self.db.error() is not None
self.client.end_request()
try:
self.db.unique.insert({"_id": "jesse"})
except DuplicateKeyError:
pass
else:
raise AssertionError("Should have raised DuplicateKeyError")
class Disconnect(MongoThread):
@ -106,19 +103,6 @@ class Disconnect(MongoThread):
self.client.disconnect()
class NoRequest(MongoThread):
def run_mongo_thread(self):
self.client.start_request()
errors = 0
for _ in range(N):
self.db.unique.insert({"_id": "jesse"}, w=0)
if not self.db.error():
errors += 1
self.client.end_request()
assert errors == 0
class SocketGetter(MongoThread):
"""Utility for _TestMaxOpenSockets and _TestWaitQueueMultiple"""
def __init__(self, client, pool):
@ -150,73 +134,6 @@ def run_cases(client, cases):
assert t.passed, "%s.run() threw an exception" % repr(t)
class CreateAndReleaseSocket(MongoThread):
"""Gets a socket, waits for all threads to reach rendezvous, and quits."""
class Rendezvous(object):
def __init__(self, nthreads):
self.nthreads = nthreads
self.nthreads_run = 0
self.lock = threading.Lock()
self.reset_ready()
def reset_ready(self):
self.ready = threading.Event()
def __init__(self, client, start_request, end_request, rendezvous):
super(CreateAndReleaseSocket, self).__init__(client)
self.start_request = start_request
self.end_request = end_request
self.rendezvous = rendezvous
def run_mongo_thread(self):
# Do an operation that requires a socket.
# TestPoolMaxSize uses this to spin up lots of threads requiring
# lots of simultaneous connections, to ensure that Pool obeys its
# max_size configuration and closes extra sockets as they're returned.
for i in range(self.start_request):
self.client.start_request()
# Use a socket
self.client[DB].test.find_one()
# Don't finish until all threads reach this point
r = self.rendezvous
r.lock.acquire()
r.nthreads_run += 1
if r.nthreads_run == r.nthreads:
# Everyone's here, let them finish.
r.ready.set()
r.lock.release()
else:
r.lock.release()
r.ready.wait(30) # Wait thirty seconds....
assert r.ready.isSet(), "Rendezvous timed out"
for i in range(self.end_request):
self.client.end_request()
class CreateAndReleaseSocketNoRendezvous(MongoThread):
"""Gets a socket and quits. No synchronization with other threads."""
def __init__(self, client, start_request, end_request):
super(CreateAndReleaseSocketNoRendezvous, self).__init__(client)
self.start_request = start_request
self.end_request = end_request
def run_mongo_thread(self):
# Do an operation that requires a socket.
# TestPoolMaxSize uses this to spin up lots of threads requiring
# lots of simultaneous connections, to ensure that Pool obeys its
# max_size configuration and closes extra sockets as they're returned.
for i in range(self.start_request):
self.client.start_request()
# Use a socket
self.client[DB].test.find_one()
for i in range(self.end_request):
self.client.end_request()
class _TestPoolingBase(unittest.TestCase):
"""Base class for all connection-pool tests."""
@ -231,40 +148,6 @@ class _TestPoolingBase(unittest.TestCase):
def create_pool(self, pair=(host, port), *args, **kwargs):
return Pool(pair, PoolOptions(*args, **kwargs))
def assert_no_request(self):
try:
server = self.c._topology.select_server(
writable_server_selector,
server_wait_time=0)
self.assertEqual(NO_REQUEST, server.pool._get_request_state())
except ConnectionFailure:
# Success: we're asserting that we're not in a request, but there's
# no pool at all so the assertion is true.
pass
def assert_request_without_socket(self):
self.assertEqual(NO_SOCKET_YET, get_pool(self.c)._get_request_state())
def assert_request_with_socket(self):
self.assertTrue(isinstance(
get_pool(self.c)._get_request_state(), SocketInfo))
def assert_pool_size(self, pool_size):
if pool_size == 0:
try:
server = self.c._topology.select_server(
writable_server_selector,
server_wait_time=0)
self.assertEqual(0, len(server.pool.sockets))
except ConnectionFailure:
# Success: we're asserting that pool size is 0, and there's no
# pool at all so the assertion is true.
pass
else:
self.assertEqual(pool_size, len(get_pool(self.c).sockets))
class TestPooling(_TestPoolingBase):
def test_max_pool_size_validation(self):
@ -280,95 +163,11 @@ class TestPooling(_TestPoolingBase):
self.assertEqual(c.max_pool_size, 100)
def test_no_disconnect(self):
run_cases(self.c, [NoRequest, NonUnique, Unique, SaveAndFind])
def test_simple_disconnect(self):
# MongoClient just created, expect 1 free socket.
self.assert_pool_size(1)
self.assert_no_request()
self.c.start_request()
self.assert_request_without_socket()
cursor = self.c[DB].stuff.find()
# Cursor hasn't actually caused a request yet, so there's still 1 free
# socket.
self.assert_pool_size(1)
self.assert_request_without_socket()
# Actually make a request to server, triggering a socket to be
# allocated to the request.
list(cursor)
self.assert_pool_size(0)
self.assert_request_with_socket()
# Pool returns to its original state
self.c.end_request()
self.assert_no_request()
self.assert_pool_size(1)
self.c.disconnect()
self.assert_pool_size(0)
self.assert_no_request()
run_cases(self.c, [NonUnique, Unique, SaveAndFind])
def test_disconnect(self):
run_cases(self.c, [SaveAndFind, Disconnect, Unique])
def test_request(self):
# Check that Pool gives two different sockets in two calls to
# get_socket().
cx_pool = self.create_pool(
pair=(host, port),
max_pool_size=10,
connect_timeout=1000,
socket_timeout=1000)
sock0 = cx_pool.get_socket({}, 0, 0)
sock1 = cx_pool.get_socket({}, 0, 0)
self.assertNotEqual(sock0, sock1)
# Now in a request, we'll get the same socket both times
cx_pool.start_request()
sock2 = cx_pool.get_socket({}, 0, 0)
sock3 = cx_pool.get_socket({}, 0, 0)
self.assertEqual(sock2, sock3)
# Pool didn't keep reference to sock0 or sock1; sock2 and 3 are new
self.assertNotEqual(sock0, sock2)
self.assertNotEqual(sock1, sock2)
# Return the request sock to pool
cx_pool.end_request()
sock4 = cx_pool.get_socket({}, 0, 0)
sock5 = cx_pool.get_socket({}, 0, 0)
# Not in a request any more, we get different sockets
self.assertNotEqual(sock4, sock5)
# end_request() returned sock2 to pool
self.assertEqual(sock4, sock2)
for s in [sock0, sock1, sock2, sock3, sock4, sock5]:
s.close()
def test_reset_and_request(self):
# reset() is called after a fork, or after a socket error. Ensure that
# a new request is begun if a request was in progress when the reset()
# occurred, otherwise no request is begun.
p = self.create_pool(max_pool_size=10)
self.assertFalse(p.in_request())
p.start_request()
self.assertTrue(p.in_request())
p.reset()
self.assertTrue(p.in_request())
p.end_request()
self.assertFalse(p.in_request())
p.reset()
self.assertFalse(p.in_request())
def test_pool_reuses_open_socket(self):
# Test Pool's _check_closed() method doesn't close a healthy socket.
cx_pool = self.create_pool(max_pool_size=10)
@ -399,236 +198,6 @@ class TestPooling(_TestPoolingBase):
self.assertEqual(1, len(cx_pool.sockets))
def test_pool_removes_dead_request_socket_after_check(self):
# Test that Pool keeps request going even if a socket dies in request.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
cx_pool.start_request()
# Get the request socket.
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(sock_info, cx_pool._get_request_state())
sock_info.sock.close()
# Although the request socket died, we're still in a request with a
# new socket.
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertTrue(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
self.assertEqual(new_sock_info, cx_pool._get_request_state())
self.assertEqual(0, len(cx_pool.sockets))
cx_pool.end_request()
self.assertEqual(1, len(cx_pool.sockets))
def test_pool_removes_dead_request_socket(self):
# Test that Pool keeps request going even if a socket dies in request.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool.start_request()
# Get the request socket
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(sock_info, cx_pool._get_request_state())
# Unlike in test_pool_removes_dead_request_socket_after_check, we
# set sock_info.closed and *don't* wait for it to be checked.
sock_info.close()
# Although the request socket died, we're still in a request with a
# new socket
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertTrue(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
self.assertEqual(new_sock_info, cx_pool._get_request_state())
self.assertEqual(0, len(cx_pool.sockets))
cx_pool.end_request()
self.assertEqual(1, len(cx_pool.sockets))
def test_pool_removes_dead_socket_after_request(self):
# Test that Pool handles a socket dying that *used* to be the request
# socket.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
cx_pool.start_request()
# Get the request socket.
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(sock_info, cx_pool._get_request_state())
# End request.
cx_pool.end_request()
self.assertEqual(1, len(cx_pool.sockets))
# Kill old request socket.
sock_info.sock.close()
# Dead socket detected and removed.
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertFalse(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(0, len(cx_pool.sockets))
self.assertFalse(_closed(new_sock_info.sock))
self.assertEqual(1, len(cx_pool.sockets))
def test_dead_request_socket_with_max_size(self):
# When a pool replaces a dead request socket, the semaphore it uses
# to enforce max_size should remain unaffected.
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
cx_pool._check_interval_seconds = 0 # Always check.
cx_pool.start_request()
# Get and close the request socket.
with cx_pool.get_socket({}, 0, 0) as request_sock_info:
request_sock_info.sock.close()
# Detects closed socket and creates new one, semaphore value still 0.
with cx_pool.get_socket({}, 0, 0) as request_sock_info_2:
self.assertNotEqual(request_sock_info, request_sock_info_2)
cx_pool.end_request()
# Semaphore value now 1; we can get a socket.
sock = cx_pool.get_socket({}, 0, 0)
sock.close()
def test_socket_reclamation(self):
if sys.platform.startswith('java'):
raise SkipTest("Jython can't do socket reclamation")
# Check that if a thread starts a request and dies without ending
# the request, that the socket is reclaimed into the pool.
cx_pool = self.create_pool(
max_pool_size=10,
connect_timeout=1000,
socket_timeout=1000)
self.assertEqual(0, len(cx_pool.sockets))
lock = None
the_sock = [None]
def leak_request():
self.assertEqual(NO_REQUEST, cx_pool._get_request_state())
cx_pool.start_request()
self.assertEqual(NO_SOCKET_YET, cx_pool._get_request_state())
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(sock_info, cx_pool._get_request_state())
the_sock[0] = id(sock_info.sock)
lock.release()
lock = thread.allocate_lock()
lock.acquire()
# Start a thread WITHOUT a threading.Thread - important to test that
# Pool can deal with primitive threads.
thread.start_new_thread(leak_request, ())
# Join thread
acquired = lock.acquire()
self.assertTrue(acquired, "Thread is hung")
# Make sure thread is really gone
time.sleep(1)
if 'PyPy' in sys.version:
gc.collect()
# Access the thread local from the main thread to trigger the
# ThreadVigil's delete callback, returning the request socket to
# the pool.
# In Python 2.7.0 and lesser, a dead thread's locals are deleted
# and those locals' weakref callbacks are fired only when another
# thread accesses the locals and finds the thread state is stale,
# see http://bugs.python.org/issue1868. Accessing the thread
# local from the main thread is a necessary part of this test, and
# realistic: in a multithreaded web server a new thread will access
# Pool._ident._local soon after an old thread has died.
cx_pool._ident.get()
# Pool reclaimed the socket
self.assertEqual(1, len(cx_pool.sockets))
self.assertEqual(the_sock[0], id(one(cx_pool.sockets).sock))
self.assertEqual(0, len(cx_pool._tid_to_sock))
def test_request_with_fork(self):
if sys.platform == "win32":
raise SkipTest("Can't test forking on Windows")
try:
from multiprocessing import Process, Pipe
except ImportError:
raise SkipTest("No multiprocessing module")
coll = self.c.pymongo_test.test
coll.remove()
coll.insert({'_id': 1})
coll.find_one()
self.assert_pool_size(1)
self.c.start_request()
self.assert_pool_size(1)
coll.find_one()
self.assert_pool_size(0)
self.assert_request_with_socket()
def f(pipe):
# We can still query server without error
self.assertEqual({'_id':1}, coll.find_one())
# Pool has detected that we forked, but resumed the request
self.assert_request_with_socket()
self.assert_pool_size(0)
pipe.send("success")
parent_conn, child_conn = Pipe()
p = Process(target=f, args=(child_conn,))
p.start()
p.join(30)
p.terminate()
child_conn.close()
self.assertEqual("success", parent_conn.recv())
def test_primitive_thread(self):
p = self.create_pool(max_pool_size=10)
# Test that start/end_request work with a thread begun from thread
# module, rather than threading module
lock = thread.allocate_lock()
lock.acquire()
sock_ids = []
def run_in_request():
p.start_request()
sock0 = p.get_socket({}, 0, 0)
sock1 = p.get_socket({}, 0, 0)
sock_ids.extend([id(sock0), id(sock1)])
p.maybe_return_socket(sock0)
p.maybe_return_socket(sock1)
p.end_request()
lock.release()
thread.start_new_thread(run_in_request, ())
# Join thread
acquired = False
for i in range(30):
time.sleep(0.5)
acquired = lock.acquire(0)
if acquired:
break
self.assertTrue(acquired, "Thread is hung")
self.assertEqual(sock_ids[0], sock_ids[1])
def test_pool_with_fork(self):
# Test that separate MongoClients have separate Pools, and that the
# driver can create a new MongoClient after forking
@ -770,204 +339,61 @@ class TestPooling(_TestPoolingBase):
class TestPoolMaxSize(_TestPoolingBase):
"""Keep right number of sockets in various start/end_request sequences.
"""
def _test_max_pool_size(
self, start_request, end_request, max_pool_size=4, nthreads=10):
"""Start `nthreads` threads. Each calls start_request `start_request`
times, then find_one and waits at a barrier; once all reach the barrier
each calls end_request `end_request` times. The test asserts that the
pool ends with min(max_pool_size, nthreads) sockets or, if
start_request wasn't called, at least one socket.
This tests both max_pool_size enforcement and that leaked request
sockets are eventually returned to the pool when their threads end.
You may need to increase ulimit -n on Mac.
"""
if start_request:
if max_pool_size is not None and max_pool_size < nthreads:
raise AssertionError("Deadlock")
c = rs_or_single_client(max_pool_size=max_pool_size)
rendezvous = CreateAndReleaseSocket.Rendezvous(nthreads)
threads = []
for i in range(nthreads):
t = CreateAndReleaseSocket(
c, start_request, end_request, rendezvous)
threads.append(t)
for t in threads:
t.start()
if 'PyPy' in sys.version:
# With PyPy we need to kick off the gc whenever the threads hit the
# rendezvous since nthreads > max_pool_size.
gc_collect_until_done(threads)
else:
for t in threads:
t.join()
# join() returns before the thread state is cleared; give it time.
time.sleep(1)
for t in threads:
self.assertTrue(t.passed)
# Socket-reclamation doesn't work in Jython
if not sys.platform.startswith('java'):
cx_pool = get_pool(c)
# Socket-reclamation depends on timely garbage-collection
if 'PyPy' in sys.version:
gc.collect()
if start_request:
# Trigger final cleanup in Python <= 2.7.0.
cx_pool._ident.get()
expected_idle = min(max_pool_size, nthreads)
message = (
'%d idle sockets (expected %d) and %d request sockets'
' (expected 0)' % (
len(cx_pool.sockets), expected_idle,
len(cx_pool._tid_to_sock)))
self.assertEqual(
expected_idle, len(cx_pool.sockets), message)
else:
# Without calling start_request(), threads can safely share
# sockets; the number running concurrently, and hence the
# number of sockets needed, is between 1 and 10, depending
# on thread-scheduling.
self.assertTrue(len(cx_pool.sockets) >= 1)
# thread.join completes slightly *before* thread locals are
# cleaned up, so wait up to 5 seconds for them.
time.sleep(0.1)
cx_pool._ident.get()
start = time.time()
while (
not cx_pool.sockets
and cx_pool._socket_semaphore.counter < max_pool_size
and (time.time() - start) < 5
):
time.sleep(0.1)
cx_pool._ident.get()
if max_pool_size is not None:
self.assertEqual(
max_pool_size,
cx_pool._socket_semaphore.counter)
self.assertEqual(0, len(cx_pool._tid_to_sock))
def _test_max_pool_size_no_rendezvous(self, start_request, end_request):
max_pool_size = 5
c = rs_or_single_client(max_pool_size=max_pool_size)
def test_max_pool_size(self):
max_pool_size = 4
c = connected(rs_or_single_client(max_pool_size=max_pool_size))
cx_pool = get_pool(c)
# nthreads had better be much larger than max_pool_size to ensure that
# max_pool_size sockets are actually required at some point in this
# test's execution.
nthreads = 10
if (sys.platform.startswith('java')
and start_request > end_request
and nthreads > max_pool_size):
# Since Jython can't reclaim the socket and release the semaphore
# after a thread leaks a request, we'll exhaust the semaphore and
# deadlock.
raise SkipTest("Jython can't do socket reclamation")
threads = []
for i in range(nthreads):
t = CreateAndReleaseSocketNoRendezvous(
c, start_request, end_request)
threads.append(t)
lock = threading.Lock()
self.n_passed = 0
for t in threads:
def f():
for _ in range(N):
c[DB].test.find_one()
assert len(cx_pool.sockets) <= max_pool_size
with lock:
self.n_passed += 1
for i in range(nthreads):
t = threading.Thread(target=f)
threads.append(t)
t.start()
if 'PyPy' in sys.version:
# With PyPy we need to kick off the gc whenever the threads hit the
# rendezvous since nthreads > max_pool_size.
gc_collect_until_done(threads)
else:
for t in threads:
t.join()
for t in threads:
self.assertTrue(t.passed)
cx_pool = get_pool(c)
# Socket-reclamation depends on timely garbage-collection
if 'PyPy' in sys.version:
gc.collect()
# thread.join completes slightly *before* thread locals are
# cleaned up, so wait up to 5 seconds for them.
time.sleep(0.1)
cx_pool._ident.get()
start = time.time()
while (
not cx_pool.sockets
and cx_pool._socket_semaphore.counter < max_pool_size
and (time.time() - start) < 5
):
time.sleep(0.1)
cx_pool._ident.get()
self.assertTrue(len(cx_pool.sockets) >= 1)
joinall(threads)
self.assertEqual(nthreads, self.n_passed)
self.assertTrue(len(cx_pool.sockets) > 1)
self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter)
def test_max_pool_size(self):
self._test_max_pool_size(
start_request=0, end_request=0, nthreads=10, max_pool_size=4)
def test_max_pool_size_none(self):
self._test_max_pool_size(
start_request=0, end_request=0, nthreads=10, max_pool_size=None)
c = connected(rs_or_single_client(max_pool_size=None))
cx_pool = get_pool(c)
def test_max_pool_size_with_request(self):
self._test_max_pool_size(
start_request=1, end_request=1, nthreads=10, max_pool_size=10)
nthreads = 10
threads = []
lock = threading.Lock()
self.n_passed = 0
def test_max_pool_size_with_multiple_request(self):
self._test_max_pool_size(
start_request=10, end_request=10, nthreads=10, max_pool_size=10)
def f():
for _ in range(N):
c[DB].test.find_one()
def test_max_pool_size_with_redundant_request(self):
self._test_max_pool_size(
start_request=2, end_request=1, nthreads=10, max_pool_size=10)
with lock:
self.n_passed += 1
def test_max_pool_size_with_redundant_request2(self):
self._test_max_pool_size(
start_request=20, end_request=1, nthreads=10, max_pool_size=10)
for i in range(nthreads):
t = threading.Thread(target=f)
threads.append(t)
t.start()
def test_max_pool_size_with_redundant_request_no_rendezvous(self):
self._test_max_pool_size_no_rendezvous(2, 1)
def test_max_pool_size_with_redundant_request_no_rendezvous2(self):
self._test_max_pool_size_no_rendezvous(20, 1)
def test_max_pool_size_with_leaked_request(self):
# Call start_request() but not end_request() -- when threads die, they
# should return their request sockets to the pool.
self._test_max_pool_size(
start_request=1, end_request=0, nthreads=10, max_pool_size=10)
def test_max_pool_size_with_leaked_request_no_rendezvous(self):
self._test_max_pool_size_no_rendezvous(1, 0)
def test_max_pool_size_with_end_request_only(self):
# Call end_request() but not start_request()
self._test_max_pool_size(0, 1)
joinall(threads)
self.assertEqual(nthreads, self.n_passed)
self.assertTrue(len(cx_pool.sockets) >= 1)
def test_max_pool_size_with_connection_failure(self):
# The pool acquires its semaphore before attempting to connect; ensure

View File

@ -38,17 +38,19 @@ from test import (client_context,
IntegrationTest,
pair,
port,
SkipTest,
unittest,
db_pwd,
db_user,
MockClientTest)
from test.pymongo_mocks import MockClient
from test.utils import (
delay, assertReadFrom, assertReadFromAll, ignore_deprecations,
read_from_which_host, assertRaisesExactly, TestRequestMixin, get_pools,
connected, wait_until, single_client, rs_or_single_client, one,
rs_client)
from test.utils import (assertRaisesExactly,
connected,
delay,
ignore_deprecations,
one,
rs_client,
single_client,
wait_until)
from test.version import Version
@ -79,7 +81,7 @@ class TestReplicaSetClientBase(IntegrationTest):
if m['stateStr'] == 'SECONDARY')
class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
class TestReplicaSetClient(TestReplicaSetClientBase):
def test_deprecated(self):
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
@ -288,88 +290,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
def test_kill_cursor_explicit_secondary(self):
self._test_kill_cursor_explicit(ReadPreference.SECONDARY)
def test_nested_request(self):
client = rs_or_single_client()
connected(client)
client.start_request()
try:
pools = get_pools(client)
self.assertTrue(client.in_request())
# Start and end request - we're still in "outer" original request
client.start_request()
self.assertInRequestAndSameSock(client, pools)
client.end_request()
self.assertInRequestAndSameSock(client, pools)
# Double-nesting
client.start_request()
client.start_request()
for pool in pools:
# Client only called start_request() once per pool, although
# its own counter is 3.
self.assertEqual(1, pool._request_counter.get())
client.end_request()
client.end_request()
self.assertInRequestAndSameSock(client, pools)
for pool in pools:
self.assertEqual(1, pool._request_counter.get())
# Finally, end original request
client.end_request()
for pool in pools:
self.assertFalse(pool.in_request())
self.assertNotInRequestAndDifferentSock(client, pools)
finally:
client.close()
def test_pinned_member(self):
raise SkipTest("Secondary pinning not implemented in PyMongo 3")
latency = 1000 * 1000
client = rs_client(secondaryacceptablelatencyms=latency)
host = read_from_which_host(client, ReadPreference.SECONDARY)
self.assertTrue(host in client.secondaries)
# No pinning since we're not in a request
assertReadFromAll(
self, client, client.secondaries,
ReadPreference.SECONDARY, None)
assertReadFromAll(
self, client, list(client.secondaries) + [client.primary],
ReadPreference.NEAREST, None)
client.start_request()
host = read_from_which_host(client, ReadPreference.SECONDARY)
self.assertTrue(host in client.secondaries)
assertReadFrom(self, client, host, ReadPreference.SECONDARY)
# Changing any part of read preference (mode, tag_sets)
# unpins the current host and pins to a new one
primary = client.primary
assertReadFrom(self, client, primary, ReadPreference.PRIMARY_PREFERRED)
host = read_from_which_host(client, ReadPreference.NEAREST)
assertReadFrom(self, client, host, ReadPreference.NEAREST)
assertReadFrom(self, client, primary, ReadPreference.PRIMARY_PREFERRED)
host = read_from_which_host(client, ReadPreference.SECONDARY_PREFERRED)
self.assertTrue(host in client.secondaries)
assertReadFrom(self, client, host, ReadPreference.SECONDARY_PREFERRED)
# Unpin
client.end_request()
assertReadFromAll(
self, client, list(client.secondaries) + [client.primary],
ReadPreference.NEAREST, None)
def test_not_master_error(self):
secondary_address = one(self.secondaries)
direct_client = single_client(*secondary_address)

View File

@ -1,158 +0,0 @@
# Copyright 2012-2014 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the thread_util module."""
import gc
import sys
import threading
import time
from functools import partial
sys.path[0:0] = [""]
from pymongo import thread_util
from test import SkipTest, unittest
from test.utils import RendezvousThread
class TestIdent(unittest.TestCase):
def test_thread_ident(self):
# 1. Store main thread's id.
# 2. Start 2 child threads.
# 3. Store their values for Ident.get().
# 4. Children reach rendezvous point.
# 5. Children call Ident.watch().
# 6. One of the children calls Ident.unwatch().
# 7. Children terminate.
# 8. Assert that children got different ids from each other and from
# main, and assert watched child's callback was executed, and that
# unwatched child's callback was not.
if 'java' in sys.platform:
raise SkipTest("Can't rely on weakref callbacks in Jython")
ident = thread_util.ThreadIdent()
ids = set([ident.get()])
unwatched_id = []
done = set([ident.get()]) # Start with main thread's id.
died = set()
class WatchedThread(RendezvousThread):
def __init__(self, ident, state):
super(WatchedThread, self).__init__(state)
self._my_ident = ident
def before_rendezvous(self):
self.my_id = self._my_ident.get()
ids.add(self.my_id)
def after_rendezvous(self):
assert not self._my_ident.watching()
self._my_ident.watch(lambda ref: died.add(self.my_id))
assert self._my_ident.watching()
done.add(self.my_id)
class UnwatchedThread(WatchedThread):
def before_rendezvous(self):
super(UnwatchedThread, self).before_rendezvous()
unwatched_id.append(self.my_id)
def after_rendezvous(self):
super(UnwatchedThread, self).after_rendezvous()
self._my_ident.unwatch(self.my_id)
assert not self._my_ident.watching()
state = RendezvousThread.create_shared_state(2)
t_watched = WatchedThread(ident, state)
t_watched.start()
t_unwatched = UnwatchedThread(ident, state)
t_unwatched.start()
RendezvousThread.wait_for_rendezvous(state)
RendezvousThread.resume_after_rendezvous(state)
t_watched.join()
t_unwatched.join()
self.assertTrue(t_watched.passed)
self.assertTrue(t_unwatched.passed)
# Remove references, let weakref callbacks run
del t_watched
del t_unwatched
# Trigger final cleanup in Python <= 2.7.0.
# http://bugs.python.org/issue1868
ident.get()
self.assertEqual(3, len(ids))
self.assertEqual(3, len(done))
# Make sure thread is really gone
slept = 0
while not died and slept < 10:
time.sleep(1)
gc.collect()
slept += 1
self.assertEqual(1, len(died))
self.assertFalse(unwatched_id[0] in died)
class TestCounter(unittest.TestCase):
def test_counter(self):
counter = thread_util.Counter()
self.assertEqual(0, counter.dec())
self.assertEqual(0, counter.get())
self.assertEqual(0, counter.dec())
self.assertEqual(0, counter.get())
done = set()
def f(n):
for i in range(n):
self.assertEqual(i, counter.get())
self.assertEqual(i + 1, counter.inc())
for i in range(n, 0, -1):
self.assertEqual(i, counter.get())
self.assertEqual(i - 1, counter.dec())
self.assertEqual(0, counter.get())
# Extra decrements have no effect
self.assertEqual(0, counter.dec())
self.assertEqual(0, counter.get())
self.assertEqual(0, counter.dec())
self.assertEqual(0, counter.get())
done.add(n)
threads = [threading.Thread(target=partial(f, i))
for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(10, len(done))
if __name__ == "__main__":
unittest.main()

View File

@ -17,14 +17,9 @@
import threading
from test import unittest, client_context, IntegrationTest, db_user, db_pwd
from test.utils import (frequent_thread_switches,
joinall,
RendezvousThread,
rs_or_single_client,
rs_or_single_client_noauth)
from test.utils import get_pool
from pymongo.pool import SocketInfo, _closed
from pymongo.errors import AutoReconnect, OperationFailure
from test.utils import rs_or_single_client_noauth
from test.utils import frequent_thread_switches, joinall
from pymongo.errors import OperationFailure
@client_context.require_connection
@ -130,57 +125,6 @@ class Disconnect(threading.Thread):
self.passed = True
class IgnoreAutoReconnect(threading.Thread):
def __init__(self, collection, n):
threading.Thread.__init__(self)
self.c = collection
self.n = n
self.setDaemon(True)
def run(self):
for _ in range(self.n):
try:
self.c.find_one()
except AutoReconnect:
pass
class FindPauseFind(RendezvousThread):
"""See test_server_disconnect() for details"""
def __init__(self, collection, state):
"""Params:
`collection`: A collection for testing
`state`: A shared state object from RendezvousThread.shared_state()
"""
super(FindPauseFind, self).__init__(state)
self.collection = collection
def before_rendezvous(self):
# acquire a socket
client = self.collection.database.connection
client.start_request()
list(self.collection.find())
pool = get_pool(client)
socket_info = pool._get_request_state()
assert isinstance(socket_info, SocketInfo)
self.request_sock = socket_info.sock
assert not _closed(self.request_sock)
def after_rendezvous(self):
# test_server_disconnect() has closed this socket, but that's ok
# because it's not our request socket anymore
assert _closed(self.request_sock)
# if disconnect() properly replaced the pool, then this won't raise
# AutoReconnect because it will acquire a new socket
list(self.collection.find())
assert self.collection.database.connection.in_request()
pool = get_pool(self.collection.database.connection)
assert self.request_sock != pool._get_request_state().sock
class TestThreads(IntegrationTest):
def setUp(self):
self.db = client_context.rs_or_standalone_client.pymongo_test
@ -236,71 +180,6 @@ class TestThreads(IntegrationTest):
error.join()
okay.join()
def test_server_disconnect(self):
# PYTHON-345, we need to make sure that threads' request sockets are
# closed by disconnect().
#
# 1. Create a client and start a request.
# 2. Start N threads and do a find() in each to get a request socket
# 3. Pause all threads
# 4. In the main thread close all sockets, including threads' request
# sockets
# 5. In main thread, do a find(), which raises AutoReconnect and resets
# pool
# 6. Resume all threads, do a find() in them
#
# If we've fixed PYTHON-345, then only one AutoReconnect is raised,
# and all the threads get new request sockets.
cx = rs_or_single_client()
cx.start_request()
collection = cx.db.pymongo_test
# acquire a request socket for the main thread
collection.find_one()
pool = get_pool(collection.database.connection)
socket_info = pool._get_request_state()
assert isinstance(socket_info, SocketInfo)
request_sock = socket_info.sock
state = FindPauseFind.create_shared_state(nthreads=10)
threads = [
FindPauseFind(collection, state)
for _ in range(state.nthreads)
]
# Each thread does a find(), thus acquiring a request socket
for t in threads:
t.start()
# Wait for the threads to reach the rendezvous
FindPauseFind.wait_for_rendezvous(state)
try:
# Simulate an event that closes all sockets, e.g. primary stepdown
for t in threads:
t.request_sock.close()
# Finally, ensure the main thread's socket's last_checkout is
# updated:
collection.find_one()
# ... and close it:
request_sock.close()
# Doing an operation on the client raises an AutoReconnect and
# resets the pool behind the scenes
self.assertRaises(AutoReconnect, collection.find_one)
finally:
# Let threads do a second find()
FindPauseFind.resume_after_rendezvous(state)
joinall(threads)
for t in threads:
self.assertTrue(t.passed, "%s threw exception" % t)
def test_client_disconnect(self):
self.db.drop_collection("test")
for i in range(1000):

View File

@ -26,7 +26,6 @@ from functools import partial
from pymongo import MongoClient
from pymongo.errors import AutoReconnect, OperationFailure
from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo
from pymongo.server_selectors import (any_server_selector,
writable_server_selector)
from test import (client_context,
@ -280,104 +279,6 @@ def ignore_deprecations():
yield
class RendezvousThread(threading.Thread):
"""A thread that starts and pauses at a rendezvous point before resuming.
To be used in tests that must ensure that N threads are all alive
simultaneously, regardless of thread-scheduling's vagaries.
1. Write a subclass of RendezvousThread and override before_rendezvous
and / or after_rendezvous.
2. Create a state with RendezvousThread.shared_state(N)
3. Start N of your subclassed RendezvousThreads, passing the state to each
one's __init__
4. In the main thread, call RendezvousThread.wait_for_rendezvous
5. Test whatever you need to test while threads are paused at rendezvous
point
6. In main thread, call RendezvousThread.resume_after_rendezvous
7. Join all threads from main thread
8. Assert that all threads' "passed" attribute is True
9. Test post-conditions
"""
class RendezvousState(object):
def __init__(self, nthreads):
# Number of threads total
self.nthreads = nthreads
# Number of threads that have arrived at rendezvous point
self.arrived_threads = 0
self.arrived_threads_lock = threading.Lock()
# Set when all threads reach rendezvous
self.ev_arrived = threading.Event()
# Set by resume_after_rendezvous() so threads can continue.
self.ev_resume = threading.Event()
@classmethod
def create_shared_state(cls, nthreads):
return RendezvousThread.RendezvousState(nthreads)
def before_rendezvous(self):
"""Overridable: Do this before the rendezvous"""
pass
def after_rendezvous(self):
"""Overridable: Do this after the rendezvous. If it throws no exception,
`passed` is set to True
"""
pass
@classmethod
def wait_for_rendezvous(cls, state):
"""Wait for all threads to reach rendezvous and pause there"""
state.ev_arrived.wait(10)
assert state.ev_arrived.isSet(), "Thread timeout"
assert state.nthreads == state.arrived_threads
@classmethod
def resume_after_rendezvous(cls, state):
"""Tell all the paused threads to continue"""
state.ev_resume.set()
def __init__(self, state):
"""Params:
`state`: A shared state object from RendezvousThread.shared_state()
"""
super(RendezvousThread, self).__init__()
self.state = state
self.passed = False
# If this thread fails to terminate, don't hang the whole program
self.setDaemon(True)
def _rendezvous(self):
"""Pause until all threads arrive here"""
s = self.state
s.arrived_threads_lock.acquire()
s.arrived_threads += 1
if s.arrived_threads == s.nthreads:
s.arrived_threads_lock.release()
s.ev_arrived.set()
else:
s.arrived_threads_lock.release()
s.ev_arrived.wait()
def run(self):
try:
self.before_rendezvous()
finally:
self._rendezvous()
# all threads have passed the rendezvous, wait for
# resume_after_rendezvous()
self.state.ev_resume.wait()
self.after_rendezvous()
self.passed = True
def read_from_which_host(
client,
pref,
@ -467,51 +368,6 @@ def get_pools(client):
client._get_topology().select_servers(any_server_selector)]
class TestRequestMixin(object):
"""Inherit from this class and from unittest.TestCase to get some
convenient methods for testing connection pools and requests
"""
def assertSameSock(self, pool):
sock_info0 = pool.get_socket({}, 0, 0)
sock_info1 = pool.get_socket({}, 0, 0)
self.assertEqual(sock_info0, sock_info1)
pool.maybe_return_socket(sock_info0)
pool.maybe_return_socket(sock_info1)
def assertDifferentSock(self, pool):
sock_info0 = pool.get_socket({}, 0, 0)
sock_info1 = pool.get_socket({}, 0, 0)
self.assertNotEqual(sock_info0, sock_info1)
pool.maybe_return_socket(sock_info0)
pool.maybe_return_socket(sock_info1)
def assertNoRequest(self, pool):
self.assertEqual(NO_REQUEST, pool._get_request_state())
def assertNoSocketYet(self, pool):
self.assertEqual(NO_SOCKET_YET, pool._get_request_state())
def assertRequestSocket(self, pool):
self.assertTrue(isinstance(pool._get_request_state(), SocketInfo))
def assertInRequestAndSameSock(self, client, pools):
self.assertTrue(client.in_request())
if not isinstance(pools, list):
pools = [pools]
for pool in pools:
self.assertTrue(pool.in_request())
self.assertSameSock(pool)
def assertNotInRequestAndDifferentSock(self, client, pools):
self.assertFalse(client.in_request())
if not isinstance(pools, list):
pools = [pools]
for pool in pools:
self.assertFalse(pool.in_request())
self.assertDifferentSock(pool)
# Constants for run_threads and lazy_client_trial.
NTRIALS = 5
NTHREADS = 10