From e8be121a892408fc886ec6a78ec25a15c50ea91d Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 20 Nov 2014 16:48:03 -0500 Subject: [PATCH] PYTHON-785 Don't use requests in tests. --- test/test_auth.py | 33 +- test/test_bulk.py | 28 +- test/test_client.py | 104 +---- test/test_collection.py | 241 +++++------- test/test_database.py | 73 ++-- test/test_gridfs.py | 22 -- test/test_pooling.py | 676 +++----------------------------- test/test_replica_set_client.py | 98 +---- test/test_thread_util.py | 158 -------- test/test_threads.py | 127 +----- test/utils.py | 144 ------- 11 files changed, 232 insertions(+), 1472 deletions(-) delete mode 100644 test/test_thread_util.py diff --git a/test/test_auth.py b/test/test_auth.py index af0747bec..d965e1f37 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -31,6 +31,7 @@ from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple from pymongo.errors import OperationFailure from pymongo.read_preferences import ReadPreference from test import client_context, host, port, SkipTest, unittest, Version +from test.utils import delay # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS. GSSAPI_HOST = os.environ.get('GSSAPI_HOST') @@ -46,18 +47,22 @@ SASL_DB = os.environ.get('SASL_DB', '$external') class AutoAuthenticateThread(threading.Thread): """Used in testing threaded authentication. + + This does collection.find_one() with a 1-second delay to ensure it must + check out and authenticate multiple sockets from the pool concurrently. + + :Parameters: + `collection`: An auth-protected collection containing one document. """ - def __init__(self, database): + def __init__(self, collection): super(AutoAuthenticateThread, self).__init__() - self.database = database - self.success = True + self.collection = collection + self.success = False def run(self): - try: - self.database.command('dbstats') - except OperationFailure: - self.success = False + assert self.collection.find_one({'$where': delay(1)}) is not None + self.success = True class TestGSSAPI(unittest.TestCase): @@ -148,14 +153,20 @@ class TestGSSAPI(unittest.TestCase): def test_gssapi_threaded(self): client = MongoClient(GSSAPI_HOST) - # Make sure each thread uses a different socket. - client.start_request() self.assertTrue(client.test.authenticate(PRINCIPAL, mechanism='GSSAPI')) + # Need one document in the collection. AutoAuthenticateThread does + # collection.find_one with a 1-second delay, forcing it to check out + # multiple sockets from the pool concurrently, proving that + # auto-authentication works with GSSAPI. + collection = client.test.collection + collection.remove() + collection.insert({'_id': 1}) + threads = [] for _ in range(4): - threads.append(AutoAuthenticateThread(client.test)) + threads.append(AutoAuthenticateThread(collection)) for thread in threads: thread.start() for thread in threads: @@ -174,7 +185,7 @@ class TestGSSAPI(unittest.TestCase): threads = [] for _ in range(4): - threads.append(AutoAuthenticateThread(client.foo)) + threads.append(AutoAuthenticateThread(collection)) for thread in threads: thread.start() for thread in threads: diff --git a/test/test_bulk.py b/test/test_bulk.py index 8c896ac3c..b29f4285b 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -29,7 +29,7 @@ from test import (client_context, port, IntegrationTest, SkipTest) -from test.utils import oid_generated_on_client, remove_all_users +from test.utils import oid_generated_on_client, remove_all_users, wait_until class BulkTestBase(IntegrationTest): @@ -1115,9 +1115,9 @@ class TestBulkNoResults(BulkTestBase): batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) batch.find({'_id': 1}).remove_one() - with self.client.start_request(): - self.assertTrue(batch.execute({'w': 0}) is None) - self.assertEqual(2, self.coll.count()) + self.assertTrue(batch.execute({'w': 0}) is None) + wait_until(lambda: 2 == self.coll.count(), + 'insert 2 documents') def test_no_results_ordered_failure(self): @@ -1127,9 +1127,9 @@ class TestBulkNoResults(BulkTestBase): batch.insert({'_id': 2}) batch.insert({'_id': 1}) batch.find({'_id': 1}).remove_one() - with self.client.start_request(): - self.assertTrue(batch.execute({'w': 0}) is None) - self.assertEqual(3, self.coll.count()) + self.assertTrue(batch.execute({'w': 0}) is None) + wait_until(lambda: 3 == self.coll.count(), + 'insert 3 documents') def test_no_results_unordered_success(self): @@ -1138,9 +1138,9 @@ class TestBulkNoResults(BulkTestBase): batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) batch.find({'_id': 1}).remove_one() - with self.client.start_request(): - self.assertTrue(batch.execute({'w': 0}) is None) - self.assertEqual(2, self.coll.count()) + self.assertTrue(batch.execute({'w': 0}) is None) + wait_until(lambda: 2 == self.coll.count(), + 'insert 2 documents') def test_no_results_unordered_failure(self): @@ -1150,10 +1150,10 @@ class TestBulkNoResults(BulkTestBase): batch.insert({'_id': 2}) batch.insert({'_id': 1}) batch.find({'_id': 1}).remove_one() - with self.client.start_request(): - self.assertTrue(batch.execute({'w': 0}) is None) - self.assertEqual(2, self.coll.count()) - self.assertTrue(self.coll.find_one({'_id': 1}) is None) + self.assertTrue(batch.execute({'w': 0}) is None) + wait_until(lambda: 2 == self.coll.count(), + 'insert 2 documents') + self.assertTrue(self.coll.find_one({'_id': 1}) is None) class TestBulkAuthorization(BulkTestBase): diff --git a/test/test_client.py b/test/test_client.py index 2f3cbeb38..3f4697d04 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -18,7 +18,6 @@ import contextlib import datetime import multiprocessing import os -import threading import socket import sys import time @@ -32,7 +31,6 @@ from bson.son import SON from bson.tz_util import utc from pymongo.mongo_client import MongoClient from pymongo.database import Database -from pymongo.pool import SocketInfo from pymongo import auth, message from pymongo.errors import (AutoReconnect, ConfigurationError, @@ -63,7 +61,6 @@ from test.utils import (assertRaisesExactly, ignore_deprecations, remove_all_users, server_is_master_with_slave, - TestRequestMixin, get_pool, one, connected, @@ -74,7 +71,7 @@ from test.utils import (assertRaisesExactly, NTHREADS) -class ClientUnitTest(unittest.TestCase, TestRequestMixin): +class ClientUnitTest(unittest.TestCase): """MongoClient tests that don't require a server.""" @classmethod @@ -166,7 +163,7 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin): self.assertEqual(Database(c, 'foo'), c.get_default_database()) -class TestClient(IntegrationTest, TestRequestMixin): +class TestClient(IntegrationTest): def test_constants(self): # Set bad defaults. @@ -605,80 +602,6 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) self.assertEqual(0, len(get_pool(client).sockets)) - def test_with_start_request(self): - pool = get_pool(self.client) - - # No request started - self.assertNoRequest(pool) - self.assertDifferentSock(pool) - - # Start a request - request_context_mgr = self.client.start_request() - self.assertTrue( - isinstance(request_context_mgr, object) - ) - - self.assertNoSocketYet(pool) - self.assertSameSock(pool) - self.assertRequestSocket(pool) - - # End request - request_context_mgr.__exit__(None, None, None) - self.assertNoRequest(pool) - self.assertDifferentSock(pool) - - # Test the 'with' statement - with self.client.start_request() as request: - self.assertEqual(self.client, request.connection) - self.assertNoSocketYet(pool) - self.assertSameSock(pool) - self.assertRequestSocket(pool) - - # Request has ended - self.assertNoRequest(pool) - self.assertDifferentSock(pool) - - def test_request_threads(self): - client = self.client - pool = get_pool(client) - self.assertNotInRequestAndDifferentSock(client, pool) - - started_request, ended_request = threading.Event(), threading.Event() - checked_request = threading.Event() - thread_done = [False] - - # Starting a request in one thread doesn't put the other thread in a - # request - def f(): - self.assertNotInRequestAndDifferentSock(client, pool) - client.start_request() - self.assertInRequestAndSameSock(client, pool) - started_request.set() - checked_request.wait() - checked_request.clear() - self.assertInRequestAndSameSock(client, pool) - client.end_request() - self.assertNotInRequestAndDifferentSock(client, pool) - ended_request.set() - checked_request.wait() - thread_done[0] = True - - t = threading.Thread(target=f) - t.setDaemon(True) - t.start() - # It doesn't matter in what order the main thread or t initially get - # to started_request.set() / wait(); by waiting here we ensure that t - # has called client.start_request() before we assert on the next line. - started_request.wait() - self.assertNotInRequestAndDifferentSock(client, pool) - checked_request.set() - ended_request.wait() - self.assertNotInRequestAndDifferentSock(client, pool) - checked_request.set() - t.join() - self.assertNotInRequestAndDifferentSock(client, pool) - self.assertTrue(thread_done[0], "Thread didn't complete") - def test_interrupt_signal(self): if sys.platform.startswith('java'): # We can't figure out how to raise an exception on a thread that's @@ -724,7 +647,7 @@ class TestClient(IntegrationTest, TestRequestMixin): next(db.foo.find()) ) - def test_operation_failure_without_request(self): + def test_operation_failure(self): # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. pool = get_pool(self.client) @@ -741,27 +664,6 @@ class TestClient(IntegrationTest, TestRequestMixin): new_sock_info = next(iter(pool.sockets)) self.assertEqual(old_sock_info, new_sock_info) - def test_operation_failure_with_request(self): - # Ensure MongoClient doesn't close socket after it gets an error - # response to getLastError. PYTHON-395. - c = rs_or_single_client() - c.start_request() - pool = get_pool(c) - - # Pool reserves a socket for this thread. - c.pymongo_test.test.find_one() - self.assertTrue(isinstance(pool._get_request_state(), SocketInfo)) - - old_sock_info = pool._get_request_state() - c.pymongo_test.test.drop() - c.pymongo_test.test.insert({'_id': 'foo'}) - self.assertRaises( - OperationFailure, - c.pymongo_test.test.insert, {'_id': 'foo'}) - - # OperationFailure doesn't affect the request socket - self.assertEqual(old_sock_info, pool._get_request_state()) - def test_alive(self): self.assertTrue(self.client.alive()) diff --git a/test/test_collection.py b/test/test_collection.py index 606a3872d..918efb5ea 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -50,7 +50,7 @@ from pymongo.errors import (DocumentTooLarge, from test.test_client import IntegrationTest from test.utils import (is_mongos, joinall, enable_text_search, get_pool, oid_generated_on_client, one, ignore_deprecations, - rs_or_single_client) + rs_or_single_client, wait_until) from test import client_context, host, port, qcheck, unittest @@ -793,14 +793,15 @@ class TestCollection(IntegrationTest): db.drop_collection("test") db.test.ensure_index([('i', ASCENDING)], unique=True) - with db.connection.start_request(): - # No error - db.test.insert([{'i': i} for i in range(5, 10)], w=0) - db.test.remove() + # No error + db.test.insert([{'i': i} for i in range(5, 10)], w=0) + wait_until(lambda: 5 == db.test.count(), 'insert 5 documents') - # No error - db.test.insert([{'i': 1}] * 2, w=0) - self.assertEqual(1, db.test.count()) + db.test.remove() + + # No error + db.test.insert([{'i': 1}] * 2, w=0) + wait_until(lambda: 1 == db.test.count(), 'insert 1 document') self.assertRaises( DuplicateKeyError, @@ -811,20 +812,19 @@ class TestCollection(IntegrationTest): wc = db.write_concern db.write_concern = {"w": 0} try: - with db.connection.start_request(): - db.test.ensure_index([('i', ASCENDING)], unique=True) + db.test.ensure_index([('i', ASCENDING)], unique=True) - # No error - db.test.insert([{'i': 1}] * 2) - self.assertEqual(1, db.test.count()) + # No error. + db.test.insert([{'i': 1}] * 2) + wait_until(lambda: 1 == db.test.count(), 'insert 1 document') - # Implied safe + # Implied acknowledged. self.assertRaises( DuplicateKeyError, lambda: db.test.insert([{'i': 2}] * 2, fsync=True), ) - # Explicit safe + # Explicit acknowledged. self.assertRaises( DuplicateKeyError, lambda: db.test.insert([{'i': 2}] * 2, w=1), @@ -902,7 +902,7 @@ class TestCollection(IntegrationTest): self.db.test.find_one({'_id': 'explicit_id'})['hello'] ) - # Safe mode + # Acknowledged mode. self.db.test.create_index("hello", unique=True) # No exception, even though we duplicate the first doc's "hello" value self.db.test.save({'_id': 'explicit_id', 'hello': 'world'}, w=0) @@ -924,23 +924,19 @@ class TestCollection(IntegrationTest): def test_unique_index(self): db = self.db + db.drop_collection("test") + db.test.create_index("hello") - with self.client.start_request(): - db.drop_collection("test") - db.test.create_index("hello") + # No error. + db.test.save({"hello": "world"}) + db.test.save({"hello": "world"}) + db.drop_collection("test") + db.test.create_index("hello", unique=True) + + with self.assertRaises(DuplicateKeyError): db.test.save({"hello": "world"}) - db.test.save({"hello": "mike"}) db.test.save({"hello": "world"}) - self.assertFalse(db.error()) - - db.drop_collection("test") - db.test.create_index("hello", unique=True) - - db.test.save({"hello": "world"}) - db.test.save({"hello": "mike"}) - db.test.save({"hello": "world"}, w=0) - self.assertTrue(db.error()) def test_duplicate_key_error(self): db = self.db @@ -949,41 +945,17 @@ class TestCollection(IntegrationTest): db.test.create_index("x", unique=True) db.test.insert({"_id": 1, "x": 1}) - db.test.insert({"_id": 2, "x": 2}) - with db.connection.start_request(): - # No error - db.test.insert({"_id": 1, "x": 1}, w=0) - db.test.save({"_id": 1, "x": 1}, w=0) - db.test.insert({"_id": 2, "x": 2}, w=0) - db.test.save({"_id": 2, "x": 2}, w=0) + with self.assertRaises(DuplicateKeyError) as context: + db.test.insert({"x": 1}) - # But all those statements didn't do anything - self.assertEqual(2, db.test.count()) + self.assertIsNotNone(context.exception.details) - expected_error = OperationFailure - if client_context.version.at_least(1, 3): - expected_error = DuplicateKeyError + with self.assertRaises(DuplicateKeyError) as context: + db.test.save({"x": 1}) - self.assertRaises(expected_error, - db.test.insert, {"_id": 1}) - self.assertRaises(expected_error, - db.test.insert, {"x": 1}) - - self.assertRaises(expected_error, - db.test.save, {"x": 2}) - self.assertRaises(expected_error, - db.test.update, {"x": 1}, - {"$inc": {"x": 1}}) - - try: - db.test.insert({"_id": 1}) - except expected_error as exc: - # Just check that we set the error document. Fields - # vary by MongoDB version. - self.assertTrue(exc.details is not None) - else: - self.fail("%s was not raised" % (expected_error.__name__,)) + self.assertIsNotNone(context.exception.details) + self.assertEqual(1, db.test.count()) def test_wtimeout(self): # Ensure setting wtimeout doesn't disable write concern altogether. @@ -1006,33 +978,33 @@ class TestCollection(IntegrationTest): self.assertEqual(1, db.test.count()) docs = [] - docs.append({"_id": oid, "two": 2}) + docs.append({"_id": oid, "two": 2}) # Duplicate _id. docs.append({"three": 3}) docs.append({"four": 4}) docs.append({"five": 5}) - with self.client.start_request(): - db.test.insert(docs, manipulate=False, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(1, db.test.count()) + with self.assertRaises(DuplicateKeyError): + db.test.insert(docs, manipulate=False) - db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(4, db.test.count()) + self.assertEqual(1, db.test.count()) - db.drop_collection("test") - oid = db.test.insert({"_id": oid, "one": 1}, w=0) - self.assertEqual(1, db.test.count()) - docs[0].pop("_id") - docs[2]["_id"] = oid + with self.assertRaises(DuplicateKeyError): + db.test.insert(docs, manipulate=False, continue_on_error=True) - db.test.insert(docs, manipulate=False, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(3, db.test.count()) + self.assertEqual(4, db.test.count()) - db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(6, db.test.count()) + db.drop_collection("test") + oid = db.test.insert({"_id": oid, "one": 1}, w=0) + self.assertEqual(1, db.test.count()) + docs[0].pop("_id") + docs[2]["_id"] = oid + + with self.assertRaises(DuplicateKeyError): + db.test.insert(docs, manipulate=False) + + self.assertEqual(3, db.test.count()) + db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) + self.assertEqual(6, db.test.count()) def test_error_code(self): try: @@ -1062,17 +1034,14 @@ class TestCollection(IntegrationTest): self.assertRaises(DuplicateKeyError, db.test.insert, {"hello": {"a": 4, "b": 10}}) - def test_safe_insert(self): - with self.client.start_request(): - db = self.db - db.drop_collection("test") + def test_acknowledged_insert(self): + db = self.db + db.drop_collection("test") - a = {"hello": "world"} - db.test.insert(a) - db.test.insert(a, w=0) - self.assertTrue("E11000" in db.error()["err"]) - - self.assertRaises(OperationFailure, db.test.insert, a) + a = {"hello": "world"} + db.test.insert(a) + db.test.insert(a, w=0) + self.assertRaises(OperationFailure, db.test.insert, a) def test_update(self): db = self.db @@ -1159,36 +1128,25 @@ class TestCollection(IntegrationTest): self.assertEqual(1, db.test.count()) self.assertEqual(2, db.test.find_one()["count"]) - def test_safe_update(self): - with self.client.start_request(): - db = self.db - v113minus = client_context.version.at_least(1, 1, 3, -1) - v19 = client_context.version.at_least(1, 9) + def test_acknowledged_update(self): + db = self.db + db.drop_collection("test") + db.test.create_index("x", unique=True) - db.drop_collection("test") - db.test.create_index("x", unique=True) + db.test.insert({"x": 5}) + id = db.test.insert({"x": 4}) - db.test.insert({"x": 5}) - id = db.test.insert({"x": 4}) + self.assertEqual( + None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0)) - self.assertEqual( - None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0)) + self.assertRaises(DuplicateKeyError, db.test.update, + {"_id": id}, {"$inc": {"x": 1}}) - if v19: - self.assertTrue("E11000" in db.error()["err"]) - elif v113minus: - self.assertTrue(db.error()["err"].startswith("E11001")) - else: - self.assertTrue(db.error()["err"].startswith("E12011")) + self.assertEqual(1, db.test.update({"_id": id}, + {"$inc": {"x": 2}})["n"]) - self.assertRaises(OperationFailure, db.test.update, - {"_id": id}, {"$inc": {"x": 1}}) - - self.assertEqual(1, db.test.update({"_id": id}, - {"$inc": {"x": 2}})["n"]) - - self.assertEqual(0, db.test.update({"_id": "foo"}, - {"$inc": {"x": 2}})["n"]) + self.assertEqual(0, db.test.update({"_id": "foo"}, + {"$inc": {"x": 2}})["n"]) def test_update_with_invalid_keys(self): self.db.drop_collection("test") @@ -1233,20 +1191,17 @@ class TestCollection(IntegrationTest): self.assertNotEqual(0, self.db.test.update({"hello": "world"}, {})['n']) - def test_safe_save(self): - with self.client.start_request(): - db = self.db - db.drop_collection("test") - db.test.create_index("hello", unique=True) + def test_acknowledged_save(self): + db = self.db + db.drop_collection("test") + db.test.create_index("hello", unique=True) - db.test.save({"hello": "world"}) - db.test.save({"hello": "world"}, w=0) - self.assertTrue("E11000" in db.error()["err"]) + db.test.save({"hello": "world"}) + db.test.save({"hello": "world"}, w=0) + self.assertRaises(DuplicateKeyError, db.test.save, + {"hello": "world"}) - self.assertRaises(OperationFailure, db.test.save, - {"hello": "world"}) - - def test_safe_remove(self): + def test_acknowledged_remove(self): db = self.db db.drop_collection("test") db.create_collection("test", capped=True, size=1000) @@ -1254,16 +1209,8 @@ class TestCollection(IntegrationTest): db.test.insert({"x": 1}) self.assertEqual(1, db.test.count()) - with db.connection.start_request(): - self.assertEqual(None, db.test.remove({"x": 1}, w=0)) - self.assertEqual(1, db.test.count()) - - if client_context.version.at_least(1, 1, 3, -1): - self.assertRaises(OperationFailure, db.test.remove, - {"x": 1}) - else: # Just test that it doesn't blow up - db.test.remove({"x": 1}) - + # Can't remove from capped collection. + self.assertRaises(OperationFailure, db.test.remove, {"x": 1}) db.drop_collection("test") db.test.insert({"x": 1}) db.test.insert({"x": 1}) @@ -1836,16 +1783,16 @@ class TestCollection(IntegrationTest): batch[1]['_id'] = batch[0]['_id'] - # Test that inserts fail after first error, acknowledged. + # Test that inserts fail after first error. self.db.test.drop() - self.assertRaises(DuplicateKeyError, self.db.test.insert, batch, w=1) + self.assertRaises(DuplicateKeyError, self.db.test.insert, batch) self.assertEqual(1, self.db.test.count()) - # Test that inserts fail after first error, unacknowledged. + # 2 batches, 2 errors, continue on error. self.db.test.drop() - with self.client.start_request(): - self.assertTrue(self.db.test.insert(batch, w=0)) - self.assertEqual(1, self.db.test.count()) + self.assertTrue(self.db.test.insert(batch, w=0)) + wait_until(lambda: 1 == self.db.test.count(), + 'insert 1 document') # 2 batches, 2 errors, acknowledged, continue on error self.db.test.drop() @@ -1862,11 +1809,11 @@ class TestCollection(IntegrationTest): # 2 batches, 2 errors, unacknowledged, continue on error self.db.test.drop() - with self.client.start_request(): - self.assertTrue( - self.db.test.insert(batch, continue_on_error=True, w=0)) - # Only the first and third documents should be inserted. - self.assertEqual(2, self.db.test.count()) + self.assertTrue( + self.db.test.insert(batch, continue_on_error=True, w=0)) + # Only the first and third documents should be inserted. + wait_until(lambda: 2 == self.db.test.count(), + 'insert 2 documents') def test_numerous_inserts(self): # Ensure we don't exceed server's 1000-document batch size limit. diff --git a/test/test_database.py b/test/test_database.py index 6039fde18..b9617fc53 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -57,6 +57,7 @@ from test.utils import ( ignore_deprecations, remove_all_users, rs_or_single_client_noauth, + rs_or_single_client, server_started_with_auth) @@ -285,38 +286,37 @@ class TestDatabase(IntegrationTest): @client_context.require_no_mongos def test_errors(self): - db = self.client.pymongo_test + # We must call getlasterror, etc. on same socket as the last operation. + db = rs_or_single_client(max_pool_size=1).pymongo_test + db.reset_error_history() + self.assertEqual(None, db.error()) + self.assertEqual(None, db.previous_error()) - with self.client.start_request(): - db.reset_error_history() - self.assertEqual(None, db.error()) - self.assertEqual(None, db.previous_error()) + db.command("forceerror", check=False) + self.assertTrue(db.error()) + self.assertTrue(db.previous_error()) - db.command("forceerror", check=False) - self.assertTrue(db.error()) - self.assertTrue(db.previous_error()) + db.command("forceerror", check=False) + self.assertTrue(db.error()) + prev_error = db.previous_error() + self.assertEqual(prev_error["nPrev"], 1) + del prev_error["nPrev"] + prev_error.pop("lastOp", None) + error = db.error() + error.pop("lastOp", None) + # getLastError includes "connectionId" in recent + # server versions, getPrevError does not. + error.pop("connectionId", None) + self.assertEqual(error, prev_error) - db.command("forceerror", check=False) - self.assertTrue(db.error()) - prev_error = db.previous_error() - self.assertEqual(prev_error["nPrev"], 1) - del prev_error["nPrev"] - prev_error.pop("lastOp", None) - error = db.error() - error.pop("lastOp", None) - # getLastError includes "connectionId" in recent - # server versions, getPrevError does not. - error.pop("connectionId", None) - self.assertEqual(error, prev_error) + db.test.find_one() + self.assertEqual(None, db.error()) + self.assertTrue(db.previous_error()) + self.assertEqual(db.previous_error()["nPrev"], 2) - db.test.find_one() - self.assertEqual(None, db.error()) - self.assertTrue(db.previous_error()) - self.assertEqual(db.previous_error()["nPrev"], 2) - - db.reset_error_history() - self.assertEqual(None, db.error()) - self.assertEqual(None, db.previous_error()) + db.reset_error_history() + self.assertEqual(None, db.error()) + self.assertEqual(None, db.previous_error()) def test_command(self): db = self.client.admin @@ -338,17 +338,16 @@ class TestDatabase(IntegrationTest): self.assertTrue(isinstance(result['result'][0]['r'], Regex)) def test_last_status(self): - db = self.client.pymongo_test + # We must call getlasterror on the same socket as the last operation. + db = rs_or_single_client(max_pool_size=1).pymongo_test + db.test.remove({}) + db.test.save({"i": 1}) - with self.client.start_request(): - db.test.remove({}) - db.test.save({"i": 1}) + db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0) + self.assertTrue(db.last_status()["updatedExisting"]) - db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0) - self.assertTrue(db.last_status()["updatedExisting"]) - - db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0) - self.assertFalse(db.last_status()["updatedExisting"]) + db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0) + self.assertFalse(db.last_status()["updatedExisting"]) def test_password_digest(self): self.assertRaises(TypeError, auth._password_digest, 5) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 18d4a4d11..b5d455bb3 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -369,28 +369,6 @@ class TestGridfs(IntegrationTest): self.assertTrue(iterate_file(f)) - def test_request(self): - c = self.db.connection - c.start_request() - n = 5 - for i in range(n): - file = self.fs.new_file(filename="test") - file.write(b"hello") - file.close() - - c.end_request() - - self.assertEqual( - n, - self.db.fs.files.find({'filename':'test'}).count() - ) - - def test_gridfs_request(self): - self.assertFalse(self.db.connection.in_request()) - self.fs.put(b"hello world") - # Request started and ended by put(), we're back to original state - self.assertFalse(self.db.connection.in_request()) - def test_gridfs_lazy_connect(self): with client_knobs(server_wait_time=0.01): client = MongoClient('badhost', connect=False) diff --git a/test/test_pooling.py b/test/test_pooling.py index edb1e56a2..46f220b83 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -22,20 +22,17 @@ import threading import time from pymongo import MongoClient -from pymongo.errors import ConfigurationError, ConnectionFailure, \ - ExceededMaxWaiters +from pymongo.errors import (ConfigurationError, + ConnectionFailure, + DuplicateKeyError, + ExceededMaxWaiters) from pymongo.server_selectors import writable_server_selector sys.path[0:0] = [""] -from bson.py3compat import thread -from pymongo.pool import (NO_REQUEST, - NO_SOCKET_YET, - Pool, - PoolOptions, - SocketInfo, _closed) +from pymongo.pool import Pool, PoolOptions, _closed from test import host, port, SkipTest, unittest, client_context -from test.utils import get_pool, one, rs_or_single_client +from test.utils import get_pool, one, rs_or_single_client, connected, joinall @client_context.require_connection @@ -86,18 +83,18 @@ class SaveAndFind(MongoThread): class Unique(MongoThread): def run_mongo_thread(self): for _ in range(N): - self.client.start_request() self.db.unique.insert({}) # no error - self.client.end_request() class NonUnique(MongoThread): def run_mongo_thread(self): for _ in range(N): - self.client.start_request() - self.db.unique.insert({"_id": "jesse"}, w=0) - assert self.db.error() is not None - self.client.end_request() + try: + self.db.unique.insert({"_id": "jesse"}) + except DuplicateKeyError: + pass + else: + raise AssertionError("Should have raised DuplicateKeyError") class Disconnect(MongoThread): @@ -106,19 +103,6 @@ class Disconnect(MongoThread): self.client.disconnect() -class NoRequest(MongoThread): - def run_mongo_thread(self): - self.client.start_request() - errors = 0 - for _ in range(N): - self.db.unique.insert({"_id": "jesse"}, w=0) - if not self.db.error(): - errors += 1 - - self.client.end_request() - assert errors == 0 - - class SocketGetter(MongoThread): """Utility for _TestMaxOpenSockets and _TestWaitQueueMultiple""" def __init__(self, client, pool): @@ -150,73 +134,6 @@ def run_cases(client, cases): assert t.passed, "%s.run() threw an exception" % repr(t) -class CreateAndReleaseSocket(MongoThread): - """Gets a socket, waits for all threads to reach rendezvous, and quits.""" - class Rendezvous(object): - def __init__(self, nthreads): - self.nthreads = nthreads - self.nthreads_run = 0 - self.lock = threading.Lock() - self.reset_ready() - - def reset_ready(self): - self.ready = threading.Event() - - def __init__(self, client, start_request, end_request, rendezvous): - super(CreateAndReleaseSocket, self).__init__(client) - self.start_request = start_request - self.end_request = end_request - self.rendezvous = rendezvous - - def run_mongo_thread(self): - # Do an operation that requires a socket. - # TestPoolMaxSize uses this to spin up lots of threads requiring - # lots of simultaneous connections, to ensure that Pool obeys its - # max_size configuration and closes extra sockets as they're returned. - for i in range(self.start_request): - self.client.start_request() - - # Use a socket - self.client[DB].test.find_one() - - # Don't finish until all threads reach this point - r = self.rendezvous - r.lock.acquire() - r.nthreads_run += 1 - if r.nthreads_run == r.nthreads: - # Everyone's here, let them finish. - r.ready.set() - r.lock.release() - else: - r.lock.release() - r.ready.wait(30) # Wait thirty seconds.... - assert r.ready.isSet(), "Rendezvous timed out" - - for i in range(self.end_request): - self.client.end_request() - - -class CreateAndReleaseSocketNoRendezvous(MongoThread): - """Gets a socket and quits. No synchronization with other threads.""" - def __init__(self, client, start_request, end_request): - super(CreateAndReleaseSocketNoRendezvous, self).__init__(client) - self.start_request = start_request - self.end_request = end_request - - def run_mongo_thread(self): - # Do an operation that requires a socket. - # TestPoolMaxSize uses this to spin up lots of threads requiring - # lots of simultaneous connections, to ensure that Pool obeys its - # max_size configuration and closes extra sockets as they're returned. - for i in range(self.start_request): - self.client.start_request() - - # Use a socket - self.client[DB].test.find_one() - for i in range(self.end_request): - self.client.end_request() - - class _TestPoolingBase(unittest.TestCase): """Base class for all connection-pool tests.""" @@ -231,40 +148,6 @@ class _TestPoolingBase(unittest.TestCase): def create_pool(self, pair=(host, port), *args, **kwargs): return Pool(pair, PoolOptions(*args, **kwargs)) - def assert_no_request(self): - try: - server = self.c._topology.select_server( - writable_server_selector, - server_wait_time=0) - - self.assertEqual(NO_REQUEST, server.pool._get_request_state()) - except ConnectionFailure: - # Success: we're asserting that we're not in a request, but there's - # no pool at all so the assertion is true. - pass - - def assert_request_without_socket(self): - self.assertEqual(NO_SOCKET_YET, get_pool(self.c)._get_request_state()) - - def assert_request_with_socket(self): - self.assertTrue(isinstance( - get_pool(self.c)._get_request_state(), SocketInfo)) - - def assert_pool_size(self, pool_size): - if pool_size == 0: - try: - server = self.c._topology.select_server( - writable_server_selector, - server_wait_time=0) - - self.assertEqual(0, len(server.pool.sockets)) - except ConnectionFailure: - # Success: we're asserting that pool size is 0, and there's no - # pool at all so the assertion is true. - pass - else: - self.assertEqual(pool_size, len(get_pool(self.c).sockets)) - class TestPooling(_TestPoolingBase): def test_max_pool_size_validation(self): @@ -280,95 +163,11 @@ class TestPooling(_TestPoolingBase): self.assertEqual(c.max_pool_size, 100) def test_no_disconnect(self): - run_cases(self.c, [NoRequest, NonUnique, Unique, SaveAndFind]) - - def test_simple_disconnect(self): - # MongoClient just created, expect 1 free socket. - self.assert_pool_size(1) - self.assert_no_request() - - self.c.start_request() - self.assert_request_without_socket() - cursor = self.c[DB].stuff.find() - - # Cursor hasn't actually caused a request yet, so there's still 1 free - # socket. - self.assert_pool_size(1) - self.assert_request_without_socket() - - # Actually make a request to server, triggering a socket to be - # allocated to the request. - list(cursor) - self.assert_pool_size(0) - self.assert_request_with_socket() - - # Pool returns to its original state - self.c.end_request() - self.assert_no_request() - self.assert_pool_size(1) - - self.c.disconnect() - self.assert_pool_size(0) - self.assert_no_request() + run_cases(self.c, [NonUnique, Unique, SaveAndFind]) def test_disconnect(self): run_cases(self.c, [SaveAndFind, Disconnect, Unique]) - def test_request(self): - # Check that Pool gives two different sockets in two calls to - # get_socket(). - cx_pool = self.create_pool( - pair=(host, port), - max_pool_size=10, - connect_timeout=1000, - socket_timeout=1000) - - sock0 = cx_pool.get_socket({}, 0, 0) - sock1 = cx_pool.get_socket({}, 0, 0) - - self.assertNotEqual(sock0, sock1) - - # Now in a request, we'll get the same socket both times - cx_pool.start_request() - - sock2 = cx_pool.get_socket({}, 0, 0) - sock3 = cx_pool.get_socket({}, 0, 0) - self.assertEqual(sock2, sock3) - - # Pool didn't keep reference to sock0 or sock1; sock2 and 3 are new - self.assertNotEqual(sock0, sock2) - self.assertNotEqual(sock1, sock2) - - # Return the request sock to pool - cx_pool.end_request() - - sock4 = cx_pool.get_socket({}, 0, 0) - sock5 = cx_pool.get_socket({}, 0, 0) - - # Not in a request any more, we get different sockets - self.assertNotEqual(sock4, sock5) - - # end_request() returned sock2 to pool - self.assertEqual(sock4, sock2) - - for s in [sock0, sock1, sock2, sock3, sock4, sock5]: - s.close() - - def test_reset_and_request(self): - # reset() is called after a fork, or after a socket error. Ensure that - # a new request is begun if a request was in progress when the reset() - # occurred, otherwise no request is begun. - p = self.create_pool(max_pool_size=10) - self.assertFalse(p.in_request()) - p.start_request() - self.assertTrue(p.in_request()) - p.reset() - self.assertTrue(p.in_request()) - p.end_request() - self.assertFalse(p.in_request()) - p.reset() - self.assertFalse(p.in_request()) - def test_pool_reuses_open_socket(self): # Test Pool's _check_closed() method doesn't close a healthy socket. cx_pool = self.create_pool(max_pool_size=10) @@ -399,236 +198,6 @@ class TestPooling(_TestPoolingBase): self.assertEqual(1, len(cx_pool.sockets)) - def test_pool_removes_dead_request_socket_after_check(self): - # Test that Pool keeps request going even if a socket dies in request. - cx_pool = self.create_pool(max_pool_size=10) - cx_pool._check_interval_seconds = 0 # Always check. - cx_pool.start_request() - - # Get the request socket. - with cx_pool.get_socket({}, 0, 0) as sock_info: - self.assertEqual(0, len(cx_pool.sockets)) - self.assertEqual(sock_info, cx_pool._get_request_state()) - sock_info.sock.close() - - # Although the request socket died, we're still in a request with a - # new socket. - with cx_pool.get_socket({}, 0, 0) as new_sock_info: - self.assertTrue(cx_pool.in_request()) - self.assertNotEqual(sock_info, new_sock_info) - self.assertEqual(new_sock_info, cx_pool._get_request_state()) - - self.assertEqual(new_sock_info, cx_pool._get_request_state()) - self.assertEqual(0, len(cx_pool.sockets)) - - cx_pool.end_request() - self.assertEqual(1, len(cx_pool.sockets)) - - def test_pool_removes_dead_request_socket(self): - # Test that Pool keeps request going even if a socket dies in request. - cx_pool = self.create_pool(max_pool_size=10) - cx_pool.start_request() - - # Get the request socket - with cx_pool.get_socket({}, 0, 0) as sock_info: - self.assertEqual(0, len(cx_pool.sockets)) - self.assertEqual(sock_info, cx_pool._get_request_state()) - - # Unlike in test_pool_removes_dead_request_socket_after_check, we - # set sock_info.closed and *don't* wait for it to be checked. - sock_info.close() - - # Although the request socket died, we're still in a request with a - # new socket - with cx_pool.get_socket({}, 0, 0) as new_sock_info: - self.assertTrue(cx_pool.in_request()) - self.assertNotEqual(sock_info, new_sock_info) - self.assertEqual(new_sock_info, cx_pool._get_request_state()) - - self.assertEqual(new_sock_info, cx_pool._get_request_state()) - self.assertEqual(0, len(cx_pool.sockets)) - - cx_pool.end_request() - self.assertEqual(1, len(cx_pool.sockets)) - - def test_pool_removes_dead_socket_after_request(self): - # Test that Pool handles a socket dying that *used* to be the request - # socket. - cx_pool = self.create_pool(max_pool_size=10) - cx_pool._check_interval_seconds = 0 # Always check. - cx_pool.start_request() - - # Get the request socket. - with cx_pool.get_socket({}, 0, 0) as sock_info: - self.assertEqual(sock_info, cx_pool._get_request_state()) - - # End request. - cx_pool.end_request() - self.assertEqual(1, len(cx_pool.sockets)) - - # Kill old request socket. - sock_info.sock.close() - - # Dead socket detected and removed. - with cx_pool.get_socket({}, 0, 0) as new_sock_info: - self.assertFalse(cx_pool.in_request()) - self.assertNotEqual(sock_info, new_sock_info) - self.assertEqual(0, len(cx_pool.sockets)) - self.assertFalse(_closed(new_sock_info.sock)) - - self.assertEqual(1, len(cx_pool.sockets)) - - def test_dead_request_socket_with_max_size(self): - # When a pool replaces a dead request socket, the semaphore it uses - # to enforce max_size should remain unaffected. - cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) - cx_pool._check_interval_seconds = 0 # Always check. - cx_pool.start_request() - - # Get and close the request socket. - with cx_pool.get_socket({}, 0, 0) as request_sock_info: - request_sock_info.sock.close() - - # Detects closed socket and creates new one, semaphore value still 0. - with cx_pool.get_socket({}, 0, 0) as request_sock_info_2: - self.assertNotEqual(request_sock_info, request_sock_info_2) - - cx_pool.end_request() - - # Semaphore value now 1; we can get a socket. - sock = cx_pool.get_socket({}, 0, 0) - sock.close() - - def test_socket_reclamation(self): - if sys.platform.startswith('java'): - raise SkipTest("Jython can't do socket reclamation") - - # Check that if a thread starts a request and dies without ending - # the request, that the socket is reclaimed into the pool. - cx_pool = self.create_pool( - max_pool_size=10, - connect_timeout=1000, - socket_timeout=1000) - - self.assertEqual(0, len(cx_pool.sockets)) - - lock = None - the_sock = [None] - - def leak_request(): - self.assertEqual(NO_REQUEST, cx_pool._get_request_state()) - cx_pool.start_request() - self.assertEqual(NO_SOCKET_YET, cx_pool._get_request_state()) - with cx_pool.get_socket({}, 0, 0) as sock_info: - self.assertEqual(sock_info, cx_pool._get_request_state()) - the_sock[0] = id(sock_info.sock) - - lock.release() - - lock = thread.allocate_lock() - lock.acquire() - - # Start a thread WITHOUT a threading.Thread - important to test that - # Pool can deal with primitive threads. - thread.start_new_thread(leak_request, ()) - - # Join thread - acquired = lock.acquire() - self.assertTrue(acquired, "Thread is hung") - - # Make sure thread is really gone - time.sleep(1) - - if 'PyPy' in sys.version: - gc.collect() - - # Access the thread local from the main thread to trigger the - # ThreadVigil's delete callback, returning the request socket to - # the pool. - # In Python 2.7.0 and lesser, a dead thread's locals are deleted - # and those locals' weakref callbacks are fired only when another - # thread accesses the locals and finds the thread state is stale, - # see http://bugs.python.org/issue1868. Accessing the thread - # local from the main thread is a necessary part of this test, and - # realistic: in a multithreaded web server a new thread will access - # Pool._ident._local soon after an old thread has died. - cx_pool._ident.get() - - # Pool reclaimed the socket - self.assertEqual(1, len(cx_pool.sockets)) - self.assertEqual(the_sock[0], id(one(cx_pool.sockets).sock)) - self.assertEqual(0, len(cx_pool._tid_to_sock)) - - def test_request_with_fork(self): - if sys.platform == "win32": - raise SkipTest("Can't test forking on Windows") - - try: - from multiprocessing import Process, Pipe - except ImportError: - raise SkipTest("No multiprocessing module") - - coll = self.c.pymongo_test.test - coll.remove() - coll.insert({'_id': 1}) - coll.find_one() - self.assert_pool_size(1) - self.c.start_request() - self.assert_pool_size(1) - coll.find_one() - self.assert_pool_size(0) - self.assert_request_with_socket() - - def f(pipe): - # We can still query server without error - self.assertEqual({'_id':1}, coll.find_one()) - - # Pool has detected that we forked, but resumed the request - self.assert_request_with_socket() - self.assert_pool_size(0) - pipe.send("success") - - parent_conn, child_conn = Pipe() - p = Process(target=f, args=(child_conn,)) - p.start() - p.join(30) - p.terminate() - child_conn.close() - self.assertEqual("success", parent_conn.recv()) - - def test_primitive_thread(self): - p = self.create_pool(max_pool_size=10) - - # Test that start/end_request work with a thread begun from thread - # module, rather than threading module - lock = thread.allocate_lock() - lock.acquire() - - sock_ids = [] - - def run_in_request(): - p.start_request() - sock0 = p.get_socket({}, 0, 0) - sock1 = p.get_socket({}, 0, 0) - sock_ids.extend([id(sock0), id(sock1)]) - p.maybe_return_socket(sock0) - p.maybe_return_socket(sock1) - p.end_request() - lock.release() - - thread.start_new_thread(run_in_request, ()) - - # Join thread - acquired = False - for i in range(30): - time.sleep(0.5) - acquired = lock.acquire(0) - if acquired: - break - - self.assertTrue(acquired, "Thread is hung") - self.assertEqual(sock_ids[0], sock_ids[1]) - def test_pool_with_fork(self): # Test that separate MongoClients have separate Pools, and that the # driver can create a new MongoClient after forking @@ -770,204 +339,61 @@ class TestPooling(_TestPoolingBase): class TestPoolMaxSize(_TestPoolingBase): - """Keep right number of sockets in various start/end_request sequences. - """ - def _test_max_pool_size( - self, start_request, end_request, max_pool_size=4, nthreads=10): - """Start `nthreads` threads. Each calls start_request `start_request` - times, then find_one and waits at a barrier; once all reach the barrier - each calls end_request `end_request` times. The test asserts that the - pool ends with min(max_pool_size, nthreads) sockets or, if - start_request wasn't called, at least one socket. - - This tests both max_pool_size enforcement and that leaked request - sockets are eventually returned to the pool when their threads end. - - You may need to increase ulimit -n on Mac. - """ - if start_request: - if max_pool_size is not None and max_pool_size < nthreads: - raise AssertionError("Deadlock") - - c = rs_or_single_client(max_pool_size=max_pool_size) - rendezvous = CreateAndReleaseSocket.Rendezvous(nthreads) - - threads = [] - for i in range(nthreads): - t = CreateAndReleaseSocket( - c, start_request, end_request, rendezvous) - - threads.append(t) - - for t in threads: - t.start() - - if 'PyPy' in sys.version: - # With PyPy we need to kick off the gc whenever the threads hit the - # rendezvous since nthreads > max_pool_size. - gc_collect_until_done(threads) - else: - for t in threads: - t.join() - - # join() returns before the thread state is cleared; give it time. - time.sleep(1) - - for t in threads: - self.assertTrue(t.passed) - - # Socket-reclamation doesn't work in Jython - if not sys.platform.startswith('java'): - cx_pool = get_pool(c) - - # Socket-reclamation depends on timely garbage-collection - if 'PyPy' in sys.version: - gc.collect() - - if start_request: - # Trigger final cleanup in Python <= 2.7.0. - cx_pool._ident.get() - expected_idle = min(max_pool_size, nthreads) - message = ( - '%d idle sockets (expected %d) and %d request sockets' - ' (expected 0)' % ( - len(cx_pool.sockets), expected_idle, - len(cx_pool._tid_to_sock))) - - self.assertEqual( - expected_idle, len(cx_pool.sockets), message) - else: - # Without calling start_request(), threads can safely share - # sockets; the number running concurrently, and hence the - # number of sockets needed, is between 1 and 10, depending - # on thread-scheduling. - self.assertTrue(len(cx_pool.sockets) >= 1) - - # thread.join completes slightly *before* thread locals are - # cleaned up, so wait up to 5 seconds for them. - time.sleep(0.1) - cx_pool._ident.get() - start = time.time() - - while ( - not cx_pool.sockets - and cx_pool._socket_semaphore.counter < max_pool_size - and (time.time() - start) < 5 - ): - time.sleep(0.1) - cx_pool._ident.get() - - if max_pool_size is not None: - self.assertEqual( - max_pool_size, - cx_pool._socket_semaphore.counter) - - self.assertEqual(0, len(cx_pool._tid_to_sock)) - - def _test_max_pool_size_no_rendezvous(self, start_request, end_request): - max_pool_size = 5 - c = rs_or_single_client(max_pool_size=max_pool_size) + def test_max_pool_size(self): + max_pool_size = 4 + c = connected(rs_or_single_client(max_pool_size=max_pool_size)) + cx_pool = get_pool(c) # nthreads had better be much larger than max_pool_size to ensure that # max_pool_size sockets are actually required at some point in this # test's execution. nthreads = 10 - - if (sys.platform.startswith('java') - and start_request > end_request - and nthreads > max_pool_size): - - # Since Jython can't reclaim the socket and release the semaphore - # after a thread leaks a request, we'll exhaust the semaphore and - # deadlock. - raise SkipTest("Jython can't do socket reclamation") - threads = [] - for i in range(nthreads): - t = CreateAndReleaseSocketNoRendezvous( - c, start_request, end_request) - - threads.append(t) + lock = threading.Lock() + self.n_passed = 0 - for t in threads: + def f(): + for _ in range(N): + c[DB].test.find_one() + assert len(cx_pool.sockets) <= max_pool_size + + with lock: + self.n_passed += 1 + + for i in range(nthreads): + t = threading.Thread(target=f) + threads.append(t) t.start() - if 'PyPy' in sys.version: - # With PyPy we need to kick off the gc whenever the threads hit the - # rendezvous since nthreads > max_pool_size. - gc_collect_until_done(threads) - else: - for t in threads: - t.join() - - for t in threads: - self.assertTrue(t.passed) - - cx_pool = get_pool(c) - - # Socket-reclamation depends on timely garbage-collection - if 'PyPy' in sys.version: - gc.collect() - - # thread.join completes slightly *before* thread locals are - # cleaned up, so wait up to 5 seconds for them. - time.sleep(0.1) - cx_pool._ident.get() - start = time.time() - - while ( - not cx_pool.sockets - and cx_pool._socket_semaphore.counter < max_pool_size - and (time.time() - start) < 5 - ): - time.sleep(0.1) - cx_pool._ident.get() - - self.assertTrue(len(cx_pool.sockets) >= 1) + joinall(threads) + self.assertEqual(nthreads, self.n_passed) + self.assertTrue(len(cx_pool.sockets) > 1) self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter) - def test_max_pool_size(self): - self._test_max_pool_size( - start_request=0, end_request=0, nthreads=10, max_pool_size=4) - def test_max_pool_size_none(self): - self._test_max_pool_size( - start_request=0, end_request=0, nthreads=10, max_pool_size=None) + c = connected(rs_or_single_client(max_pool_size=None)) + cx_pool = get_pool(c) - def test_max_pool_size_with_request(self): - self._test_max_pool_size( - start_request=1, end_request=1, nthreads=10, max_pool_size=10) + nthreads = 10 + threads = [] + lock = threading.Lock() + self.n_passed = 0 - def test_max_pool_size_with_multiple_request(self): - self._test_max_pool_size( - start_request=10, end_request=10, nthreads=10, max_pool_size=10) + def f(): + for _ in range(N): + c[DB].test.find_one() - def test_max_pool_size_with_redundant_request(self): - self._test_max_pool_size( - start_request=2, end_request=1, nthreads=10, max_pool_size=10) + with lock: + self.n_passed += 1 - def test_max_pool_size_with_redundant_request2(self): - self._test_max_pool_size( - start_request=20, end_request=1, nthreads=10, max_pool_size=10) + for i in range(nthreads): + t = threading.Thread(target=f) + threads.append(t) + t.start() - def test_max_pool_size_with_redundant_request_no_rendezvous(self): - self._test_max_pool_size_no_rendezvous(2, 1) - - def test_max_pool_size_with_redundant_request_no_rendezvous2(self): - self._test_max_pool_size_no_rendezvous(20, 1) - - def test_max_pool_size_with_leaked_request(self): - # Call start_request() but not end_request() -- when threads die, they - # should return their request sockets to the pool. - self._test_max_pool_size( - start_request=1, end_request=0, nthreads=10, max_pool_size=10) - - def test_max_pool_size_with_leaked_request_no_rendezvous(self): - self._test_max_pool_size_no_rendezvous(1, 0) - - def test_max_pool_size_with_end_request_only(self): - # Call end_request() but not start_request() - self._test_max_pool_size(0, 1) + joinall(threads) + self.assertEqual(nthreads, self.n_passed) + self.assertTrue(len(cx_pool.sockets) >= 1) def test_max_pool_size_with_connection_failure(self): # The pool acquires its semaphore before attempting to connect; ensure diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index bb5c0f67a..e6e0df6f6 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -38,17 +38,19 @@ from test import (client_context, IntegrationTest, pair, port, - SkipTest, unittest, db_pwd, db_user, MockClientTest) from test.pymongo_mocks import MockClient -from test.utils import ( - delay, assertReadFrom, assertReadFromAll, ignore_deprecations, - read_from_which_host, assertRaisesExactly, TestRequestMixin, get_pools, - connected, wait_until, single_client, rs_or_single_client, one, - rs_client) +from test.utils import (assertRaisesExactly, + connected, + delay, + ignore_deprecations, + one, + rs_client, + single_client, + wait_until) from test.version import Version @@ -79,7 +81,7 @@ class TestReplicaSetClientBase(IntegrationTest): if m['stateStr'] == 'SECONDARY') -class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): +class TestReplicaSetClient(TestReplicaSetClientBase): def test_deprecated(self): with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) @@ -288,88 +290,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): def test_kill_cursor_explicit_secondary(self): self._test_kill_cursor_explicit(ReadPreference.SECONDARY) - def test_nested_request(self): - client = rs_or_single_client() - connected(client) - client.start_request() - try: - pools = get_pools(client) - self.assertTrue(client.in_request()) - - # Start and end request - we're still in "outer" original request - client.start_request() - self.assertInRequestAndSameSock(client, pools) - client.end_request() - self.assertInRequestAndSameSock(client, pools) - - # Double-nesting - client.start_request() - client.start_request() - - for pool in pools: - # Client only called start_request() once per pool, although - # its own counter is 3. - self.assertEqual(1, pool._request_counter.get()) - - client.end_request() - client.end_request() - self.assertInRequestAndSameSock(client, pools) - - for pool in pools: - self.assertEqual(1, pool._request_counter.get()) - - # Finally, end original request - client.end_request() - for pool in pools: - self.assertFalse(pool.in_request()) - - self.assertNotInRequestAndDifferentSock(client, pools) - finally: - client.close() - - def test_pinned_member(self): - raise SkipTest("Secondary pinning not implemented in PyMongo 3") - - latency = 1000 * 1000 - client = rs_client(secondaryacceptablelatencyms=latency) - - host = read_from_which_host(client, ReadPreference.SECONDARY) - self.assertTrue(host in client.secondaries) - - # No pinning since we're not in a request - assertReadFromAll( - self, client, client.secondaries, - ReadPreference.SECONDARY, None) - - assertReadFromAll( - self, client, list(client.secondaries) + [client.primary], - ReadPreference.NEAREST, None) - - client.start_request() - host = read_from_which_host(client, ReadPreference.SECONDARY) - self.assertTrue(host in client.secondaries) - assertReadFrom(self, client, host, ReadPreference.SECONDARY) - - # Changing any part of read preference (mode, tag_sets) - # unpins the current host and pins to a new one - primary = client.primary - assertReadFrom(self, client, primary, ReadPreference.PRIMARY_PREFERRED) - - host = read_from_which_host(client, ReadPreference.NEAREST) - assertReadFrom(self, client, host, ReadPreference.NEAREST) - - assertReadFrom(self, client, primary, ReadPreference.PRIMARY_PREFERRED) - - host = read_from_which_host(client, ReadPreference.SECONDARY_PREFERRED) - self.assertTrue(host in client.secondaries) - assertReadFrom(self, client, host, ReadPreference.SECONDARY_PREFERRED) - - # Unpin - client.end_request() - assertReadFromAll( - self, client, list(client.secondaries) + [client.primary], - ReadPreference.NEAREST, None) - def test_not_master_error(self): secondary_address = one(self.secondaries) direct_client = single_client(*secondary_address) diff --git a/test/test_thread_util.py b/test/test_thread_util.py deleted file mode 100644 index fa02286e6..000000000 --- a/test/test_thread_util.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2012-2014 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test the thread_util module.""" - -import gc -import sys -import threading -import time -from functools import partial - -sys.path[0:0] = [""] - -from pymongo import thread_util -from test import SkipTest, unittest -from test.utils import RendezvousThread - - -class TestIdent(unittest.TestCase): - def test_thread_ident(self): - # 1. Store main thread's id. - # 2. Start 2 child threads. - # 3. Store their values for Ident.get(). - # 4. Children reach rendezvous point. - # 5. Children call Ident.watch(). - # 6. One of the children calls Ident.unwatch(). - # 7. Children terminate. - # 8. Assert that children got different ids from each other and from - # main, and assert watched child's callback was executed, and that - # unwatched child's callback was not. - - if 'java' in sys.platform: - raise SkipTest("Can't rely on weakref callbacks in Jython") - - ident = thread_util.ThreadIdent() - - ids = set([ident.get()]) - unwatched_id = [] - done = set([ident.get()]) # Start with main thread's id. - died = set() - - class WatchedThread(RendezvousThread): - def __init__(self, ident, state): - super(WatchedThread, self).__init__(state) - self._my_ident = ident - - def before_rendezvous(self): - self.my_id = self._my_ident.get() - ids.add(self.my_id) - - def after_rendezvous(self): - assert not self._my_ident.watching() - self._my_ident.watch(lambda ref: died.add(self.my_id)) - assert self._my_ident.watching() - done.add(self.my_id) - - class UnwatchedThread(WatchedThread): - def before_rendezvous(self): - super(UnwatchedThread, self).before_rendezvous() - unwatched_id.append(self.my_id) - - def after_rendezvous(self): - super(UnwatchedThread, self).after_rendezvous() - self._my_ident.unwatch(self.my_id) - assert not self._my_ident.watching() - - state = RendezvousThread.create_shared_state(2) - t_watched = WatchedThread(ident, state) - t_watched.start() - - t_unwatched = UnwatchedThread(ident, state) - t_unwatched.start() - - RendezvousThread.wait_for_rendezvous(state) - RendezvousThread.resume_after_rendezvous(state) - - t_watched.join() - t_unwatched.join() - - self.assertTrue(t_watched.passed) - self.assertTrue(t_unwatched.passed) - - # Remove references, let weakref callbacks run - del t_watched - del t_unwatched - - # Trigger final cleanup in Python <= 2.7.0. - # http://bugs.python.org/issue1868 - ident.get() - self.assertEqual(3, len(ids)) - self.assertEqual(3, len(done)) - - # Make sure thread is really gone - slept = 0 - while not died and slept < 10: - time.sleep(1) - gc.collect() - slept += 1 - - self.assertEqual(1, len(died)) - self.assertFalse(unwatched_id[0] in died) - - -class TestCounter(unittest.TestCase): - def test_counter(self): - counter = thread_util.Counter() - - self.assertEqual(0, counter.dec()) - self.assertEqual(0, counter.get()) - self.assertEqual(0, counter.dec()) - self.assertEqual(0, counter.get()) - - done = set() - - def f(n): - for i in range(n): - self.assertEqual(i, counter.get()) - self.assertEqual(i + 1, counter.inc()) - - for i in range(n, 0, -1): - self.assertEqual(i, counter.get()) - self.assertEqual(i - 1, counter.dec()) - - self.assertEqual(0, counter.get()) - - # Extra decrements have no effect - self.assertEqual(0, counter.dec()) - self.assertEqual(0, counter.get()) - self.assertEqual(0, counter.dec()) - self.assertEqual(0, counter.get()) - - done.add(n) - - threads = [threading.Thread(target=partial(f, i)) - for i in range(10)] - - for t in threads: - t.start() - - for t in threads: - t.join() - - self.assertEqual(10, len(done)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_threads.py b/test/test_threads.py index b5874db3f..afacd6a01 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -17,14 +17,9 @@ import threading from test import unittest, client_context, IntegrationTest, db_user, db_pwd -from test.utils import (frequent_thread_switches, - joinall, - RendezvousThread, - rs_or_single_client, - rs_or_single_client_noauth) -from test.utils import get_pool -from pymongo.pool import SocketInfo, _closed -from pymongo.errors import AutoReconnect, OperationFailure +from test.utils import rs_or_single_client_noauth +from test.utils import frequent_thread_switches, joinall +from pymongo.errors import OperationFailure @client_context.require_connection @@ -130,57 +125,6 @@ class Disconnect(threading.Thread): self.passed = True -class IgnoreAutoReconnect(threading.Thread): - - def __init__(self, collection, n): - threading.Thread.__init__(self) - self.c = collection - self.n = n - self.setDaemon(True) - - def run(self): - for _ in range(self.n): - try: - self.c.find_one() - except AutoReconnect: - pass - - -class FindPauseFind(RendezvousThread): - """See test_server_disconnect() for details""" - def __init__(self, collection, state): - """Params: - `collection`: A collection for testing - `state`: A shared state object from RendezvousThread.shared_state() - """ - super(FindPauseFind, self).__init__(state) - self.collection = collection - - def before_rendezvous(self): - # acquire a socket - client = self.collection.database.connection - client.start_request() - list(self.collection.find()) - - pool = get_pool(client) - socket_info = pool._get_request_state() - assert isinstance(socket_info, SocketInfo) - self.request_sock = socket_info.sock - assert not _closed(self.request_sock) - - def after_rendezvous(self): - # test_server_disconnect() has closed this socket, but that's ok - # because it's not our request socket anymore - assert _closed(self.request_sock) - - # if disconnect() properly replaced the pool, then this won't raise - # AutoReconnect because it will acquire a new socket - list(self.collection.find()) - assert self.collection.database.connection.in_request() - pool = get_pool(self.collection.database.connection) - assert self.request_sock != pool._get_request_state().sock - - class TestThreads(IntegrationTest): def setUp(self): self.db = client_context.rs_or_standalone_client.pymongo_test @@ -236,71 +180,6 @@ class TestThreads(IntegrationTest): error.join() okay.join() - def test_server_disconnect(self): - # PYTHON-345, we need to make sure that threads' request sockets are - # closed by disconnect(). - # - # 1. Create a client and start a request. - # 2. Start N threads and do a find() in each to get a request socket - # 3. Pause all threads - # 4. In the main thread close all sockets, including threads' request - # sockets - # 5. In main thread, do a find(), which raises AutoReconnect and resets - # pool - # 6. Resume all threads, do a find() in them - # - # If we've fixed PYTHON-345, then only one AutoReconnect is raised, - # and all the threads get new request sockets. - cx = rs_or_single_client() - cx.start_request() - collection = cx.db.pymongo_test - - # acquire a request socket for the main thread - collection.find_one() - pool = get_pool(collection.database.connection) - socket_info = pool._get_request_state() - assert isinstance(socket_info, SocketInfo) - request_sock = socket_info.sock - - state = FindPauseFind.create_shared_state(nthreads=10) - - threads = [ - FindPauseFind(collection, state) - for _ in range(state.nthreads) - ] - - # Each thread does a find(), thus acquiring a request socket - for t in threads: - t.start() - - # Wait for the threads to reach the rendezvous - FindPauseFind.wait_for_rendezvous(state) - - try: - # Simulate an event that closes all sockets, e.g. primary stepdown - for t in threads: - t.request_sock.close() - - # Finally, ensure the main thread's socket's last_checkout is - # updated: - collection.find_one() - - # ... and close it: - request_sock.close() - - # Doing an operation on the client raises an AutoReconnect and - # resets the pool behind the scenes - self.assertRaises(AutoReconnect, collection.find_one) - - finally: - # Let threads do a second find() - FindPauseFind.resume_after_rendezvous(state) - - joinall(threads) - - for t in threads: - self.assertTrue(t.passed, "%s threw exception" % t) - def test_client_disconnect(self): self.db.drop_collection("test") for i in range(1000): diff --git a/test/utils.py b/test/utils.py index 0fe4fb10e..8a20b3d11 100644 --- a/test/utils.py +++ b/test/utils.py @@ -26,7 +26,6 @@ from functools import partial from pymongo import MongoClient from pymongo.errors import AutoReconnect, OperationFailure -from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo from pymongo.server_selectors import (any_server_selector, writable_server_selector) from test import (client_context, @@ -280,104 +279,6 @@ def ignore_deprecations(): yield -class RendezvousThread(threading.Thread): - """A thread that starts and pauses at a rendezvous point before resuming. - To be used in tests that must ensure that N threads are all alive - simultaneously, regardless of thread-scheduling's vagaries. - - 1. Write a subclass of RendezvousThread and override before_rendezvous - and / or after_rendezvous. - 2. Create a state with RendezvousThread.shared_state(N) - 3. Start N of your subclassed RendezvousThreads, passing the state to each - one's __init__ - 4. In the main thread, call RendezvousThread.wait_for_rendezvous - 5. Test whatever you need to test while threads are paused at rendezvous - point - 6. In main thread, call RendezvousThread.resume_after_rendezvous - 7. Join all threads from main thread - 8. Assert that all threads' "passed" attribute is True - 9. Test post-conditions - """ - - class RendezvousState(object): - def __init__(self, nthreads): - # Number of threads total - self.nthreads = nthreads - - # Number of threads that have arrived at rendezvous point - self.arrived_threads = 0 - self.arrived_threads_lock = threading.Lock() - - # Set when all threads reach rendezvous - self.ev_arrived = threading.Event() - - # Set by resume_after_rendezvous() so threads can continue. - self.ev_resume = threading.Event() - - - @classmethod - def create_shared_state(cls, nthreads): - return RendezvousThread.RendezvousState(nthreads) - - def before_rendezvous(self): - """Overridable: Do this before the rendezvous""" - pass - - def after_rendezvous(self): - """Overridable: Do this after the rendezvous. If it throws no exception, - `passed` is set to True - """ - pass - - @classmethod - def wait_for_rendezvous(cls, state): - """Wait for all threads to reach rendezvous and pause there""" - state.ev_arrived.wait(10) - assert state.ev_arrived.isSet(), "Thread timeout" - assert state.nthreads == state.arrived_threads - - @classmethod - def resume_after_rendezvous(cls, state): - """Tell all the paused threads to continue""" - state.ev_resume.set() - - def __init__(self, state): - """Params: - `state`: A shared state object from RendezvousThread.shared_state() - """ - super(RendezvousThread, self).__init__() - self.state = state - self.passed = False - - # If this thread fails to terminate, don't hang the whole program - self.setDaemon(True) - - def _rendezvous(self): - """Pause until all threads arrive here""" - s = self.state - s.arrived_threads_lock.acquire() - s.arrived_threads += 1 - if s.arrived_threads == s.nthreads: - s.arrived_threads_lock.release() - s.ev_arrived.set() - else: - s.arrived_threads_lock.release() - s.ev_arrived.wait() - - def run(self): - try: - self.before_rendezvous() - finally: - self._rendezvous() - - # all threads have passed the rendezvous, wait for - # resume_after_rendezvous() - self.state.ev_resume.wait() - - self.after_rendezvous() - self.passed = True - - def read_from_which_host( client, pref, @@ -467,51 +368,6 @@ def get_pools(client): client._get_topology().select_servers(any_server_selector)] -class TestRequestMixin(object): - """Inherit from this class and from unittest.TestCase to get some - convenient methods for testing connection pools and requests - """ - - def assertSameSock(self, pool): - sock_info0 = pool.get_socket({}, 0, 0) - sock_info1 = pool.get_socket({}, 0, 0) - self.assertEqual(sock_info0, sock_info1) - pool.maybe_return_socket(sock_info0) - pool.maybe_return_socket(sock_info1) - - def assertDifferentSock(self, pool): - sock_info0 = pool.get_socket({}, 0, 0) - sock_info1 = pool.get_socket({}, 0, 0) - self.assertNotEqual(sock_info0, sock_info1) - pool.maybe_return_socket(sock_info0) - pool.maybe_return_socket(sock_info1) - - def assertNoRequest(self, pool): - self.assertEqual(NO_REQUEST, pool._get_request_state()) - - def assertNoSocketYet(self, pool): - self.assertEqual(NO_SOCKET_YET, pool._get_request_state()) - - def assertRequestSocket(self, pool): - self.assertTrue(isinstance(pool._get_request_state(), SocketInfo)) - - def assertInRequestAndSameSock(self, client, pools): - self.assertTrue(client.in_request()) - if not isinstance(pools, list): - pools = [pools] - for pool in pools: - self.assertTrue(pool.in_request()) - self.assertSameSock(pool) - - def assertNotInRequestAndDifferentSock(self, client, pools): - self.assertFalse(client.in_request()) - if not isinstance(pools, list): - pools = [pools] - for pool in pools: - self.assertFalse(pool.in_request()) - self.assertDifferentSock(pool) - - # Constants for run_threads and lazy_client_trial. NTRIALS = 5 NTHREADS = 10