From c5eae2f99fe09f34d28dad948d4171e95020ea0e Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Wed, 1 Oct 2014 23:12:33 -0400 Subject: [PATCH] Use replica set connection in tests wherever possible. Most tests now inherit from IntegrationTest and use self.client for all MongoDB operations. self.client is now a replica set connection if an RS is available, otherwise a connection to a standalone. --- test/__init__.py | 3 +- test/test_auth.py | 20 ++++++------- test/test_bulk.py | 52 ++++++++++----------------------- test/test_client.py | 5 ---- test/test_collection.py | 25 ++++++++-------- test/test_common.py | 5 ++-- test/test_cursor.py | 32 ++++++++++---------- test/test_cursor_manager.py | 2 +- test/test_database.py | 27 +++++++---------- test/test_grid_file.py | 5 ---- test/test_gridfs.py | 1 - test/test_json_util.py | 5 ---- test/test_read_preferences.py | 6 ++-- test/test_replica_set_client.py | 7 +++-- 14 files changed, 76 insertions(+), 119 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index beeaef193..dccbf00c3 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -310,7 +310,8 @@ class IntegrationTest(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): - pass + cls.client = client_context.rs_or_standalone_client + cls.db = cls.client.pymongo_test class MockClientTest(unittest.TestCase): diff --git a/test/test_auth.py b/test/test_auth.py index d66f926af..d5d4a4bc3 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -262,15 +262,11 @@ class TestSCRAMSHA1(unittest.TestCase): {}).get('authenticationMechanisms', ''): raise SkipTest('SCRAM-SHA-1 mechanism not enabled') - client = client_context.client - if self.set_name: - client.pymongo_test.add_user('user', 'pass', - roles=['userAdmin', 'readWrite'], - writeConcern={'w': client_context.w}) - else: - client.pymongo_test.add_user( - 'user', 'pass', roles=['userAdmin', 'readWrite']) - + client = client_context.rs_or_standalone_client + client.pymongo_test.add_user( + 'user', 'pass', + roles=['userAdmin', 'readWrite'], + writeConcern={'w': client_context.w}) def test_scram_sha1(self): client = MongoClient(host, port) @@ -298,7 +294,9 @@ class TestSCRAMSHA1(unittest.TestCase): client.pymongo_test.command('dbstats') def tearDown(self): - client_context.client.pymongo_test.remove_user('user') + client_context.rs_or_standalone_client.pymongo_test.remove_user( + 'user', + w=client_context.w) class TestAuthURIOptions(unittest.TestCase): @@ -387,7 +385,7 @@ class TestDelegatedAuth(unittest.TestCase): @client_context.require_version_max(2, 5, 3) @client_context.require_version_min(2, 4, 0) def setUp(self): - self.client = client_context.client + self.client = client_context.rs_or_standalone_client def tearDown(self): self.client.pymongo_test.remove_user('user') diff --git a/test/test_bulk.py b/test/test_bulk.py index 8deb6f7a2..4e44f9aab 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -22,19 +22,21 @@ from bson import InvalidDocument, SON from bson.py3compat import string_type from pymongo import MongoClient from pymongo.errors import BulkWriteError, InvalidOperation, OperationFailure -from test import client_context, unittest, host, port +from test import client_context, unittest, host, port, IntegrationTest from test.utils import oid_generated_on_client, remove_all_users -@client_context.require_connection -def setUpModule(): - pass +class BulkTestBase(IntegrationTest): - -class BulkTestBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(BulkTestBase, cls).setUpClass() + cls.coll = cls.db.test + cls.has_write_commands = (client_context.client.max_wire_version > 1) def setUp(self): - self.has_write_commands = (client_context.client.max_wire_version > 1) + super(BulkTestBase, self).setUp() + self.coll.remove() def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -111,14 +113,6 @@ class BulkTestBase(unittest.TestCase): class TestBulk(BulkTestBase): - @classmethod - def setUpClass(cls): - cls.coll = client_context.client.pymongo_test.test - - def setUp(self): - super(TestBulk, self).setUp() - self.coll.remove() - def test_empty(self): bulk = self.coll.initialize_ordered_bulk_op() self.assertRaises(InvalidOperation, bulk.execute) @@ -913,13 +907,8 @@ class TestBulkWriteConcern(BulkTestBase): @classmethod def setUpClass(cls): - cls.is_repl = ('setName' in client_context.ismaster) + super(TestBulkWriteConcern, cls).setUpClass() cls.w = client_context.w - cls.coll = client_context.client.pymongo_test.test - - def setUp(self): - super(TestBulkWriteConcern, self).setUp() - self.coll.remove() @client_context.require_version_min(1, 8, 2) def test_fsync_and_j(self): @@ -1083,14 +1072,6 @@ class TestBulkWriteConcern(BulkTestBase): class TestBulkNoResults(BulkTestBase): - @classmethod - def setUpClass(cls): - cls.coll = client_context.client.pymongo_test.test - - def setUp(self): - super(TestBulkNoResults, self).setUp() - self.coll.remove() - def test_no_results_ordered_success(self): batch = self.coll.initialize_ordered_bulk_op() @@ -1098,7 +1079,7 @@ class TestBulkNoResults(BulkTestBase): batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) batch.find({'_id': 1}).remove_one() - with client_context.client.start_request(): + with self.client.start_request(): self.assertTrue(batch.execute({'w': 0}) is None) self.assertEqual(2, self.coll.count()) @@ -1110,7 +1091,7 @@ class TestBulkNoResults(BulkTestBase): batch.insert({'_id': 2}) batch.insert({'_id': 1}) batch.find({'_id': 1}).remove_one() - with client_context.client.start_request(): + with self.client.start_request(): self.assertTrue(batch.execute({'w': 0}) is None) self.assertEqual(3, self.coll.count()) @@ -1121,7 +1102,7 @@ class TestBulkNoResults(BulkTestBase): batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) batch.find({'_id': 1}).remove_one() - with client_context.client.start_request(): + with self.client.start_request(): self.assertTrue(batch.execute({'w': 0}) is None) self.assertEqual(2, self.coll.count()) @@ -1133,7 +1114,7 @@ class TestBulkNoResults(BulkTestBase): batch.insert({'_id': 2}) batch.insert({'_id': 1}) batch.find({'_id': 1}).remove_one() - with client_context.client.start_request(): + with self.client.start_request(): self.assertTrue(batch.execute({'w': 0}) is None) self.assertEqual(2, self.coll.count()) self.assertTrue(self.coll.find_one({'_id': 1}) is None) @@ -1145,13 +1126,10 @@ class TestBulkAuthorization(BulkTestBase): @client_context.require_auth @client_context.require_version_min(2, 5, 3) def setUpClass(cls): - cls.db = client_context.client.pymongo_test - cls.coll = cls.db.test + super(TestBulkAuthorization, cls).setUpClass() def setUp(self): super(TestBulkAuthorization, self).setUp() - self.coll.remove() - self.db.add_user('readonly', 'pw', roles=['read']) self.db.command( 'createRole', 'noremove', diff --git a/test/test_client.py b/test/test_client.py index acc7a1294..0ed5341cb 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -169,11 +169,6 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin): class TestClient(IntegrationTest, TestRequestMixin): - @classmethod - def setUpClass(cls): - super(TestClient, cls).setUpClass() - cls.client = client_context.rs_or_standalone_client - def test_constants(self): # Set bad defaults. MongoClient.HOST = "somedomainthatdoesntexist.org" diff --git a/test/test_collection.py b/test/test_collection.py index fb7c6094a..768c06a71 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -109,7 +109,6 @@ class TestCollection(IntegrationTest): @classmethod def setUpClass(cls): super(TestCollection, cls).setUpClass() - cls.db = client_context.client.pymongo_test cls.w = client_context.w @classmethod @@ -410,7 +409,7 @@ class TestCollection(IntegrationTest): @client_context.require_version_min(2, 3, 2) @client_context.require_no_mongos def test_index_text(self): - enable_text_search(client_context.client) + enable_text_search(self.client) db = self.db db.test.drop_indexes() @@ -955,7 +954,7 @@ class TestCollection(IntegrationTest): def test_unique_index(self): db = self.db - with client_context.client.start_request(): + with self.client.start_request(): db.drop_collection("test") db.test.create_index("hello") @@ -1041,7 +1040,7 @@ class TestCollection(IntegrationTest): docs.append({"four": 4}) docs.append({"five": 5}) - with client_context.client.start_request(): + with self.client.start_request(): db.test.insert(docs, manipulate=False, w=0) self.assertEqual(11000, db.error()['code']) self.assertEqual(1, db.test.count()) @@ -1093,7 +1092,7 @@ class TestCollection(IntegrationTest): db.test.insert, {"hello": {"a": 4, "b": 10}}) def test_safe_insert(self): - with client_context.client.start_request(): + with self.client.start_request(): db = self.db db.drop_collection("test") @@ -1145,7 +1144,7 @@ class TestCollection(IntegrationTest): def test_update_nmodified(self): db = self.db db.drop_collection("test") - used_write_commands = (client_context.client.max_wire_version > 1) + used_write_commands = (self.client.max_wire_version > 1) db.test.insert({'_id': 1}) result = db.test.update({'_id': 1}, {'$set': {'x': 1}}) @@ -1190,7 +1189,7 @@ class TestCollection(IntegrationTest): self.assertEqual(2, db.test.find_one()["count"]) def test_safe_update(self): - with client_context.client.start_request(): + with self.client.start_request(): db = self.db v113minus = client_context.version.at_least(1, 1, 3, -1) v19 = client_context.version.at_least(1, 9) @@ -1264,7 +1263,7 @@ class TestCollection(IntegrationTest): {})['n']) def test_safe_save(self): - with client_context.client.start_request(): + with self.client.start_request(): db = self.db db.drop_collection("test") db.test.create_index("hello", unique=True) @@ -1844,7 +1843,7 @@ class TestCollection(IntegrationTest): self.db.test.update({"bar": "x"}, {"bar": "x" * (max_size - 32)}) def test_insert_large_batch(self): - max_bson_size = client_context.client.max_bson_size + max_bson_size = self.client.max_bson_size if client_context.version.at_least(2, 5, 4, -1): # Write commands are limited to 16MB + 16k per batch big_string = 'x' * int(max_bson_size / 2) @@ -1868,7 +1867,7 @@ class TestCollection(IntegrationTest): # Test that inserts fail after first error, unacknowledged. self.db.test.drop() - with client_context.client.start_request(): + with self.client.start_request(): self.assertTrue(self.db.test.insert(batch, w=0)) self.assertEqual(1, self.db.test.count()) @@ -1887,7 +1886,7 @@ class TestCollection(IntegrationTest): # 2 batches, 2 errors, unacknowledged, continue on error self.db.test.drop() - with client_context.client.start_request(): + with self.client.start_request(): self.assertTrue( self.db.test.insert(batch, continue_on_error=True, w=0)) # Only the first and third documents should be inserted. @@ -2002,7 +2001,7 @@ class TestCollection(IntegrationTest): self.assertEqual(3, result.find_one({"_id": "cat"})["value"]) self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) - client_context.client.drop_database('mrtestdb') + self.client.drop_database('mrtestdb') full_result = db.test.map_reduce(map, reduce, out='mrunittests', full_response=True) @@ -2306,7 +2305,7 @@ class TestCollection(IntegrationTest): son['foo'] += 2 return son - db = client_context.client.pymongo_test + db = self.client.pymongo_test db.add_son_manipulator(IncByTwo()) c = db.test c.drop() diff --git a/test/test_common.py b/test/test_common.py index f6c15d6fd..e8b1bffdc 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -25,7 +25,7 @@ from bson.objectid import ObjectId from bson.son import SON from pymongo.mongo_client import MongoClient from pymongo.errors import ConfigurationError, OperationFailure -from test import client_context, pair, unittest +from test import client_context, pair, unittest, IntegrationTest from test.utils import get_client, connected @@ -34,10 +34,9 @@ def setUpModule(): pass -class TestCommon(unittest.TestCase): +class TestCommon(IntegrationTest): def test_uuid_subtype(self): - self.db = client_context.client.pymongo_test coll = self.db.uuid coll.drop() diff --git a/test/test_cursor.py b/test/test_cursor.py index 5492ff192..821bfbee5 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -138,11 +138,6 @@ class TestCursorNoConnect(unittest.TestCase): class TestCursor(IntegrationTest): - @classmethod - def setUpClass(cls): - super(TestCursor, cls).setUpClass() - cls.db = client_context.client.pymongo_test - @client_context.require_version_min(2, 5, 3, -1) def test_max_time_ms(self): db = self.db @@ -168,7 +163,7 @@ class TestCursor(IntegrationTest): self.assertTrue(coll.find_one(max_time_ms=1000)) - client = client_context.client + client = self.client if "enableTestCommands=1" in client_context.cmd_line['argv']: # Cursor parses server timeout error in response to initial query. client.admin.command("configureFailPoint", @@ -199,9 +194,9 @@ class TestCursor(IntegrationTest): # Send initial query before turning on failpoint. next(cursor) - client_context.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + self.client.admin.command("configureFailPoint", + "maxTimeAlwaysTimeOut", + mode="alwaysOn") try: try: # Iterate up to first getmore. @@ -211,9 +206,9 @@ class TestCursor(IntegrationTest): else: self.fail("ExecutionTimeout not raised") finally: - client_context.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + self.client.admin.command("configureFailPoint", + "maxTimeAlwaysTimeOut", + mode="off") def test_explain(self): a = self.db.test.find() @@ -1101,8 +1096,14 @@ class TestCursor(IntegrationTest): def test_cursor_transfer(self): # This is just a test, don't try this at home... - self.db.test.remove({}) - self.db.test.insert({'_id': i} for i in range(200)) + + # set_cursor_manager is only allowed on direct connections, not on + # replica set connections. + client = client_context.client + db = client.pymongo_test + + db.test.remove({}) + db.test.insert({'_id': i} for i in range(200)) class CManager(CursorManager): def __init__(self, connection): @@ -1112,12 +1113,11 @@ class TestCursor(IntegrationTest): # Do absolutely nothing... pass - client = self.db.connection try: with ignore_deprecations(): client.set_cursor_manager(CManager) docs = [] - cursor = self.db.test.find().batch_size(10) + cursor = db.test.find().batch_size(10) docs.append(next(cursor)) cursor.close() docs.extend(cursor) diff --git a/test/test_cursor_manager.py b/test/test_cursor_manager.py index 842aa2cf2..8ea1815ff 100644 --- a/test/test_cursor_manager.py +++ b/test/test_cursor_manager.py @@ -33,7 +33,7 @@ class TestCursorManager(IntegrationTest): @classmethod def setUpClass(cls): super(TestCursorManager, cls).setUpClass() - cls.collection = client_context.client.pymongo_test.test + cls.collection = cls.db.test cls.collection.remove() # Ensure two batches. diff --git a/test/test_database.py b/test/test_database.py index 53ff06b77..7759a2cb4 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -130,11 +130,6 @@ class TestDatabaseNoConnect(unittest.TestCase): class TestDatabase(IntegrationTest): - @classmethod - def setUpClass(cls): - super(TestDatabase, cls).setUpClass() - cls.client = client_context.client - def test_repr(self): self.assertEqual(repr(Database(self.client, "pymongo_test")), "Database(%r, %s)" % (self.client, @@ -515,7 +510,7 @@ class TestDatabase(IntegrationTest): # "Non-admin" user db = auth_c.pymongo_test - client_context.client.pymongo_test.add_user('user', 'pass') + self.client.pymongo_test.add_user('user', 'pass') try: db.authenticate('user', 'pass') info = db.command('usersInfo', 'user')['users'][0] @@ -523,8 +518,8 @@ class TestDatabase(IntegrationTest): db.logout() # Read only "Non-admin" user - client_context.client.pymongo_test.add_user('ro-user', 'pass', - read_only=True) + self.client.pymongo_test.add_user('ro-user', 'pass', + read_only=True) db.authenticate('ro-user', 'pass') info = db.command('usersInfo', 'ro-user')['users'][0] self.assertEqual("read", info['roles'][0]['role']) @@ -540,8 +535,8 @@ class TestDatabase(IntegrationTest): def test_new_user_cmds(self): auth_c = MongoClient(pair) db = auth_c.pymongo_test - client_context.client.pymongo_test.add_user("amalia", "password", - roles=["userAdmin"]) + self.client.pymongo_test.add_user("amalia", "password", + roles=["userAdmin"]) db.authenticate("amalia", "password") try: @@ -562,7 +557,7 @@ class TestDatabase(IntegrationTest): auth_c = MongoClient(pair) db = auth_c.auth_test - client_context.client.auth_test.add_user( + self.client.auth_test.add_user( "bernie", "password", roles=["userAdmin", "dbAdmin", "readWrite"]) db.authenticate("bernie", "password") @@ -585,14 +580,14 @@ class TestDatabase(IntegrationTest): @client_context.require_auth def test_authenticate_multiple(self): # Setup - client_context.client.drop_database("pymongo_test") - client_context.client.drop_database("pymongo_test1") + self.client.drop_database("pymongo_test") + self.client.drop_database("pymongo_test1") auth_c = MongoClient(pair) users_db = auth_c.pymongo_test admin_db = auth_c.admin other_db = auth_c.pymongo_test1 - client_context.client.admin.add_user( + self.client.admin.add_user( 'admin', 'pass', roles=["userAdminAnyDatabase", "dbAdmin", "clusterAdmin", "readWrite"]) @@ -638,8 +633,8 @@ class TestDatabase(IntegrationTest): # Cleanup finally: remove_all_users(users_db) - client_context.client.admin.remove_user('ro-admin') - client_context.client.admin.remove_user('admin') + self.client.admin.remove_user('ro-admin') + self.client.admin.remove_user('admin') def test_id_ordering(self): # PyMongo attempts to have _id show up first diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 691615c23..b7fdd1d28 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -94,11 +94,6 @@ class TestGridFileNoConnect(unittest.TestCase): class TestGridFile(IntegrationTest): - @classmethod - def setUpClass(cls): - super(TestGridFile, cls).setUpClass() - cls.db = client_context.client.pymongo_test - def setUp(self): self.db.drop_collection('fs.files') self.db.drop_collection('fs.chunks') diff --git a/test/test_gridfs.py b/test/test_gridfs.py index b24057f4f..4b4e0df27 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -91,7 +91,6 @@ class TestGridfs(IntegrationTest): @classmethod def setUpClass(cls): super(TestGridfs, cls).setUpClass() - cls.db = client_context.client.pymongo_test cls.fs = gridfs.GridFS(cls.db) cls.alt = gridfs.GridFS(cls.db, "alt") diff --git a/test/test_json_util.py b/test/test_json_util.py index c5c1e5634..9015320be 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -238,11 +238,6 @@ class TestJsonUtil(unittest.TestCase): class TestJsonUtilRoundtrip(IntegrationTest): - @classmethod - def setUpClass(cls): - super(TestJsonUtilRoundtrip, cls).setUpClass() - cls.db = client_context.client.pymongo_test - def setUp(self): if not json_util.json_lib: raise SkipTest("No json or simplejson module") diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index a13a7817b..b02787a8e 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -51,13 +51,13 @@ class TestReadPreferencesBase(TestReplicaSetClientBase): def setUp(self): super(TestReadPreferencesBase, self).setUp() # Insert some data so we can use cursors in read_from_which_host - client_context.client.pymongo_test.test.drop() - client_context.client.pymongo_test.test.insert( + self.client.pymongo_test.test.drop() + self.client.pymongo_test.test.insert( [{'_id': i} for i in range(10)], w=self.w) def tearDown(self): super(TestReadPreferencesBase, self).tearDown() - client_context.client.pymongo_test.test.drop() + self.client.pymongo_test.test.drop() def read_from_which_host(self, client): """Do a find() on the client and return which host was used diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index f19506922..aa8b9b81b 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -35,6 +35,7 @@ from pymongo.read_preferences import ReadPreference, Secondary, Nearest from test import (client_context, client_knobs, connection_string, + IntegrationTest, pair, port, SkipTest, @@ -50,14 +51,16 @@ from test.utils import ( from test.version import Version -class TestReplicaSetClientBase(unittest.TestCase): +class TestReplicaSetClientBase(IntegrationTest): @classmethod @client_context.require_replica_set def setUpClass(cls): + super(TestReplicaSetClientBase, cls).setUpClass() cls.name = client_context.setname - ismaster = client_context.ismaster cls.w = client_context.w + + ismaster = client_context.ismaster cls.hosts = set(partition_node(h) for h in ismaster['hosts']) cls.arbiters = set(partition_node(h) for h in ismaster.get("arbiters", []))