From 26fb43cf787f4f2430760bbe33b5bba1d5a38266 Mon Sep 17 00:00:00 2001 From: Luke Lovett Date: Wed, 30 Apr 2014 18:50:50 +0000 Subject: [PATCH] PYTHON-681 Reuse MongoClient whenever possible in the tests --- test/__init__.py | 111 ++++++++ test/high_availability/test_ha.py | 5 +- test/test_auth.py | 18 +- test/test_binary.py | 20 +- test/test_bulk.py | 93 ++++--- test/test_client.py | 260 +++++++++---------- test/test_collection.py | 410 ++++++++++++++---------------- test/test_common.py | 54 ++-- test/test_cursor.py | 72 ++---- test/test_database.py | 146 ++++------- test/test_grid_file.py | 11 +- test/test_gridfs.py | 29 +-- test/test_json_util.py | 12 +- test/test_pooling_base.py | 5 +- test/test_pymongo.py | 8 +- test/test_read_preferences.py | 23 +- test/test_replica_set_client.py | 234 ++++++++--------- test/test_son_manipulator.py | 12 +- test/test_ssl.py | 41 +-- test/test_thread_util.py | 4 +- test/utils.py | 5 +- test/version.py | 68 ++--- 22 files changed, 800 insertions(+), 841 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 679ee520c..646b2e5d3 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -25,9 +25,12 @@ else: from unittest import SkipTest import warnings +from functools import wraps + import pymongo from bson.py3compat import _unicode +from test.version import Version # hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode # type in Python 2, so ensure these hostnames are unicodes, too. It makes tests @@ -43,6 +46,114 @@ host3 = _unicode(os.environ.get("DB_IP3", 'localhost')) port3 = int(os.environ.get("DB_PORT3", 27019)) +class ClientContext(object): + + def __init__(self): + """Create a client and grab essential information from the server.""" + try: + self.client = pymongo.MongoClient(host, port) + except pymongo.errors.ConnectionFailure: + self.client = None + else: + self.ismaster = self.client.admin.command('ismaster') + self.w = len(self.ismaster.get("hosts", [])) or 1 + self.setname = self.ismaster.get('setName', '') + self.rs_client = None + if self.setname: + self.rs_client = pymongo.MongoReplicaSetClient( + pair, replicaSet=self.setname) + self.cmd_line = self.client.admin.command('getCmdLineOpts') + self.version = Version.from_client(self.client) + self.auth_enabled = self._server_started_with_auth() + self.test_commands_enabled = ('testCommandsEnabled=1' + in self.cmd_line['argv']) + self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid') + + def _server_started_with_auth(self): + # MongoDB >= 2.0 + if 'parsed' in self.cmd_line: + parsed = self.cmd_line['parsed'] + # MongoDB >= 2.6 + if 'security' in parsed: + security = parsed['security'] + # >= rc3 + if 'authorization' in security: + return security['authorization'] == 'enabled' + # < rc3 + return (security.get('auth', False) or + bool(security.get('keyFile'))) + return parsed.get('auth', False) or bool(parsed.get('keyFile')) + # Legacy + argv = self.cmd_line['argv'] + return '--auth' in argv or '--keyFile' in argv + + def _require(self, condition, msg, func=None): + def make_wrapper(f): + @wraps(f) + def wrap(*args, **kwargs): + if condition: + return f(*args, **kwargs) + raise SkipTest(msg) + return wrap + + if func is None: + def decorate(f): + return make_wrapper(f) + return decorate + return make_wrapper(func) + + def require_version_min(self, *ver): + """Run a test only if the server version is at least ``version``.""" + other_version = Version(*ver) + return self._require(self.version >= other_version, + "Server version must be at least %s" + % str(other_version)) + + def require_version_max(self, *ver): + """Run a test only if the server version is at most ``version``.""" + other_version = Version(*ver) + return self._require(self.version <= other_version, + "Server version must be at most %s" + % str(other_version)) + + def require_auth(self, func): + """Run a test only if the server is running with auth enabled.""" + return self.check_auth_with_sharding( + self._require(self.auth_enabled, + "Authentication is not enabled on the server", + func=func)) + + def require_replica_set(self, func): + """Run a test only if the client is connected to a replica set.""" + return self._require(self.rs_client is not None, + "Not connected to a replica set", + func=func) + + def require_no_mongos(self, func): + """Run a test only if the client is not connected to a mongos.""" + return self._require(not self.is_mongos, + "Must be connected to a mongod, not a mongos", + func=func) + + def check_auth_with_sharding(self, func): + """Skip a test when connected to mongos < 2.0 and running with auth.""" + condition = not (self.auth_enabled and + self.is_mongos and self.version < (2,)) + return self._require(condition, + "Auth with sharding requires MongoDB >= 2.0.0", + func=func) + + def require_test_commands(self, func): + """Run a test only if the server has test commands enabled.""" + return self._require(self.test_commands_enabled, + "Test commands must be enabled", + func=func) + + +# Reusable client context +client_context = ClientContext() + + def setup(): warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/high_availability/test_ha.py b/test/high_availability/test_ha.py index 197d42a64..712c7bb49 100644 --- a/test/high_availability/test_ha.py +++ b/test/high_availability/test_ha.py @@ -34,8 +34,9 @@ from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.mongo_client import MongoClient, _partition_node from pymongo.read_preferences import ReadPreference -from test import SkipTest, unittest, utils, version +from test import SkipTest, unittest, utils from test.utils import one +from test.version import Version # May be imported from gevent, below. @@ -1013,7 +1014,7 @@ class TestLastErrorDefaults(HATestCase): use_greenlets=use_greenlets) def test_get_last_error_defaults(self): - if not version.at_least(self.c, (1, 9, 0)): + if not Version.from_client(self.c).at_least(1, 9, 0): raise SkipTest("Need MongoDB >= 1.9.0 to test getLastErrorDefaults") replset = self.c.local.system.replset.find_one() diff --git a/test/test_auth.py b/test/test_auth.py index b7d432b9a..0aa12cb7f 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -30,8 +30,7 @@ from pymongo import MongoClient, MongoReplicaSetClient from pymongo.auth import HAVE_KERBEROS from pymongo.errors import OperationFailure, ConfigurationError from pymongo.read_preferences import ReadPreference -from test import host, port, SkipTest, unittest, version -from test.utils import is_mongos, server_started_with_auth +from test import client_context, host, port, SkipTest, unittest # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS. GSSAPI_HOST = os.environ.get('GSSAPI_HOST') @@ -229,13 +228,9 @@ class TestSASL(unittest.TestCase): class TestAuthURIOptions(unittest.TestCase): + @client_context.require_auth def setUp(self): client = MongoClient(host, port) - # Sharded auth not supported before MongoDB 2.0 - if is_mongos(client) and not version.at_least(client, (2, 0, 0)): - raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") - if not server_started_with_auth(client): - raise SkipTest('Authentication is not enabled on server') response = client.admin.command('ismaster') self.set_name = str(response.get('setName', '')) client.admin.add_user('admin', 'pass', roles=['userAdminAnyDatabase', @@ -313,14 +308,11 @@ class TestAuthURIOptions(unittest.TestCase): class TestDelegatedAuth(unittest.TestCase): + @client_context.require_auth + @client_context.require_version_max(2, 5, 3) + @client_context.require_version_min(2, 4, 0) def setUp(self): self.client = MongoClient(host, port) - if not version.at_least(self.client, (2, 4, 0)): - raise SkipTest('Delegated authentication requires MongoDB >= 2.4.0') - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') - if version.at_least(self.client, (2, 5, 3, -1)): - raise SkipTest('Delegated auth does not exist in MongoDB >= 2.5.3') # Give admin all privileges. self.client.admin.add_user('admin', 'pass', roles=['readAnyDatabase', diff --git a/test/test_binary.py b/test/test_binary.py index 9e6470dc6..9c0a5db4b 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -27,8 +27,7 @@ import bson from bson.binary import * from bson.py3compat import u from bson.son import SON -from test import unittest -from test.test_client import get_client +from test import client_context, unittest from pymongo.mongo_client import MongoClient @@ -147,9 +146,8 @@ class TestBinary(unittest.TestCase): self.assertEqual(data, encoded) # Test insert and find - client = get_client() - client.pymongo_test.drop_collection('java_uuid') - coll = client.pymongo_test.java_uuid + client_context.client.pymongo_test.drop_collection('java_uuid') + coll = client_context.client.pymongo_test.java_uuid coll.uuid_subtype = JAVA_LEGACY coll.insert(docs) @@ -160,7 +158,7 @@ class TestBinary(unittest.TestCase): coll.uuid_subtype = OLD_UUID_SUBTYPE for d in coll.find(): self.assertNotEqual(d['newguid'], d['newguidstring']) - client.pymongo_test.drop_collection('java_uuid') + client_context.client.pymongo_test.drop_collection('java_uuid') def test_legacy_csharp_uuid(self): @@ -217,9 +215,8 @@ class TestBinary(unittest.TestCase): self.assertEqual(data, encoded) # Test insert and find - client = get_client() - client.pymongo_test.drop_collection('csharp_uuid') - coll = client.pymongo_test.csharp_uuid + client_context.client.pymongo_test.drop_collection('csharp_uuid') + coll = client_context.client.pymongo_test.csharp_uuid coll.uuid_subtype = CSHARP_LEGACY coll.insert(docs) @@ -230,7 +227,7 @@ class TestBinary(unittest.TestCase): coll.uuid_subtype = OLD_UUID_SUBTYPE for d in coll.find(): self.assertNotEqual(d['newguid'], d['newguidstring']) - client.pymongo_test.drop_collection('csharp_uuid') + client_context.client.pymongo_test.drop_collection('csharp_uuid') def test_uri_to_uuid(self): @@ -240,8 +237,7 @@ class TestBinary(unittest.TestCase): def test_uuid_queries(self): - c = get_client() - coll = c.pymongo_test.test + coll = client_context.client.pymongo_test.test coll.drop() uu = uuid.uuid4() diff --git a/test/test_bulk.py b/test/test_bulk.py index b048e5da7..3217a576d 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -21,19 +21,14 @@ sys.path[0:0] = [""] from bson import InvalidDocument, SON from bson.py3compat import string_type from pymongo.errors import BulkWriteError, InvalidOperation, OperationFailure -from test import SkipTest, unittest, version -from test.test_client import get_client -from test.utils import (oid_generated_on_client, - remove_all_users, - server_started_with_auth, - server_started_with_nojournal) +from test import client_context, unittest +from test.utils import oid_generated_on_client, remove_all_users class BulkTestBase(unittest.TestCase): def setUp(self): - client = get_client() - self.has_write_commands = (client.max_wire_version > 1) + self.has_write_commands = (client_context.client.max_wire_version > 1) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -110,9 +105,12 @@ class BulkTestBase(unittest.TestCase): class TestBulk(BulkTestBase): + @classmethod + def setUpClass(cls): + cls.coll = client_context.client.pymongo_test.test + def setUp(self): super(TestBulk, self).setUp() - self.coll = get_client().pymongo_test.test self.coll.remove() def test_empty(self): @@ -583,8 +581,7 @@ class TestBulk(BulkTestBase): self.assertEqual(self.coll.find({'x': 2}).count(), 0) def test_upsert_large(self): - client = self.coll.database.connection - big = 'a' * (client.max_bson_size - 37) + big = 'a' * (client_context.client.max_bson_size - 37) bulk = self.coll.initialize_ordered_bulk_op() bulk.find({'x': 1}).upsert().update({'$set': {'s': big}}) result = bulk.execute() @@ -886,29 +883,26 @@ class TestBulk(BulkTestBase): class TestBulkWriteConcern(BulkTestBase): + @classmethod + def setUpClass(cls): + cls.is_repl = ('setName' in client_context.ismaster) + cls.w = client_context.w + cls.coll = client_context.client.pymongo_test.test + def setUp(self): super(TestBulkWriteConcern, self).setUp() - client = get_client() - ismaster = client.test.command('ismaster') - self.is_repl = bool(ismaster.get('setName')) - self.w = len(ismaster.get("hosts", [])) - self.client = client - self.coll = client.pymongo_test.test self.coll.remove() + @client_context.require_version_min(1, 8, 2) def test_fsync_and_j(self): - if not version.at_least(self.client, (1, 8, 2)): - raise SkipTest("Need at least MongoDB 1.8.2") batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1}) self.assertRaises( OperationFailure, batch.execute, {'fsync': True, 'j': True}) + @client_context.require_replica_set def test_write_concern_failure_ordered(self): - if not self.is_repl: - raise SkipTest("Need a replica set to test.") - # Ensure we don't raise on wnote. batch = self.coll.initialize_ordered_bulk_op() batch.find({"something": "that does no exist"}).remove() @@ -985,10 +979,8 @@ class TestBulkWriteConcern(BulkTestBase): finally: self.coll.drop_index([('a', 1)]) + @client_context.require_replica_set def test_write_concern_failure_unordered(self): - if not self.is_repl: - raise SkipTest("Need a replica set to test.") - # Ensure we don't raise on wnote. batch = self.coll.initialize_unordered_bulk_op() batch.find({"something": "that does no exist"}).remove() @@ -1063,9 +1055,12 @@ class TestBulkWriteConcern(BulkTestBase): class TestBulkNoResults(BulkTestBase): + @classmethod + def setUpClass(cls): + cls.coll = client_context.client.pymongo_test.test + def setUp(self): super(TestBulkNoResults, self).setUp() - self.coll = get_client().pymongo_test.test self.coll.remove() def test_no_results_ordered_success(self): @@ -1114,21 +1109,21 @@ class TestBulkNoResults(BulkTestBase): class TestBulkAuthorization(BulkTestBase): + @classmethod + @client_context.require_auth + @client_context.require_version_min(2, 5, 3) + def setUpClass(cls): + cls.db = client_context.client.pymongo_test + cls.coll = cls.db.test + def setUp(self): super(TestBulkAuthorization, self).setUp() - self.client = client = get_client() - if (not server_started_with_auth(client) - or not version.at_least(client, (2, 5, 3))): - raise SkipTest('Need at least MongoDB 2.5.3 with auth') - - db = client.pymongo_test - self.coll = db.test self.coll.remove() - db.add_user('dbOwner', 'pw', roles=['dbOwner']) - db.authenticate('dbOwner', 'pw') - db.add_user('readonly', 'pw', roles=['read']) - db.command( + self.db.add_user('dbOwner', 'pw', roles=['dbOwner']) + self.db.authenticate('dbOwner', 'pw') + self.db.add_user('readonly', 'pw', roles=['read']) + self.db.command( 'createRole', 'noremove', privileges=[{ 'actions': ['insert', 'update', 'find'], @@ -1136,13 +1131,21 @@ class TestBulkAuthorization(BulkTestBase): }], roles=[]) - db.add_user('noremove', 'pw', roles=['noremove']) - db.logout() + self.db.add_user('noremove', 'pw', roles=['noremove']) + self.db.logout() + + def tearDown(self): + self.db.logout() + self.db.authenticate('dbOwner', 'pw') + self.db.command('dropRole', 'noremove') + remove_all_users(self.db) + self.db.logout() + self.db.connection.disconnect() def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - db = self.client.pymongo_test + db = client_context.client.pymongo_test db.authenticate('readonly', 'pw') bulk = self.coll.initialize_ordered_bulk_op() bulk.insert({'x': 1}) @@ -1151,7 +1154,7 @@ class TestBulkAuthorization(BulkTestBase): def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - db = self.client.pymongo_test + db = client_context.client.pymongo_test db.authenticate('noremove', 'pw') bulk = self.coll.initialize_ordered_bulk_op() bulk.insert({'x': 1}) @@ -1161,14 +1164,6 @@ class TestBulkAuthorization(BulkTestBase): self.assertRaises(OperationFailure, bulk.execute) self.assertEqual(set([1, 2]), set(self.coll.distinct('x'))) - def tearDown(self): - db = self.client.pymongo_test - db.logout() - db.authenticate('dbOwner', 'pw') - db.command('dropRole', 'noremove') - remove_all_users(db) - db.logout() - if __name__ == "__main__": unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 5db0e61e0..1065eb0b4 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -36,14 +36,12 @@ from pymongo.errors import (AutoReconnect, InvalidName, OperationFailure, PyMongoError) -from test import host, pair, port, SkipTest, unittest, version +from test import client_context, host, pair, port, SkipTest, unittest from test.pymongo_mocks import MockClient from test.utils import (assertRaisesExactly, delay, - is_mongos, remove_all_users, server_is_master_with_slave, - server_started_with_auth, TestRequestMixin, _TestLazyConnectMixin, lazy_client_trial, @@ -56,6 +54,10 @@ def get_client(*args, **kwargs): class TestClient(unittest.TestCase, TestRequestMixin): + @classmethod + def setUpClass(cls): + cls.client = client_context.client + def test_types(self): self.assertRaises(TypeError, MongoClient, 1) self.assertRaises(TypeError, MongoClient, 1.14) @@ -108,7 +110,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertEqual(host, c.host) self.assertEqual(port, c.port) - if version.at_least(c, (2, 5, 4, -1)): + if client_context.version.at_least(2, 5, 4, -1): self.assertTrue(c.max_wire_version > 0) else: self.assertEqual(c.max_wire_version, 0) @@ -136,10 +138,10 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertTrue(MongoClient(host, port)) def test_equality(self): - client = MongoClient(host, port) - self.assertEqual(client, MongoClient(host, port)) + # ClientContext.client is constructed as MongoClient(host, port) + self.assertEqual(self.client, MongoClient(host, port)) # Explicitly test inequality - self.assertFalse(client != MongoClient(host, port)) + self.assertFalse(self.client != MongoClient(host, port)) def test_host_w_port(self): self.assertTrue(MongoClient("%s:%d" % (host, port))) @@ -153,10 +155,9 @@ class TestClient(unittest.TestCase, TestRequestMixin): "MongoClient('%s', %d)" % (host, port)) def test_getters(self): - self.assertEqual(MongoClient(host, port).host, host) - self.assertEqual(MongoClient(host, port).port, port) - self.assertEqual(set([(host, port)]), - MongoClient(host, port).nodes) + self.assertEqual(self.client.host, host) + self.assertEqual(self.client.port, port) + self.assertEqual(set([(host, port)]), self.client.nodes) def test_use_greenlets(self): self.assertFalse(MongoClient(host, port).use_greenlets) @@ -166,62 +167,56 @@ class TestClient(unittest.TestCase, TestRequestMixin): host, port, use_greenlets=True).use_greenlets) def test_get_db(self): - client = MongoClient(host, port) - def make_db(base, name): return base[name] - self.assertRaises(InvalidName, make_db, client, "") - self.assertRaises(InvalidName, make_db, client, "te$t") - self.assertRaises(InvalidName, make_db, client, "te.t") - self.assertRaises(InvalidName, make_db, client, "te\\t") - self.assertRaises(InvalidName, make_db, client, "te/t") - self.assertRaises(InvalidName, make_db, client, "te st") + self.assertRaises(InvalidName, make_db, self.client, "") + self.assertRaises(InvalidName, make_db, self.client, "te$t") + self.assertRaises(InvalidName, make_db, self.client, "te.t") + self.assertRaises(InvalidName, make_db, self.client, "te\\t") + self.assertRaises(InvalidName, make_db, self.client, "te/t") + self.assertRaises(InvalidName, make_db, self.client, "te st") - self.assertTrue(isinstance(client.test, Database)) - self.assertEqual(client.test, client["test"]) - self.assertEqual(client.test, Database(client, "test")) + self.assertTrue(isinstance(self.client.test, Database)) + self.assertEqual(self.client.test, self.client["test"]) + self.assertEqual(self.client.test, Database(self.client, "test")) def test_database_names(self): - client = MongoClient(host, port) + self.client.pymongo_test.test.save({"dummy": u("object")}) + self.client.pymongo_test_mike.test.save({"dummy": u("object")}) - client.pymongo_test.test.save({"dummy": u("object")}) - client.pymongo_test_mike.test.save({"dummy": u("object")}) - - dbs = client.database_names() + dbs = self.client.database_names() self.assertTrue("pymongo_test" in dbs) self.assertTrue("pymongo_test_mike" in dbs) def test_drop_database(self): - client = MongoClient(host, port) - - self.assertRaises(TypeError, client.drop_database, 5) - self.assertRaises(TypeError, client.drop_database, None) + self.assertRaises(TypeError, self.client.drop_database, 5) + self.assertRaises(TypeError, self.client.drop_database, None) raise SkipTest("This test often fails due to SERVER-2329") - client.pymongo_test.test.save({"dummy": u("object")}) - dbs = client.database_names() + self.client.pymongo_test.test.save({"dummy": u("object")}) + dbs = self.client.database_names() self.assertTrue("pymongo_test" in dbs) - client.drop_database("pymongo_test") - dbs = client.database_names() + self.client.drop_database("pymongo_test") + dbs = self.client.database_names() self.assertTrue("pymongo_test" not in dbs) - client.pymongo_test.test.save({"dummy": u("object")}) - dbs = client.database_names() + self.client.pymongo_test.test.save({"dummy": u("object")}) + dbs = self.client.database_names() self.assertTrue("pymongo_test" in dbs) - client.drop_database(client.pymongo_test) - dbs = client.database_names() + self.client.drop_database(self.client.pymongo_test) + dbs = self.client.database_names() self.assertTrue("pymongo_test" not in dbs) def test_copy_db(self): - c = MongoClient(host, port) + c = self.client # Due to SERVER-2329, databases may not disappear # from a master in a master-slave pair. if server_is_master_with_slave(c): raise SkipTest("SERVER-2329") - if (not version.at_least(c, (2, 6, 0)) and - is_mongos(c) and server_started_with_auth(c)): + if (client_context.version.at_least(2, 6, 0) and + client_context.is_mongos and client_context.auth_enabled): raise SkipTest("Need mongos >= 2.6.0 to test with authentication") # We test copy twice; once starting in a request and once not. In # either case the copy should succeed (because it starts a request @@ -259,8 +254,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"]) # See SERVER-6427 for mongos - if not is_mongos(c) and server_started_with_auth(c): - + if not client_context.is_mongos and client_context.auth_enabled: c.drop_database("pymongo_test1") c.admin.add_user("admin", "password") @@ -286,34 +280,31 @@ class TestClient(unittest.TestCase, TestRequestMixin): # Cleanup remove_all_users(c.pymongo_test) c.admin.remove_user("admin") + c.admin.logout() c.disconnect() def test_iteration(self): - client = MongoClient(host, port) - def iterate(): - [a for a in client] + [a for a in self.client] self.assertRaises(TypeError, iterate) def test_disconnect(self): - c = MongoClient(host, port) - coll = c.pymongo_test.bar + coll = self.client.pymongo_test.bar - c.disconnect() - c.disconnect() + self.client.disconnect() + self.client.disconnect() coll.count() - c.disconnect() - c.disconnect() + self.client.disconnect() + self.client.disconnect() coll.count() def test_from_uri(self): - c = MongoClient(host, port) - - self.assertEqual(c, MongoClient("mongodb://%s:%d" % (host, port))) + self.assertEqual(self.client, + MongoClient("mongodb://%s:%d" % (host, port))) def test_get_default_database(self): c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False) @@ -330,18 +321,13 @@ class TestClient(unittest.TestCase, TestRequestMixin): c = MongoClient(uri, _connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) + @client_context.require_auth def test_auth_from_uri(self): - c = MongoClient(host, port) - # Sharded auth not supported before MongoDB 2.0 - if is_mongos(c) and not version.at_least(c, (2, 0, 0)): - raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") - if not server_started_with_auth(c): - raise SkipTest('Authentication is not enabled on server') - - c.admin.add_user("admin", "pass") - c.admin.authenticate("admin", "pass") + self.client.admin.add_user("admin", "pass") + self.client.admin.authenticate("admin", "pass") try: - c.pymongo_test.add_user("user", "pass", roles=['userAdmin', 'readWrite']) + self.client.pymongo_test.add_user( + "user", "pass", roles=['userAdmin', 'readWrite']) self.assertRaises(ConfigurationError, MongoClient, "mongodb://foo:bar@%s:%d" % (host, port)) @@ -375,18 +361,13 @@ class TestClient(unittest.TestCase, TestRequestMixin): finally: # Clean up. - remove_all_users(c.pymongo_test) - remove_all_users(c.admin) + remove_all_users(self.client.pymongo_test) + remove_all_users(self.client.admin) + self.client.admin.logout() + self.client.disconnect() + @client_context.require_auth def test_lazy_auth_raises_operation_failure(self): - # Check if we have the prerequisites to run this test. - c = MongoClient(host, port) - if not server_started_with_auth(c): - raise SkipTest('Authentication is not enabled on server') - - if is_mongos(c) and not version.at_least(c, (2, 0, 0)): - raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") - lazy_client = MongoClient( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), _connect=False) @@ -397,8 +378,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): def test_unix_socket(self): if not hasattr(socket, "AF_UNIX"): raise SkipTest("UNIX-sockets are not supported on this system") - if (sys.platform == 'darwin' and - server_started_with_auth(MongoClient(host, port))): + if sys.platform == 'darwin' and client_context.auth_enabled: raise SkipTest("SERVER-8492") mongodb_socket = '/tmp/mongodb-27017.sock' @@ -428,7 +408,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): except ImportError: raise SkipTest("No multiprocessing module") - db = MongoClient(host, port).pymongo_test + db = self.client.pymongo_test # Failure occurs if the client is used before the fork db.test.find_one() @@ -478,7 +458,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): pass def test_document_class(self): - c = MongoClient(host, port) + c = self.client db = c.pymongo_test db.test.insert({"x": 1}) @@ -488,9 +468,12 @@ class TestClient(unittest.TestCase, TestRequestMixin): c.document_class = SON - self.assertEqual(SON, c.document_class) - self.assertTrue(isinstance(db.test.find_one(), SON)) - self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) + try: + self.assertEqual(SON, c.document_class) + self.assertTrue(isinstance(db.test.find_one(), SON)) + self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) + finally: + c.document_class = dict c = MongoClient(host, port, document_class=SON) db = c.pymongo_test @@ -531,7 +514,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): get_client, socketTimeoutMS='foo') def test_socket_timeout(self): - no_timeout = MongoClient(host, port) + no_timeout = self.client timeout_sec = 1 timeout = MongoClient( host, port, socketTimeoutMS=1000 * timeout_sec) @@ -562,7 +545,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertRaises(ConfigurationError, MongoClient, tz_aware='foo') aware = MongoClient(host, port, tz_aware=True) - naive = MongoClient(host, port) + naive = self.client aware.pymongo_test.drop_collection("test") now = datetime.datetime.utcnow() @@ -595,28 +578,26 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertTrue("pymongo_test" in dbs) self.assertTrue("pymongo_test_bernie" in dbs) + @client_context.require_no_mongos def test_fsync_lock_unlock(self): - c = get_client() - if is_mongos(c): - raise SkipTest('fsync/lock not supported by mongos') - if not version.at_least(c, (2, 0)) and server_started_with_auth(c): + if (not client_context.version.at_least(2, 0) and + client_context.auth_enabled): raise SkipTest('Requires server >= 2.0 to test with auth') - - res = c.admin.command('getCmdLineOpts') - if '--master' in res['argv'] and version.at_least(c, (2, 3, 0)): + if (server_is_master_with_slave(client_context.client) and + client_context.version.at_least(2, 3, 0)): raise SkipTest('SERVER-7714') - self.assertFalse(c.is_locked) + self.assertFalse(self.client.is_locked) # async flushing not supported on windows... if sys.platform not in ('cygwin', 'win32'): - c.fsync(async=True) - self.assertFalse(c.is_locked) - c.fsync(lock=True) - self.assertTrue(c.is_locked) + self.client.fsync(async=True) + self.assertFalse(self.client.is_locked) + self.client.fsync(lock=True) + self.assertTrue(self.client.is_locked) locked = True - c.unlock() + self.client.unlock() for _ in range(5): - locked = c.is_locked + locked = self.client.is_locked if not locked: break time.sleep(1) @@ -633,26 +614,23 @@ class TestClient(unittest.TestCase, TestRequestMixin): # pool self.assertEqual(1, len(get_pool(client).sockets)) - # We need exec here because if the Python version is less than 2.6 - # these with-statements won't even compile. with contextlib.closing(client): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) self.assertEqual(None, client._MongoClient__member) - with get_client() as client: + with self.client as client: self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) self.assertEqual(None, client._MongoClient__member) def test_with_start_request(self): - client = get_client() - pool = get_pool(client) + pool = get_pool(self.client) # No request started self.assertNoRequest(pool) self.assertDifferentSock(pool) # Start a request - request_context_mgr = client.start_request() + request_context_mgr = self.client.start_request() self.assertTrue( isinstance(request_context_mgr, object) ) @@ -667,8 +645,8 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertDifferentSock(pool) # Test the 'with' statement - with client.start_request() as request: - self.assertEqual(client, request.connection) + with self.client.start_request() as request: + self.assertEqual(self.client, request.connection) self.assertNoSocketYet(pool) self.assertSameSock(pool) self.assertRequestSocket(pool) @@ -685,8 +663,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): ) # auto_start_request should default to False - client = get_client() - self.assertFalse(client.auto_start_request) + self.assertFalse(self.client.auto_start_request) client = get_client(auto_start_request=True) self.assertTrue(client.auto_start_request) @@ -709,35 +686,34 @@ class TestClient(unittest.TestCase, TestRequestMixin): def test_nested_request(self): # auto_start_request is False - client = get_client() - pool = get_pool(client) - self.assertFalse(client.in_request()) + pool = get_pool(self.client) + self.assertFalse(self.client.in_request()) # Start and end request - client.start_request() - self.assertInRequestAndSameSock(client, pool) - client.end_request() - self.assertNotInRequestAndDifferentSock(client, pool) + self.client.start_request() + self.assertInRequestAndSameSock(self.client, pool) + self.client.end_request() + self.assertNotInRequestAndDifferentSock(self.client, pool) # Double-nesting - client.start_request() - client.start_request() - client.end_request() - self.assertInRequestAndSameSock(client, pool) - client.end_request() - self.assertNotInRequestAndDifferentSock(client, pool) + self.client.start_request() + self.client.start_request() + self.client.end_request() + self.assertInRequestAndSameSock(self.client, pool) + self.client.end_request() + self.assertNotInRequestAndDifferentSock(self.client, pool) # Extra end_request calls have no effect - count stays at zero - client.end_request() - self.assertNotInRequestAndDifferentSock(client, pool) + self.client.end_request() + self.assertNotInRequestAndDifferentSock(self.client, pool) - client.start_request() - self.assertInRequestAndSameSock(client, pool) - client.end_request() - self.assertNotInRequestAndDifferentSock(client, pool) + self.client.start_request() + self.assertInRequestAndSameSock(self.client, pool) + self.client.end_request() + self.assertNotInRequestAndDifferentSock(self.client, pool) def test_request_threads(self): - client = get_client(auto_start_request=False) + client = self.client pool = get_pool(client) self.assertNotInRequestAndDifferentSock(client, pool) @@ -787,8 +763,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): # Test fix for PYTHON-294 -- make sure MongoClient closes its # socket if it gets an interrupt while waiting to recv() from it. - c = get_client() - db = c.pymongo_test + db = self.client.pymongo_test # A $where clause which takes 1.5 sec to execute where = delay(1.5) @@ -826,17 +801,17 @@ class TestClient(unittest.TestCase, TestRequestMixin): def test_operation_failure_without_request(self): # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. - c = get_client() - pool = get_pool(c) - self.assertEqual(1, len(pool.sockets)) + pool = get_pool(self.client) + socket_count = len(pool.sockets) + self.assertGreaterEqual(socket_count, 1) old_sock_info = next(iter(pool.sockets)) - c.pymongo_test.test.drop() - c.pymongo_test.test.insert({'_id': 'foo'}) + self.client.pymongo_test.test.drop() + self.client.pymongo_test.test.insert({'_id': 'foo'}) self.assertRaises( OperationFailure, - c.pymongo_test.test.insert, {'_id': 'foo'}) + self.client.pymongo_test.test.insert, {'_id': 'foo'}) - self.assertEqual(1, len(pool.sockets)) + self.assertEqual(socket_count, len(pool.sockets)) new_sock_info = next(iter(pool.sockets)) self.assertEqual(old_sock_info, new_sock_info) @@ -861,7 +836,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertEqual(old_sock_info, pool._get_request_state()) def test_alive(self): - self.assertTrue(get_client().alive()) + self.assertTrue(self.client.alive()) client = MongoClient('doesnt exist', _connect=False) self.assertFalse(client.alive()) @@ -938,12 +913,9 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertEqual(expected_min, c.min_wire_version) self.assertEqual(expected_max, c.max_wire_version) + @client_context.require_replica_set def test_replica_set(self): - client = MongoClient(host, port) - name = client.pymongo_test.command('ismaster').get('setName') - if not name: - raise SkipTest('Not connected to a replica set') - + name = client_context.setname MongoClient(host, port, replicaSet=name) # No error. self.assertRaises( diff --git a/test/test_collection.py b/test/test_collection.py index e864b33b0..722083ec2 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -38,7 +38,7 @@ from pymongo import (ASCENDING, DESCENDING, GEO2D, from pymongo import message as message_module from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor -from pymongo.mongo_replica_set_client import MongoReplicaSetClient +from pymongo.database import Database from pymongo.read_preferences import ReadPreference from pymongo.son_manipulator import SONManipulator from pymongo.errors import (DocumentTooLarge, @@ -51,22 +51,22 @@ from pymongo.errors import (DocumentTooLarge, from test.test_client import get_client from test.utils import (is_mongos, joinall, enable_text_search, get_pool, oid_generated_on_client) -from test import qcheck, SkipTest, unittest, version +from test import (client_context, + qcheck, + SkipTest, + unittest, + version) class TestCollection(unittest.TestCase): - def setUp(self): - self.client = get_client() - self.db = self.client.pymongo_test - ismaster = self.db.command('ismaster') - self.setname = ismaster.get('setName') - self.w = len(ismaster.get('hosts', [])) or 1 + @classmethod + def setUpClass(cls): + cls.db = client_context.client.pymongo_test + cls.w = client_context.w def tearDown(self): self.db.drop_collection("test_large_limit") - self.db = None - self.client = None def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -357,9 +357,8 @@ class TestCollection(unittest.TestCase): index_info = db.test.index_information()['loc_2d'] self.assertEqual([('loc', '2d')], index_info['key']) + @client_context.require_no_mongos def test_index_haystack(self): - if is_mongos(self.db.connection): - raise SkipTest("geoSearch is not supported by mongos") db = self.db db.test.drop_indexes() db.test.remove() @@ -393,14 +392,10 @@ class TestCollection(unittest.TestCase): "type": "restaurant" }, results[0]) + @client_context.require_version_min(2, 3, 2) + @client_context.require_no_mongos def test_index_text(self): - if not version.at_least(self.client, (2, 3, 2)): - raise SkipTest("Text search requires server >=2.3.2.") - - if is_mongos(self.client): - raise SkipTest("setParameter does not work through mongos") - - enable_text_search(self.client) + enable_text_search(client_context.client) db = self.db db.test.drop_indexes() @@ -408,7 +403,7 @@ class TestCollection(unittest.TestCase): index_info = db.test.index_information()["t_text"] self.assertTrue("weights" in index_info) - if version.at_least(self.client, (2, 5, 5)): + if client_context.version.at_least(2, 5, 5): db.test.insert([ {'t': 'spam eggs and spam'}, {'t': 'spam'}, @@ -426,10 +421,8 @@ class TestCollection(unittest.TestCase): db.test.drop_indexes() + @client_context.require_version_min(2, 3, 2) def test_index_2dsphere(self): - if not version.at_least(self.client, (2, 3, 2)): - raise SkipTest("2dsphere indexing requires server >=2.3.2.") - db = self.db db.test.drop_indexes() self.assertEqual("geo_2dsphere", @@ -444,10 +437,8 @@ class TestCollection(unittest.TestCase): db.test.drop_indexes() + @client_context.require_version_min(2, 3, 2) def test_index_hashed(self): - if not version.at_least(self.client, (2, 3, 2)): - raise SkipTest("hashed indexing requires server >=2.3.2.") - db = self.db db.test.drop_indexes() self.assertEqual("a_hashed", @@ -485,7 +476,7 @@ class TestCollection(unittest.TestCase): db = self.db self._drop_dups_setup(db) - if version.at_least(db.connection, (1, 9, 2)): + if client_context.version.at_least(1, 9, 2): # No error, just drop the duplicate db.test.create_index( [('i', ASCENDING)], @@ -589,14 +580,14 @@ class TestCollection(unittest.TestCase): db.drop_collection("test") db.test.save({}) expected = {} - if version.at_least(db.connection, (2, 7, 0)): + if client_context.version.at_least(2, 7, 0): # usePowerOf2Sizes server default expected["flags"] = 1 self.assertEqual(db.test.options(), expected) self.assertEqual(db.test.doesnotexist.options(), {}) db.drop_collection("test") - if version.at_least(db.connection, (1, 9)): + if client_context.version.at_least(1, 9): db.create_collection("test", capped=True, size=4096) result = db.test.options() # mongos 2.2.x adds an $auth field when auth is enabled. @@ -703,7 +694,7 @@ class TestCollection(unittest.TestCase): db.test.insert({"x": [1, 2, 3], "mike": "awesome"}) self.assertEqual([1, 2, 3], db.test.find_one()["x"]) - if version.at_least(db.connection, (1, 5, 1)): + if client_context.version.at_least(1, 5, 1): self.assertEqual([2, 3], db.test.find_one(fields={"x": {"$slice": -2}})["x"]) @@ -827,24 +818,28 @@ class TestCollection(unittest.TestCase): ) db.drop_collection("test") + wc = db.write_concern db.write_concern = {"w": 0} - db.test.ensure_index([('i', ASCENDING)], unique=True) + try: + db.test.ensure_index([('i', ASCENDING)], unique=True) - # No error - db.test.insert([{'i': 1}] * 2) - self.assertEqual(1, db.test.count()) + # No error + db.test.insert([{'i': 1}] * 2) + self.assertEqual(1, db.test.count()) - # Implied safe - self.assertRaises( - DuplicateKeyError, - lambda: db.test.insert([{'i': 2}] * 2, fsync=True), - ) + # Implied safe + self.assertRaises( + DuplicateKeyError, + lambda: db.test.insert([{'i': 2}] * 2, fsync=True), + ) - # Explicit safe - self.assertRaises( - DuplicateKeyError, - lambda: db.test.insert([{'i': 2}] * 2, w=1), - ) + # Explicit safe + self.assertRaises( + DuplicateKeyError, + lambda: db.test.insert([{'i': 2}] * 2, w=1), + ) + finally: + db.write_concern = wc def test_insert_iterables(self): db = self.db @@ -864,14 +859,12 @@ class TestCollection(unittest.TestCase): itertools.repeat(None, 10))) self.assertEqual(db.test.find().count(), 10) + @client_context.require_version_min(2, 0) def test_insert_manipulate_false(self): # Test three aspects of insert with manipulate=False: # 1. The return value is None or [None] as appropriate. # 2. _id is not set on the passed-in document object. # 3. _id is not sent to server. - if not version.at_least(self.db.connection, (2, 0)): - raise SkipTest('Need at least MongoDB 2.0') - collection = self.db.test_insert_manipulate_false collection.drop() oid = ObjectId() @@ -934,28 +927,29 @@ class TestCollection(unittest.TestCase): doc = self.db.test.find_one() doc['a.b'] = 'c' expected = InvalidDocument - if version.at_least(self.client, (2, 5, 4, -1)): + if client_context.version.at_least(2, 5, 4, -1): expected = OperationFailure self.assertRaises(expected, self.db.test.save, doc) def test_unique_index(self): db = self.db - db.drop_collection("test") - db.test.create_index("hello") + with client_context.client.start_request(): + db.drop_collection("test") + db.test.create_index("hello") - db.test.save({"hello": "world"}) - db.test.save({"hello": "mike"}) - db.test.save({"hello": "world"}) - self.assertFalse(db.error()) + db.test.save({"hello": "world"}) + db.test.save({"hello": "mike"}) + db.test.save({"hello": "world"}) + self.assertFalse(db.error()) - db.drop_collection("test") - db.test.create_index("hello", unique=True) + db.drop_collection("test") + db.test.create_index("hello", unique=True) - db.test.save({"hello": "world"}) - db.test.save({"hello": "mike"}) - db.test.save({"hello": "world"}, w=0) - self.assertTrue(db.error()) + db.test.save({"hello": "world"}) + db.test.save({"hello": "mike"}) + db.test.save({"hello": "world"}, w=0) + self.assertTrue(db.error()) def test_duplicate_key_error(self): db = self.db @@ -976,7 +970,7 @@ class TestCollection(unittest.TestCase): self.assertEqual(2, db.test.count()) expected_error = OperationFailure - if version.at_least(db.connection, (1, 3)): + if client_context.version.at_least(1, 3): expected_error = DuplicateKeyError self.assertRaises(expected_error, @@ -1012,11 +1006,9 @@ class TestCollection(unittest.TestCase): collection.write_concern = {'wtimeout': 1000} self.assertRaises(DuplicateKeyError, collection.insert, {'_id': 1}) + @client_context.require_version_min(1, 9, 1) def test_continue_on_error(self): db = self.db - if not version.at_least(db.connection, (1, 9, 1)): - raise SkipTest("continue_on_error requires MongoDB >= 1.9.1") - db.drop_collection("test") oid = db.test.insert({"one": 1}) self.assertEqual(1, db.test.count()) @@ -1027,33 +1019,34 @@ class TestCollection(unittest.TestCase): docs.append({"four": 4}) docs.append({"five": 5}) - db.test.insert(docs, manipulate=False, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(1, db.test.count()) + with client_context.client.start_request(): + db.test.insert(docs, manipulate=False, w=0) + self.assertEqual(11000, db.error()['code']) + self.assertEqual(1, db.test.count()) - db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(4, db.test.count()) + db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) + self.assertEqual(11000, db.error()['code']) + self.assertEqual(4, db.test.count()) - db.drop_collection("test") - oid = db.test.insert({"_id": oid, "one": 1}, w=0) - self.assertEqual(1, db.test.count()) - docs[0].pop("_id") - docs[2]["_id"] = oid + db.drop_collection("test") + oid = db.test.insert({"_id": oid, "one": 1}, w=0) + self.assertEqual(1, db.test.count()) + docs[0].pop("_id") + docs[2]["_id"] = oid - db.test.insert(docs, manipulate=False, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(3, db.test.count()) + db.test.insert(docs, manipulate=False, w=0) + self.assertEqual(11000, db.error()['code']) + self.assertEqual(3, db.test.count()) - db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) - self.assertEqual(11000, db.error()['code']) - self.assertEqual(6, db.test.count()) + db.test.insert(docs, manipulate=False, continue_on_error=True, w=0) + self.assertEqual(11000, db.error()['code']) + self.assertEqual(6, db.test.count()) def test_error_code(self): try: self.db.test.update({}, {"$thismodifierdoesntexist": 1}) except OperationFailure as exc: - if version.at_least(self.db.connection, (1, 3)): + if client_context.version.at_least(1, 3): self.assertTrue(exc.code in (9, 10147, 16840, 17009)) # Just check that we set the error document. Fields # vary by MongoDB version. @@ -1078,15 +1071,16 @@ class TestCollection(unittest.TestCase): db.test.insert, {"hello": {"a": 4, "b": 10}}) def test_safe_insert(self): - db = self.db - db.drop_collection("test") + with client_context.client.start_request(): + db = self.db + db.drop_collection("test") - a = {"hello": "world"} - db.test.insert(a) - db.test.insert(a, w=0) - self.assertTrue("E11000" in db.error()["err"]) + a = {"hello": "world"} + db.test.insert(a) + db.test.insert(a, w=0) + self.assertTrue("E11000" in db.error()["err"]) - self.assertRaises(OperationFailure, db.test.insert, a) + self.assertRaises(OperationFailure, db.test.insert, a) def test_update(self): db = self.db @@ -1129,7 +1123,7 @@ class TestCollection(unittest.TestCase): def test_update_nmodified(self): db = self.db db.drop_collection("test") - used_write_commands = (self.client.max_wire_version > 1) + used_write_commands = (client_context.client.max_wire_version > 1) db.test.insert({'_id': 1}) result = db.test.update({'_id': 1}, {'$set': {'x': 1}}) @@ -1145,11 +1139,9 @@ class TestCollection(unittest.TestCase): else: self.assertFalse('nModified' in result) + @client_context.require_version_min(1, 1, 3, -1) def test_multi_update(self): db = self.db - if not version.at_least(db.connection, (1, 1, 3, -1)): - raise SkipTest("multi-update requires MongoDB >= 1.1.3") - db.drop_collection("test") db.test.save({"x": 4, "y": 3}) @@ -1176,34 +1168,35 @@ class TestCollection(unittest.TestCase): self.assertEqual(2, db.test.find_one()["count"]) def test_safe_update(self): - db = self.db - v113minus = version.at_least(db.connection, (1, 1, 3, -1)) - v19 = version.at_least(db.connection, (1, 9)) + with client_context.client.start_request(): + db = self.db + v113minus = client_context.version.at_least(1, 1, 3, -1) + v19 = client_context.version.at_least(1, 9) - db.drop_collection("test") - db.test.create_index("x", unique=True) + db.drop_collection("test") + db.test.create_index("x", unique=True) - db.test.insert({"x": 5}) - id = db.test.insert({"x": 4}) + db.test.insert({"x": 5}) + id = db.test.insert({"x": 4}) - self.assertEqual( - None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0)) + self.assertEqual( + None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0)) - if v19: - self.assertTrue("E11000" in db.error()["err"]) - elif v113minus: - self.assertTrue(db.error()["err"].startswith("E11001")) - else: - self.assertTrue(db.error()["err"].startswith("E12011")) + if v19: + self.assertTrue("E11000" in db.error()["err"]) + elif v113minus: + self.assertTrue(db.error()["err"].startswith("E11001")) + else: + self.assertTrue(db.error()["err"].startswith("E12011")) - self.assertRaises(OperationFailure, db.test.update, - {"_id": id}, {"$inc": {"x": 1}}) + self.assertRaises(OperationFailure, db.test.update, + {"_id": id}, {"$inc": {"x": 1}}) - self.assertEqual(1, db.test.update({"_id": id}, - {"$inc": {"x": 2}})["n"]) + self.assertEqual(1, db.test.update({"_id": id}, + {"$inc": {"x": 2}})["n"]) - self.assertEqual(0, db.test.update({"_id": "foo"}, - {"$inc": {"x": 2}})["n"]) + self.assertEqual(0, db.test.update({"_id": "foo"}, + {"$inc": {"x": 2}})["n"]) def test_update_with_invalid_keys(self): self.db.drop_collection("test") @@ -1212,7 +1205,7 @@ class TestCollection(unittest.TestCase): doc['a.b'] = 'c' expected = InvalidDocument - if version.at_least(self.client, (2, 5, 4, -1)): + if client_context.version.at_least(2, 5, 4, -1): expected = OperationFailure # Replace @@ -1249,16 +1242,17 @@ class TestCollection(unittest.TestCase): {})['n']) def test_safe_save(self): - db = self.db - db.drop_collection("test") - db.test.create_index("hello", unique=True) + with client_context.client.start_request(): + db = self.db + db.drop_collection("test") + db.test.create_index("hello", unique=True) - db.test.save({"hello": "world"}) - db.test.save({"hello": "world"}, w=0) - self.assertTrue("E11000" in db.error()["err"]) + db.test.save({"hello": "world"}) + db.test.save({"hello": "world"}, w=0) + self.assertTrue("E11000" in db.error()["err"]) - self.assertRaises(OperationFailure, db.test.save, - {"hello": "world"}) + self.assertRaises(OperationFailure, db.test.save, + {"hello": "world"}) def test_safe_remove(self): db = self.db @@ -1271,7 +1265,7 @@ class TestCollection(unittest.TestCase): self.assertEqual(None, db.test.remove({"x": 1}, w=0)) self.assertEqual(1, db.test.count()) - if version.at_least(db.connection, (1, 1, 3, -1)): + if client_context.version.at_least(1, 1, 3, -1): self.assertRaises(OperationFailure, db.test.remove, {"x": 1}) else: # Just test that it doesn't blow up @@ -1283,18 +1277,16 @@ class TestCollection(unittest.TestCase): self.assertEqual(2, db.test.remove({})["n"]) self.assertEqual(0, db.test.remove({})["n"]) + @client_context.require_version_min(1, 5, 1) def test_last_error_options(self): - if not version.at_least(self.client, (1, 5, 1)): - raise SkipTest("getLastError options require MongoDB >= 1.5.1") - self.db.test.save({"x": 1}, w=1, wtimeout=1) self.db.test.insert({"x": 1}, w=1, wtimeout=1) self.db.test.remove({"x": 1}, w=1, wtimeout=1) self.db.test.update({"x": 1}, {"y": 2}, w=1, wtimeout=1) - ismaster = self.client.admin.command("ismaster") - if ismaster.get("setName"): - w = len(ismaster["hosts"]) + 1 + if client_context.setname: + # client_context.w is the number of hosts in the replica set + w = client_context.w + 1 self.assertRaises(WTimeoutError, self.db.test.save, {"x": 1}, w=w, wtimeout=1) self.assertRaises(WTimeoutError, self.db.test.insert, @@ -1314,7 +1306,7 @@ class TestCollection(unittest.TestCase): self.fail("WTimeoutError was not raised") # can't use fsync and j options together - if version.at_least(self.client, (1, 8, 2)): + if client_context.version.at_least(1, 8, 2): self.assertRaises(OperationFailure, self.db.test.insert, {"_id": 1}, j=True, fsync=True) @@ -1335,9 +1327,8 @@ class TestCollection(unittest.TestCase): self.assertEqual(db.test.find({'foo': 'bar'}).count(), 1) self.assertEqual(db.test.find({'foo': re.compile(r'ba.*')}).count(), 2) + @client_context.require_version_min(2, 1, 0) def test_aggregate(self): - if not version.at_least(self.db.connection, (2, 1, 0)): - raise SkipTest("The aggregate command requires MongoDB >= 2.1.0") db = self.db db.drop_collection("test") db.test.save({'foo': [1, 2]}) @@ -1353,81 +1344,74 @@ class TestCollection(unittest.TestCase): self.assertEqual(1.0, result['ok']) self.assertEqual([{'foo': [1, 2]}], result['result']) + @client_context.require_version_min(2, 3, 2) # See SERVER-6470. def test_aggregate_with_compile_re(self): - # See SERVER-6470. - if not version.at_least(self.db.connection, (2, 3, 2)): - raise SkipTest( - "Retrieving a regex with aggregation requires " - "MongoDB >= 2.3.2") + self.db.test.drop() + self.db.test.insert({'r': re.compile('.*')}) - db = self.client.pymongo_test - db.test.drop() - db.test.insert({'r': re.compile('.*')}) - - result = db.test.aggregate([]) + result = self.db.test.aggregate([]) self.assertTrue(isinstance(result['result'][0]['r'], RE_TYPE)) - result = db.test.aggregate([], compile_re=False) + result = self.db.test.aggregate([], compile_re=False) self.assertTrue(isinstance(result['result'][0]['r'], Regex)) + @client_context.require_version_min(2, 5, 1) def test_aggregation_cursor_validation(self): - if not version.at_least(self.db.connection, (2, 5, 1)): - raise SkipTest("Aggregation cursor requires MongoDB >= 2.5.1") db = self.db projection = {'$project': {'_id': '$_id'}} cursor = db.test.aggregate(projection, cursor={}) self.assertTrue(isinstance(cursor, CommandCursor)) + @client_context.require_version_min(2, 5, 1) def test_aggregation_cursor(self): - if not version.at_least(self.db.connection, (2, 5, 1)): - raise SkipTest("Aggregation cursor requires MongoDB >= 2.5.1") db = self.db - if self.setname: - db = MongoReplicaSetClient(host=self.client.host, - port=self.client.port, - replicaSet=self.setname)[db.name] + if client_context.setname: + db = client_context.rs_client[db.name] # Test that getMore messages are sent to the right server. db.read_preference = ReadPreference.SECONDARY - for collection_size in (10, 1000): - db.drop_collection("test") - db.test.insert([{'_id': i} for i in range(collection_size)], - w=self.w) - expected_sum = sum(range(collection_size)) - # Use batchSize to ensure multiple getMore messages - cursor = db.test.aggregate( - {'$project': {'_id': '$_id'}}, - cursor={'batchSize': 5}) + try: + for collection_size in (10, 1000): + db.drop_collection("test") + db.test.insert([{'_id': i} for i in range(collection_size)], + w=self.w) + expected_sum = sum(range(collection_size)) + # Use batchSize to ensure multiple getMore messages + cursor = db.test.aggregate( + {'$project': {'_id': '$_id'}}, + cursor={'batchSize': 5}) - self.assertEqual( - expected_sum, - sum(doc['_id'] for doc in cursor)) + self.assertEqual( + expected_sum, + sum(doc['_id'] for doc in cursor)) + finally: + db.read_preference = ReadPreference.PRIMARY + @client_context.require_version_min(2, 5, 5) + @client_context.require_no_mongos def test_parallel_scan(self): - if is_mongos(self.db.connection): - raise SkipTest("mongos does not support parallel_scan") - if not version.at_least(self.db.connection, (2, 5, 5)): - raise SkipTest("Requires MongoDB >= 2.5.5") db = self.db db.drop_collection("test") - if self.setname: - db = MongoReplicaSetClient(host=self.client.host, - port=self.client.port, - replicaSet=self.setname)[db.name] + if client_context.setname: + db = client_context.rs_client[db.name] # Test that getMore messages are sent to the right server. db.read_preference = ReadPreference.SECONDARY - coll = db.test - coll.insert(({'_id': i} for i in range(8000)), w=self.w) - docs = [] - threads = [threading.Thread(target=docs.extend, args=(cursor,)) - for cursor in coll.parallel_scan(3)] - for t in threads: - t.start() - for t in threads: - t.join() - self.assertEqual( - set(range(8000)), - set(doc['_id'] for doc in docs)) + try: + coll = db.test + coll.insert(({'_id': i} for i in range(8000)), w=self.w) + docs = [] + threads = [threading.Thread(target=docs.extend, args=(cursor,)) + for cursor in coll.parallel_scan(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual( + set(range(8000)), + set(doc['_id'] for doc in docs)) + finally: + db.read_preference = ReadPreference.PRIMARY def test_group(self): db = self.db @@ -1515,7 +1499,7 @@ class TestCollection(unittest.TestCase): Code(reduce_function, {"inc_value": 0.5}))[0]['count']) - if version.at_least(db.connection, (1, 1)): + if client_context.version.at_least(1, 1): self.assertEqual(2, db.test.group([], {}, {"count": 0}, Code(reduce_function, {"inc_value": 1}), @@ -1754,10 +1738,8 @@ class TestCollection(unittest.TestCase): # The socket should be discarded. self.assertEqual(0, len(socks)) + @client_context.require_version_min(1, 1) def test_distinct(self): - if not version.at_least(self.db.connection, (1, 1)): - raise SkipTest("distinct command requires MongoDB >= 1.1") - self.db.drop_collection("test") test = self.db.test @@ -1814,11 +1796,11 @@ class TestCollection(unittest.TestCase): def test_insert_large_document(self): max_size = self.db.connection.max_bson_size half_size = int(max_size / 2) - if version.at_least(self.db.connection, (1, 7, 4)): + if client_context.version.at_least(1, 7, 4): self.assertEqual(max_size, 16777216) expected = DocumentTooLarge - if version.at_least(self.client, (2, 5, 4, -1)): + if client_context.version.at_least(2, 5, 4, -1): # Document too large handled by the server expected = OperationFailure self.assertRaises(expected, self.db.test.insert, @@ -1838,8 +1820,8 @@ class TestCollection(unittest.TestCase): self.db.test.update({"bar": "x"}, {"bar": "x" * (max_size - 32)}) def test_insert_large_batch(self): - max_bson_size = self.client.max_bson_size - if version.at_least(self.client, (2, 5, 4, -1)): + max_bson_size = client_context.client.max_bson_size + if client_context.version.at_least(2, 5, 4, -1): # Write commands are limited to 16MB + 16k per batch big_string = 'x' * int(max_bson_size / 2) else: @@ -1862,12 +1844,9 @@ class TestCollection(unittest.TestCase): # Test that inserts fail after first error, unacknowledged. self.db.test.drop() - self.client.start_request() - try: + with client_context.client.start_request(): self.assertTrue(self.db.test.insert(batch, w=0)) self.assertEqual(1, self.db.test.count()) - finally: - self.client.end_request() # 2 batches, 2 errors, acknowledged, continue on error self.db.test.drop() @@ -1884,13 +1863,11 @@ class TestCollection(unittest.TestCase): # 2 batches, 2 errors, unacknowledged, continue on error self.db.test.drop() - self.client.start_request() - try: - self.assertTrue(self.db.test.insert(batch, continue_on_error=True, w=0)) + with client_context.client.start_request(): + self.assertTrue( + self.db.test.insert(batch, continue_on_error=True, w=0)) # Only the first and third documents should be inserted. self.assertEqual(2, self.db.test.count()) - finally: - self.client.end_request() def test_numerous_inserts(self): # Ensure we don't exceed server's 1000-document batch size limit. @@ -1935,10 +1912,8 @@ class TestCollection(unittest.TestCase): coll.uuid_subtype = 6 self.assertEqual(doc, coll.find_one({'_id': 2})) + @client_context.require_version_min(1, 1, 1) def test_map_reduce(self): - if not version.at_least(self.db.connection, (1, 1, 1)): - raise SkipTest("mapReduce command requires MongoDB >= 1.1.1") - db = self.db db.drop_collection("test") @@ -1964,7 +1939,7 @@ class TestCollection(unittest.TestCase): self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) - if version.at_least(self.db.connection, (1, 7, 4)): + if client_context.version.at_least(1, 7, 4): db.test.insert({"id": 5, "tags": ["hampster"]}) result = db.test.map_reduce(map, reduce, out='mrunittests') self.assertEqual(1, result.find_one({"_id": "hampster"})["value"]) @@ -1992,8 +1967,8 @@ class TestCollection(unittest.TestCase): self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) - if (is_mongos(self.db.connection) - and not version.at_least(self.db.connection, (2, 1, 2))): + if (client_context.is_mongos and + not client_context.version.at_least(2, 1, 2)): pass else: result = db.test.map_reduce(map, reduce, @@ -2003,7 +1978,7 @@ class TestCollection(unittest.TestCase): self.assertEqual(3, result.find_one({"_id": "cat"})["value"]) self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) - self.client.drop_database('mrtestdb') + client_context.client.drop_database('mrtestdb') full_result = db.test.map_reduce(map, reduce, out='mrunittests', full_response=True) @@ -2014,7 +1989,7 @@ class TestCollection(unittest.TestCase): self.assertEqual(1, result.find_one({"_id": "dog"})["value"]) self.assertEqual(None, result.find_one({"_id": "mouse"})) - if version.at_least(self.db.connection, (1, 7, 4)): + if client_context.version.at_least(1, 7, 4): result = db.test.map_reduce(map, reduce, out={'inline': 1}) self.assertTrue(isinstance(result, dict)) self.assertTrue('results' in result) @@ -2077,7 +2052,7 @@ class TestCollection(unittest.TestCase): # Starting with MongoDB 2.5.2 this is no longer possible # from insert, update, or findAndModify. - if not version.at_least(self.db.connection, (2, 5, 2)): + if not client_context.version.at_least(2, 5, 2): # Force insert of ref without $id. c.insert(ref_only, check_keys=False) self.assertEqual(DBRef('collection', id=None), @@ -2117,7 +2092,7 @@ class TestCollection(unittest.TestCase): # Test that we raise DuplicateKeyError when appropriate. # MongoDB doesn't have a code field for DuplicateKeyError # from commands before 2.2. - if version.at_least(self.db.connection, (2, 2)): + if client_context.version.at_least(2, 2): c.ensure_index('i', unique=True) self.assertRaises(DuplicateKeyError, c.find_and_modify, query={'i': 1, 'j': 1}, @@ -2139,7 +2114,7 @@ class TestCollection(unittest.TestCase): self.assertEqual(None, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}})) # The return value changed in 2.1.2. See SERVER-6226. - if version.at_least(self.db.connection, (2, 1, 2)): + if client_context.version.at_least(2, 1, 2): self.assertEqual(None, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, upsert=True)) @@ -2160,8 +2135,8 @@ class TestCollection(unittest.TestCase): # Test with full_response=True # No lastErrorObject from mongos until 2.0 - if (not is_mongos(self.db.connection) and - version.at_least(self.db.connection, (2, 0))): + if (not client_context.is_mongos and + client_context.version.at_least(2, 0)): result = c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True, upsert=True, full_response=True, @@ -2251,9 +2226,8 @@ class TestCollection(unittest.TestCase): self.assertRaises(TypeError, c.find_and_modify, {}, {'$inc': {'i': 1}}, sort=sort) + @client_context.require_version_min(2, 0, 0) def test_find_with_nested(self): - if not version.at_least(self.db.connection, (2, 0, 0)): - raise SkipTest("nested $and and $or requires MongoDB >= 2.0") c = self.db.test c.drop() c.insert([{'i': i} for i in range(5)]) # [0, 1, 2, 3, 4] @@ -2309,7 +2283,7 @@ class TestCollection(unittest.TestCase): son['foo'] += 2 return son - db = self.client.pymongo_test + db = client_context.client.pymongo_test db.add_son_manipulator(IncByTwo()) c = db.test c.drop() @@ -2321,7 +2295,7 @@ class TestCollection(unittest.TestCase): c.remove({}) def test_compile_re(self): - c = self.client.pymongo_test.test + c = self.db.test c.drop() c.insert({'r': re.compile('.*')}) diff --git a/test/test_common.py b/test/test_common.py index 57f280de1..d4b55e65b 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -26,17 +26,14 @@ from bson.son import SON from pymongo.mongo_client import MongoClient from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.errors import ConfigurationError, OperationFailure -from test import pair, unittest, SkipTest, version -from test.utils import drop_collections +from test import client_context, pair, unittest class TestCommon(unittest.TestCase): def test_uuid_subtype(self): - - self.client = MongoClient(pair) - self.db = self.client.pymongo_test - coll = self.client.pymongo_test.uuid + self.db = client_context.client.pymongo_test + coll = self.db.uuid coll.drop() def change_subtype(collection, subtype): @@ -102,21 +99,20 @@ class TestCommon(unittest.TestCase): self.assertEqual(5, coll.find_one({'_id': uu})['i']) # Test command - db = self.client.pymongo_test no_obj_error = "No matching object found" - result = db.command('findAndModify', 'uuid', - allowable_errors=[no_obj_error], - uuid_subtype=UUID_SUBTYPE, - query={'_id': uu}, - update={'$set': {'i': 6}}) + result = self.db.command('findAndModify', 'uuid', + allowable_errors=[no_obj_error], + uuid_subtype=UUID_SUBTYPE, + query={'_id': uu}, + update={'$set': {'i': 6}}) self.assertEqual(None, result.get('value')) - self.assertEqual(5, db.command('findAndModify', 'uuid', - update={'$set': {'i': 6}}, - query={'_id': uu})['value']['i']) - self.assertEqual(6, db.command('findAndModify', 'uuid', - update={'$set': {'i': 7}}, - query={'_id': UUIDLegacy(uu)} - )['value']['i']) + self.assertEqual(5, self.db.command('findAndModify', 'uuid', + update={'$set': {'i': 6}}, + query={'_id': uu})['value']['i']) + self.assertEqual(6, self.db.command( + 'findAndModify', 'uuid', + update={'$set': {'i': 7}}, + query={'_id': UUIDLegacy(uu)})['value']['i']) # Test (inline)_map_reduce coll.drop() @@ -140,23 +136,23 @@ class TestCommon(unittest.TestCase): coll.uuid_subtype = UUID_SUBTYPE q = {"_id": uu} - if version.at_least(self.db.connection, (1, 7, 4)): + if client_context.version.at_least(1, 7, 4): result = coll.inline_map_reduce(map, reduce, query=q) self.assertEqual([], result) result = coll.map_reduce(map, reduce, "results", query=q) - self.assertEqual(0, db.results.count()) + self.assertEqual(0, self.db.results.count()) coll.uuid_subtype = OLD_UUID_SUBTYPE q = {"_id": uu} - if version.at_least(self.db.connection, (1, 7, 4)): + if client_context.version.at_least(1, 7, 4): result = coll.inline_map_reduce(map, reduce, query=q) self.assertEqual(2, len(result)) result = coll.map_reduce(map, reduce, "results", query=q) - self.assertEqual(2, db.results.count()) + self.assertEqual(2, self.db.results.count()) - db.drop_collection("result") + self.db.drop_collection("result") coll.drop() # Test group @@ -233,13 +229,9 @@ class TestCommon(unittest.TestCase): self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,))) self.assertFalse(m != MongoClient("mongodb://%s/?w=0" % (pair,))) + @client_context.require_replica_set def test_mongo_replica_set_client(self): - c = MongoClient(pair) - ismaster = c.admin.command('ismaster') - if 'setName' in ismaster: - setname = str(ismaster.get('setName')) - else: - raise SkipTest("Not connected to a replica set.") + setname = client_context.setname m = MongoReplicaSetClient(pair, replicaSet=setname, w=0) coll = m.pymongo_test.write_concern_test coll.drop() @@ -249,7 +241,7 @@ class TestCommon(unittest.TestCase): self.assertTrue(coll.insert(doc)) self.assertRaises(OperationFailure, coll.insert, doc, w=1) - m = MongoReplicaSetClient(pair, replicaSet=setname) + m = client_context.rs_client coll = m.pymongo_test.write_concern_test self.assertTrue(coll.insert(doc, w=0)) self.assertRaises(OperationFailure, coll.insert, doc) diff --git a/test/test_cursor.py b/test/test_cursor.py index 9737817b4..5d210ed54 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -31,12 +31,10 @@ from pymongo import (ASCENDING, OFF) from pymongo.command_cursor import CommandCursor from pymongo.cursor_manager import CursorManager -from pymongo.database import Database from pymongo.errors import (InvalidOperation, OperationFailure, ExecutionTimeout) -from test import SkipTest, unittest, version -from test.test_client import get_client +from test import client_context, SkipTest, unittest from test.utils import is_mongos, get_command_line, server_started_with_auth if PY3: @@ -45,17 +43,12 @@ if PY3: class TestCursor(unittest.TestCase): - def setUp(self): - self.client = get_client() - self.db = Database(self.client, "pymongo_test") - - def tearDown(self): - self.db = None + @classmethod + def setUpClass(cls): + cls.db = client_context.client.pymongo_test + @client_context.require_version_min(2, 5, 3, -1) def test_max_time_ms(self): - if not version.at_least(self.db.connection, (2, 5, 3, -1)): - raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3") - db = self.db db.pymongo_test.drop() coll = db.pymongo_test @@ -79,11 +72,12 @@ class TestCursor(unittest.TestCase): self.assertTrue(coll.find_one(max_time_ms=1000)) - if "enableTestCommands=1" in get_command_line(self.client)["argv"]: + client = client_context.client + if "enableTestCommands=1" in client_context.cmd_line['argv']: # Cursor parses server timeout error in response to initial query. - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + client.admin.command("configureFailPoint", + "maxTimeAlwaysTimeOut", + mode="alwaysOn") try: cursor = coll.find().max_time_ms(1) try: @@ -95,27 +89,23 @@ class TestCursor(unittest.TestCase): self.assertRaises(ExecutionTimeout, coll.find_one, max_time_ms=1) finally: - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + client.admin.command("configureFailPoint", + "maxTimeAlwaysTimeOut", + mode="off") + @client_context.require_version_min(2, 5, 3, -1) + @client_context.require_test_commands def test_max_time_ms_getmore(self): # Test that Cursor handles server timeout error in response to getmore. - if "enableTestCommands=1" not in get_command_line(self.client)["argv"]: - raise SkipTest("Need test commands enabled") - - if not version.at_least(self.db.connection, (2, 5, 3, -1)): - raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3") - coll = self.db.pymongo_test coll.insert({} for _ in range(200)) cursor = coll.find().max_time_ms(100) # Send initial query before turning on failpoint. next(cursor) - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + client_context.client.admin.command("configureFailPoint", + "maxTimeAlwaysTimeOut", + mode="alwaysOn") try: try: # Iterate up to first getmore. @@ -125,9 +115,9 @@ class TestCursor(unittest.TestCase): else: self.fail("ExecutionTimeout not raised") finally: - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + client_context.client.admin.command("configureFailPoint", + "maxTimeAlwaysTimeOut", + mode="off") def test_explain(self): a = self.db.test.find() @@ -779,7 +769,7 @@ class TestCursor(unittest.TestCase): self.db.test.drop() self.db.test.save({"x": 1}) - if not version.at_least(self.db.connection, (1, 1, 3, -1)): + if not client_context.version.at_least(1, 1, 3, -1): for _ in self.db.test.find({}, ["a"]): self.fail() @@ -875,10 +865,8 @@ class TestCursor(unittest.TestCase): self.assertRaises(IndexError, lambda x: self.db.test.find().skip(50)[x], 50) + @client_context.require_version_min(1, 1, 4, -1) def test_count_with_limit_and_skip(self): - if not version.at_least(self.db.connection, (1, 1, 4, -1)): - raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4") - self.assertRaises(TypeError, self.db.test.find().count, "foo") def check_len(cursor, length): @@ -961,10 +949,8 @@ class TestCursor(unittest.TestCase): finally: db.drop_collection("test") + @client_context.require_version_min(1, 1, 3, 1) def test_distinct(self): - if not version.at_least(self.db.connection, (1, 1, 3, 1)): - raise SkipTest("distinct with query requires MongoDB >= 1.1.3") - self.db.drop_collection("test") self.db.test.save({"a": 1}) @@ -990,10 +976,8 @@ class TestCursor(unittest.TestCase): self.assertEqual(["b", "c"], distinct) + @client_context.require_version_min(1, 5, 1) def test_max_scan(self): - if not version.at_least(self.db.connection, (1, 5, 1)): - raise SkipTest("maxScan requires MongoDB >= 1.5.1") - self.db.drop_collection("test") for _ in range(100): self.db.test.insert({}) @@ -1018,11 +1002,9 @@ class TestCursor(unittest.TestCase): self.assertFalse(c2.alive) self.assertTrue(c1.alive) + @client_context.require_version_min(2, 0) + @client_context.require_no_mongos def test_comment(self): - if is_mongos(self.client): - raise SkipTest("profile is not supported by mongos") - if not version.at_least(self.db.connection, (2, 0)): - raise SkipTest("Requires server >= 2.0") if server_started_with_auth(self.db.connection): raise SkipTest("SERVER-4754 - This test uses profiling.") diff --git a/test/test_database.py b/test/test_database.py index 49e44a1f9..cc17444f8 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -44,9 +44,8 @@ from pymongo.errors import (CollectionInvalid, from pymongo.son_manipulator import (AutoReference, NamespaceInjector, ObjectIdShuffler) -from test import SkipTest, unittest, version -from test.utils import (get_command_line, is_mongos, - remove_all_users, server_started_with_auth) +from test import client_context, SkipTest, unittest +from test.utils import remove_all_users, server_started_with_auth from test.test_client import get_client if PY3: @@ -55,11 +54,9 @@ if PY3: class TestDatabase(unittest.TestCase): - def setUp(self): - self.client = get_client() - - def tearDown(self): - self.client = None + @classmethod + def setUpClass(cls): + cls.client = client_context.client def test_name(self): self.assertRaises(TypeError, Database, self.client, 4) @@ -112,7 +109,7 @@ class TestDatabase(unittest.TestCase): db.create_collection("test.foo") self.assertTrue(u("test.foo") in db.collection_names()) expected = {} - if version.at_least(self.client, (2, 7, 0)): + if client_context.version.at_least(2, 7, 0): # usePowerOf2Sizes server default expected["flags"] = 1 result = db.test.foo.options() @@ -185,9 +182,8 @@ class TestDatabase(unittest.TestCase): self.assertTrue(db.validate_collection(db.test, scandata=True, full=True)) self.assertTrue(db.validate_collection(db.test, True, True)) + @client_context.require_no_mongos def test_profiling_levels(self): - if is_mongos(self.client): - raise SkipTest('profile is not supported by mongos') db = self.client.pymongo_test self.assertEqual(db.profiling_level(), OFF) # default @@ -215,9 +211,8 @@ class TestDatabase(unittest.TestCase): db.set_profiling_level(OFF, 100) # back to default self.assertEqual(100, db.command("profile", -1)['slowms']) + @client_context.require_no_mongos def test_profiling_info(self): - if is_mongos(self.client): - raise SkipTest('profile is not supported by mongos') db = self.client.pymongo_test db.set_profiling_level(ALL) @@ -236,7 +231,7 @@ class TestDatabase(unittest.TestCase): self.assertTrue(len(info) >= 1) # These basically clue us in to server changes. - if version.at_least(db.connection, (1, 9, 1, -1)): + if client_context.version.at_least(1, 9, 1, -1): self.assertTrue(isinstance(info[0]['responseLength'], int)) self.assertTrue(isinstance(info[0]['millis'], int)) self.assertTrue(isinstance(info[0]['client'], string_type)) @@ -256,9 +251,8 @@ class TestDatabase(unittest.TestCase): self.assertRaises(TypeError, iterate) + @client_context.require_no_mongos def test_errors(self): - if is_mongos(self.client): - raise SkipTest('getpreverror not supported by mongos') db = self.client.pymongo_test db.reset_error_history() @@ -296,15 +290,11 @@ class TestDatabase(unittest.TestCase): self.assertEqual(db.command("buildinfo"), db.command({"buildinfo": 1})) + # We use 'aggregate' as our example command, since it's an easy way to + # retrieve a BSON regex from a collection using a command. But until + # MongoDB 2.3.2, aggregation turned regexes into strings: SERVER-6470. + @client_context.require_version_min(2, 3, 2) def test_command_with_compile_re(self): - # We use 'aggregate' as our example command, since it's an easy way to - # retrieve a BSON regex from a collection using a command. But until - # MongoDB 2.3.2, aggregation turned regexes into strings: SERVER-6470. - if not version.at_least(self.client, (2, 3, 2)): - raise SkipTest( - "Retrieving a regex with aggregation requires " - "MongoDB >= 2.3.2") - db = self.client.pymongo_test db.test.drop() db.test.insert({'r': re.compile('.*')}) @@ -317,14 +307,15 @@ class TestDatabase(unittest.TestCase): def test_last_status(self): db = self.client.pymongo_test - db.test.remove({}) - db.test.save({"i": 1}) + with self.client.start_request(): + db.test.remove({}) + db.test.save({"i": 1}) - db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0) - self.assertTrue(db.last_status()["updatedExisting"]) + db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0) + self.assertTrue(db.last_status()["updatedExisting"]) - db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0) - self.assertFalse(db.last_status()["updatedExisting"]) + db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0) + self.assertFalse(db.last_status()["updatedExisting"]) def test_password_digest(self): self.assertRaises(TypeError, auth._password_digest, 5) @@ -340,13 +331,8 @@ class TestDatabase(unittest.TestCase): self.assertEqual(auth._password_digest("Gustave", u("Dor\xe9")), u("81e0e2364499209f466e75926a162d73")) + @client_context.require_auth def test_authenticate_add_remove_user(self): - if (is_mongos(self.client) and not - version.at_least(self.client, (2, 0, 0))): - raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') - db = self.client.pymongo_test # Configuration errors @@ -357,7 +343,7 @@ class TestDatabase(unittest.TestCase): self.assertRaises(ConfigurationError, db.add_user, "user", 'password', True, roles=['read']) - if version.at_least(self.client, (2, 5, 3, -1)): + if client_context.version.at_least(2, 5, 3, -1): with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) self.assertRaises(DeprecationWarning, db.add_user, @@ -400,7 +386,7 @@ class TestDatabase(unittest.TestCase): db.authenticate, "Gustave", u("Dor\xe9")) self.assertTrue(db.authenticate("Gustave", u("password"))) - if not version.at_least(self.client, (2, 5, 3, -1)): + if not client_context.version.at_least(2, 5, 3, -1): # Add a readOnly user db.add_user("Ross", "password", read_only=True) db.logout() @@ -411,17 +397,13 @@ class TestDatabase(unittest.TestCase): # Cleanup finally: remove_all_users(db) + self.client.pymongo_test.logout() self.client.admin.remove_user("admin") self.client.admin.logout() + self.client.disconnect() + @client_context.require_auth def test_make_user_readonly(self): - if (is_mongos(self.client) - and not version.at_least(self.client, (2, 0, 0))): - raise SkipTest('Auth with sharding requires MongoDB >= 2.0.0') - - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') - admin = self.client.admin admin.add_user('admin', 'pw') admin.authenticate('admin', 'pw') @@ -451,13 +433,12 @@ class TestDatabase(unittest.TestCase): remove_all_users(db) admin.remove_user("admin") admin.logout() + db.logout() + self.client.disconnect() + @client_context.require_version_min(2, 5, 3, -1) + @client_context.require_auth def test_default_roles(self): - if not version.at_least(self.client, (2, 5, 3, -1)): - raise SkipTest("Default roles only exist in MongoDB >= 2.5.3") - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') - # "Admin" user db = self.client.admin db.add_user('admin', 'pass') @@ -503,14 +484,11 @@ class TestDatabase(unittest.TestCase): db.authenticate('user', 'pass') remove_all_users(db) db.logout() + self.client.disconnect() + @client_context.require_version_min(2, 5, 3, -1) + @client_context.require_auth def test_new_user_cmds(self): - if not version.at_least(self.client, (2, 5, 3, -1)): - raise SkipTest("User manipulation through commands " - "requires MongoDB >= 2.5.3") - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') - db = self.client.pymongo_test db.add_user("amalia", "password", roles=["userAdmin"]) db.authenticate("amalia", "password") @@ -527,13 +505,10 @@ class TestDatabase(unittest.TestCase): finally: db.remove_user("amalia") db.logout() + self.client.disconnect() + @client_context.require_auth def test_authenticate_and_safe(self): - if (is_mongos(self.client) and not - version.at_least(self.client, (2, 0, 0))): - raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') db = self.client.auth_test db.add_user("bernie", "password", @@ -555,14 +530,10 @@ class TestDatabase(unittest.TestCase): finally: db.remove_user("bernie") db.logout() + self.client.disconnect() + @client_context.require_auth def test_authenticate_and_request(self): - if (is_mongos(self.client) and not - version.at_least(self.client, (2, 0, 0))): - raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") - if not server_started_with_auth(self.client): - raise SkipTest('Authentication is not enabled on server') - # Database.authenticate() needs to be in a request - check that it # always runs in a request, and that it restores the request state # (in or not in a request) properly when it's finished. @@ -584,19 +555,14 @@ class TestDatabase(unittest.TestCase): db.remove_user("mike") db.logout() request_db.logout() + self.client.disconnect() + @client_context.require_auth def test_authenticate_multiple(self): - client = get_client() - if (is_mongos(client) and not - version.at_least(self.client, (2, 2, 0))): - raise SkipTest("Need mongos >= 2.2.0") - if not server_started_with_auth(client): - raise SkipTest("Authentication is not enabled on server") - # Setup - users_db = client.pymongo_test - admin_db = client.admin - other_db = client.pymongo_test1 + users_db = self.client.pymongo_test + admin_db = self.client.admin + other_db = self.client.pymongo_test1 users_db.test.remove() other_db.test.remove() @@ -606,7 +572,7 @@ class TestDatabase(unittest.TestCase): try: self.assertTrue(admin_db.authenticate('admin', 'pass')) - if version.at_least(self.client, (2, 5, 3, -1)): + if client_context.version.at_least(2, 5, 3, -1): admin_db.add_user('ro-admin', 'pass', roles=["userAdmin", "readAnyDatabase"]) else: @@ -632,7 +598,7 @@ class TestDatabase(unittest.TestCase): other_db.test.insert, {}) # Force close all sockets - client.disconnect() + self.client.disconnect() # We should still be able to write to the regular user's db self.assertTrue(users_db.test.remove()) @@ -644,11 +610,11 @@ class TestDatabase(unittest.TestCase): # Cleanup finally: - admin_db.logout() - users_db.logout() - admin_db.authenticate('admin', 'pass') remove_all_users(users_db) remove_all_users(admin_db) + admin_db.logout() + users_db.logout() + self.client.disconnect() def test_id_ordering(self): # PyMongo attempts to have _id show up first @@ -718,7 +684,7 @@ class TestDatabase(unittest.TestCase): # TODO some of these tests belong in the collection level testing. def test_save_find_one(self): - db = Database(self.client, "pymongo_test") + db = self.client.pymongo_test db.test.remove({}) a_doc = SON({"hello": u("world")}) @@ -873,7 +839,7 @@ class TestDatabase(unittest.TestCase): del db.system_js['add'] self.assertEqual(0, db.system.js.count()) - if version.at_least(db.connection, (1, 3, 2, -1)): + if client_context.version.at_least(1, 3, 2, -1): self.assertRaises(OperationFailure, db.system_js.add, 1, 5) # TODO right now CodeWScope doesn't work w/ system js @@ -883,7 +849,7 @@ class TestDatabase(unittest.TestCase): self.assertRaises(OperationFailure, db.system_js.non_existant) # XXX: Broken in V8, works in SpiderMonkey - if not version.at_least(db.connection, (2, 3, 0)): + if not client_context.version.at_least(2, 3, 0): db.system_js.no_param = Code("return 5;") self.assertEqual(5, db.system_js.no_param()) @@ -939,17 +905,15 @@ class TestDatabase(unittest.TestCase): self.assertRaises(UserWarning, self.client.pymongo_test.command, 'ping', read_preference=ReadPreference.SECONDARY) try: - self.client.pymongo_test.command('dbStats', + self.client.pymongo_test.command( + 'dbStats', read_preference=ReadPreference.SECONDARY_PREFERRED) except UserWarning: self.fail("Shouldn't have raised UserWarning.") + @client_context.require_version_min(2, 5, 3, -1) + @client_context.require_test_commands def test_command_max_time_ms(self): - if not version.at_least(self.client, (2, 5, 3, -1)): - raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3") - if "enableTestCommands=1" not in get_command_line(self.client)["argv"]: - raise SkipTest("Test commands must be enabled.") - self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 3fa535abc..5779f30f6 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -35,20 +35,19 @@ from gridfs.errors import (NoFile, UnsupportedAPI) from pymongo import MongoClient from pymongo.errors import ConnectionFailure -from test.test_client import get_client -from test import qcheck, unittest +from test import client_context, qcheck, unittest class TestGridFile(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.db = client_context.client.pymongo_test + def setUp(self): - self.db = get_client().pymongo_test self.db.fs.files.remove({}) self.db.fs.chunks.remove({}) - def tearDown(self): - self.db = None - def test_basic(self): f = GridIn(self.db.fs, filename="test") f.write(b"hello world") diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 99888bf18..0ddcb9831 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -32,8 +32,7 @@ import gridfs from bson.py3compat import u, StringIO, string_type from gridfs.errors import (FileExists, NoFile) -from test import unittest -from test.test_client import get_client +from test import client_context, unittest from test.utils import joinall @@ -71,17 +70,17 @@ class JustRead(threading.Thread): class TestGridfs(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.db = client_context.client.pymongo_test + cls.fs = gridfs.GridFS(cls.db) + cls.alt = gridfs.GridFS(cls.db, "alt") + def setUp(self): - self.db = get_client().pymongo_test self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") self.db.drop_collection("alt.files") self.db.drop_collection("alt.chunks") - self.fs = gridfs.GridFS(self.db) - self.alt = gridfs.GridFS(self.db, "alt") - - def tearDown(self): - self.db = self.fs = self.alt = None def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") @@ -278,17 +277,14 @@ class TestGridfs(unittest.TestCase): self.assertEqual(b"hello world", self.fs.get(oid).read()) def test_file_exists(self): - db = get_client(w=1).pymongo_test - fs = gridfs.GridFS(db) + oid = self.fs.put(b"hello") + self.assertRaises(FileExists, self.fs.put, b"world", _id=oid) - oid = fs.put(b"hello") - self.assertRaises(FileExists, fs.put, b"world", _id=oid) - - one = fs.new_file(_id=123) + one = self.fs.new_file(_id=123) one.write(b"some content") one.close() - two = fs.new_file(_id=123) + two = self.fs.new_file(_id=123) self.assertRaises(FileExists, two.write, b'x' * 262146) def test_exists(self): @@ -446,10 +442,9 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase): self.assertRaises(ConnectionFailure, fs.put, 'data') def tearDown(self): - rsc = self._get_client() + rsc = client_context.rs_client rsc.pymongo_test.drop_collection('fs.files') rsc.pymongo_test.drop_collection('fs.chunks') - rsc.close() if __name__ == "__main__": diff --git a/test/test_json_util.py b/test/test_json_util.py index ed022e773..5ce797b1f 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -33,23 +33,21 @@ from bson.son import RE_TYPE from bson.timestamp import Timestamp from bson.tz_util import utc -from test import SkipTest, unittest -from test.test_client import get_client +from test import client_context, SkipTest, unittest PY3 = sys.version_info[0] == 3 class TestJsonUtil(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.db = client_context.client.pymongo_test + def setUp(self): if not json_util.json_lib: raise SkipTest("No json or simplejson module") - self.db = get_client().pymongo_test - - def tearDown(self): - self.db = None - def round_tripped(self, doc): return json_util.loads(json_util.dumps(doc)) diff --git a/test/test_pooling_base.py b/test/test_pooling_base.py index ac8850471..cf4a5d01d 100644 --- a/test/test_pooling_base.py +++ b/test/test_pooling_base.py @@ -30,9 +30,10 @@ from pymongo.mongo_client import MongoClient from pymongo.pool import Pool, NO_REQUEST, NO_SOCKET_YET, SocketInfo from pymongo.errors import ConfigurationError, ConnectionFailure from pymongo.errors import ExceededMaxWaiters -from test import version, host, port, SkipTest +from test import host, port, SkipTest from test.test_client import get_client from test.utils import delay, is_mongos, one, get_pool +from test.version import Version N = 10 DB = "pymongo-pooling-tests" @@ -1191,7 +1192,7 @@ class _TestPoolSocketSharing(_TestPoolingBase): # Javascript function that pauses N seconds per document fn = delay(10) if (is_mongos(db.connection) or not - version.at_least(db.connection, (1, 7, 2))): + Version.from_client(db.connection).at_least(1, 7, 2)): # mongos doesn't support eval so we have to use $where # which is less reliable in this context. self.assertEqual(1, db.test.find({"$where": fn}).count()) diff --git a/test/test_pymongo.py b/test/test_pymongo.py index fe499eda6..279dc089f 100644 --- a/test/test_pymongo.py +++ b/test/test_pymongo.py @@ -14,20 +14,18 @@ """Test the pymongo module itself.""" -import os import sys sys.path[0:0] = [""] import pymongo -from test import host, port, unittest +from test import unittest class TestPyMongo(unittest.TestCase): def test_mongo_client_alias(self): # Testing that pymongo module imports mongo_client.MongoClient - c = pymongo.MongoClient(host, port) - self.assertEqual(c.host, host) - self.assertEqual(c.port, port) + self.assertEqual(pymongo.MongoClient, + pymongo.mongo_client.MongoClient) if __name__ == "__main__": diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 15261e04a..3ddc0ea67 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -32,21 +32,21 @@ from pymongo.errors import ConfigurationError from test.test_replica_set_client import TestReplicaSetClientBase from test.test_client import get_client -from test import host, port, SkipTest, unittest, utils, version +from test import client_context, host, port, SkipTest, unittest, utils +from test.version import Version class TestReadPreferencesBase(TestReplicaSetClientBase): def setUp(self): super(TestReadPreferencesBase, self).setUp() # Insert some data so we can use cursors in read_from_which_host - c = self._get_client() - c.pymongo_test.test.drop() - c.pymongo_test.test.insert([{'_id': i} for i in range(10)], w=self.w) + client_context.client.pymongo_test.test.drop() + client_context.client.pymongo_test.test.insert( + [{'_id': i} for i in range(10)], w=self.w) def tearDown(self): super(TestReadPreferencesBase, self).tearDown() - c = self._get_client() - c.pymongo_test.test.drop() + client_context.client.pymongo_test.test.drop() def read_from_which_host(self, client): """Do a find() on the client and return which host was used @@ -204,6 +204,7 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase): # Effectively ignore members' ping times so we can test the effect # of ReadPreference modes only acceptableLatencyMS=1000*1000) + self.client_version = Version.from_client(self.c) def tearDown(self): # We create a lot of collections and indexes in these tests, so drop @@ -330,14 +331,14 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase): ('geoSearch', 'test'), ('near', [33, 33]), ('maxDistance', 6), ('search', {'type': 'restaurant'}), ('limit', 30)]))) - if version.at_least(self.c, (2, 1, 0)): + if self.client_version.at_least(2, 1, 0): self._test_fn(True, lambda: self.c.pymongo_test.command(SON([ ('aggregate', 'test'), ('pipeline', []) ]))) # Text search. - if version.at_least(self.c, (2, 3, 2)): + if self.client_version.at_least(2, 3, 2): with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) @@ -394,10 +395,8 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase): ('out', {'inline': True}) ]))) + @client_context.require_version_min(2, 5, 2) def test_aggregate_command_with_out(self): - if not version.at_least(self.c, (2, 5, 2)): - raise SkipTest("Aggregation with $out requires MongoDB >= 2.5.2") - # Tests aggregate command when pipeline contains $out. self.c.pymongo_test.test.insert({"x": 1, "y": 1}, w=self.w) self.c.pymongo_test.test.insert({"x": 1, "y": 2}, w=self.w) @@ -479,7 +478,7 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase): lambda: self.c.pymongo_test.test.find().distinct('a')) def test_aggregate(self): - if version.at_least(self.c, (2, 1, 0)): + if self.client_version.at_least(2, 1, 0): self._test_fn(True, lambda: self.c.pymongo_test.test.aggregate( [{'$project': {'_id': 1}}])) diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index 3e8813f53..31cb28d4b 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -29,7 +29,6 @@ sys.path[0:0] = [""] from bson.py3compat import thread, u, _unicode from bson.son import SON from bson.tz_util import utc -from pymongo.mongo_client import MongoClient from pymongo.read_preferences import ReadPreference, Secondary, Nearest from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.mongo_replica_set_client import _partition_node, have_gevent @@ -40,13 +39,14 @@ from pymongo.errors import (AutoReconnect, ConnectionFailure, InvalidName, OperationFailure, InvalidOperation) -from test import pair, port, SkipTest, unittest, version +from test import client_context, pair, port, SkipTest, unittest from test.pymongo_mocks import MockReplicaSetClient from test.utils import ( delay, assertReadFrom, assertReadFromAll, read_from_which_host, remove_all_users, assertRaisesExactly, TestRequestMixin, one, server_started_with_auth, pools_from_rs_client, get_pool, _TestLazyConnectMixin) +from test.version import Version class TestReplicaSetClientAgainstStandalone(unittest.TestCase): @@ -54,9 +54,7 @@ class TestReplicaSetClientAgainstStandalone(unittest.TestCase): but only if the database at DB_IP and DB_PORT is a standalone. """ def setUp(self): - client = MongoClient(pair) - response = client.admin.command('ismaster') - if 'setName' in response: + if client_context.setname: raise SkipTest("Connected to a replica set, not a standalone mongod") def test_connect(self): @@ -66,32 +64,29 @@ class TestReplicaSetClientAgainstStandalone(unittest.TestCase): class TestReplicaSetClientBase(unittest.TestCase): - def setUp(self): - client = MongoClient(pair) - response = client.admin.command('ismaster') - if 'setName' in response: - self.name = str(response['setName']) - self.w = len(response['hosts']) - self.hosts = set([_partition_node(h) - for h in response["hosts"]]) - self.arbiters = set([_partition_node(h) - for h in response.get("arbiters", [])]) - repl_set_status = client.admin.command('replSetGetStatus') - primary_info = [ - m for m in repl_set_status['members'] - if m['stateStr'] == 'PRIMARY' - ][0] + @classmethod + @client_context.require_replica_set + def setUpClass(cls): + cls.name = client_context.setname + ismaster = client_context.ismaster + cls.w = client_context.w + cls.hosts = set(_partition_node(h) for h in ismaster['hosts']) + cls.arbiters = set(_partition_node(h) + for h in ismaster.get("arbiters", [])) - self.primary = _partition_node(primary_info['name']) - self.secondaries = [ - _partition_node(m['name']) for m in repl_set_status['members'] - if m['stateStr'] == 'SECONDARY' - ] - else: - raise SkipTest("Not connected to a replica set") + repl_set_status = client_context.client.admin.command( + 'replSetGetStatus') + primary_info = [ + m for m in repl_set_status['members'] + if m['stateStr'] == 'PRIMARY' + ][0] - super(TestReplicaSetClientBase, self).setUp() + cls.primary = _partition_node(primary_info['name']) + cls.secondaries = [ + _partition_node(m['name']) for m in repl_set_status['members'] + if m['stateStr'] == 'SECONDARY' + ] def _get_client(self, **kwargs): return MongoReplicaSetClient(pair, @@ -138,7 +133,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertTrue(c.primary) self.assertTrue(c.secondaries) - if version.at_least(c, (2, 5, 4, -1)): + if Version.from_client(c).at_least(2, 5, 4, -1): self.assertTrue(c.max_wire_version > 0) else: self.assertEqual(c.max_wire_version, 0) @@ -169,10 +164,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) + @client_context.require_auth def test_init_disconnected_with_auth(self): - c = self._get_client() - if not server_started_with_auth(c): - raise SkipTest('Authentication is not enabled on server') + c = client_context.rs_client c.admin.add_user("admin", "pass") c.admin.authenticate("admin", "pass") @@ -199,6 +193,8 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): # Clean up. remove_all_users(c.pymongo_test) remove_all_users(c.admin) + c.admin.logout() + c.disconnect() def test_connect(self): assertRaisesExactly(ConnectionFailure, MongoReplicaSetClient, @@ -210,7 +206,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertTrue(MongoReplicaSetClient(pair, replicaSet=self.name)) def test_repr(self): - client = self._get_client() + client = client_context.rs_client # Quirk: the RS client makes a frozenset of hosts from a dict's keys, # so we must do the same to achieve the same order. @@ -223,7 +219,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): "MongoReplicaSetClient([%s])" % hosts_repr) def test_properties(self): - c = MongoReplicaSetClient(pair, replicaSet=self.name) + c = client_context.rs_client c.admin.command('ping') self.assertEqual(c.primary, self.primary) self.assertEqual(c.hosts, self.hosts) @@ -240,7 +236,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): cursor = c.pymongo_test.test.find() self.assertEqual( ReadPreference.PRIMARY, cursor._Cursor__read_preference) - c.close() tag_sets = [{'dc': 'la', 'rack': '2'}, {'foo': 'bar'}] secondary = Secondary(tag_sets) @@ -268,7 +263,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertEqual( nearest, cursor._Cursor__read_preference) - if version.at_least(c, (1, 7, 4)): + if Version.from_client(c).at_least(1, 7, 4): self.assertEqual(c.max_bson_size, 16777216) else: self.assertEqual(c.max_bson_size, 4194304) @@ -283,7 +278,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): pair, replicaSet=self.name, use_greenlets=True).use_greenlets) def test_get_db(self): - client = self._get_client() + client = client_context.rs_client def make_db(base, name): return base[name] @@ -298,10 +293,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertTrue(isinstance(client.test, Database)) self.assertEqual(client.test, client["test"]) self.assertEqual(client.test, Database(client, "test")) - client.close() def test_auto_reconnect_exception_when_read_preference_is_secondary(self): - c = self._get_client() + c = client_context.rs_client db = c.pymongo_test def raise_socket_error(*args, **kwargs): @@ -316,12 +310,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): finally: socket.socket.sendall = old_sendall + @client_context.require_auth def test_lazy_auth_raises_operation_failure(self): # Check if we have the prerequisites to run this test. - c = self._get_client() - if not server_started_with_auth(c): - raise SkipTest('Authentication is not enabled on server') - lazy_client = MongoReplicaSetClient( "mongodb://user:wrong@%s/pymongo_test" % pair, replicaSet=self.name, @@ -331,7 +322,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): OperationFailure, lazy_client.test.collection.find_one) def test_operations(self): - c = self._get_client() + c = client_context.rs_client # Check explicitly for a case we've commonly hit in tests: # a replica set is started with a tiny oplog, a previous @@ -367,10 +358,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): db.test.remove({}) self.assertEqual(0, db.test.count()) db.test.drop() - c.close() def test_database_names(self): - client = self._get_client() + client = client_context.rs_client client.pymongo_test.test.save({"dummy": u("object")}) client.pymongo_test_mike.test.save({"dummy": u("object")}) @@ -378,10 +368,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): dbs = client.database_names() self.assertTrue("pymongo_test" in dbs) self.assertTrue("pymongo_test_mike" in dbs) - client.close() def test_drop_database(self): - client = self._get_client() + client = client_context.rs_client self.assertRaises(TypeError, client.drop_database, 5) self.assertRaises(TypeError, client.drop_database, None) @@ -399,37 +388,33 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): client.drop_database(client.pymongo_test) dbs = client.database_names() self.assertTrue("pymongo_test" not in dbs) - client.close() def test_copy_db(self): - c = self._get_client() + c = client_context.rs_client # We test copy twice; once starting in a request and once not. In # either case the copy should succeed (because it starts a request # internally) and should leave us in the same state as before the copy. - c.start_request() + with c.start_request(): + self.assertRaises(TypeError, c.copy_database, 4, "foo") + self.assertRaises(TypeError, c.copy_database, "foo", 4) - self.assertRaises(TypeError, c.copy_database, 4, "foo") - self.assertRaises(TypeError, c.copy_database, "foo", 4) + self.assertRaises(InvalidName, c.copy_database, "foo", "$foo") - self.assertRaises(InvalidName, c.copy_database, "foo", "$foo") + c.pymongo_test.test.drop() + c.drop_database("pymongo_test1") + c.drop_database("pymongo_test2") - c.pymongo_test.test.drop() - c.drop_database("pymongo_test1") - c.drop_database("pymongo_test2") + c.pymongo_test.test.insert({"foo": "bar"}) - c.pymongo_test.test.insert({"foo": "bar"}) + self.assertFalse("pymongo_test1" in c.database_names()) + self.assertFalse("pymongo_test2" in c.database_names()) - self.assertFalse("pymongo_test1" in c.database_names()) - self.assertFalse("pymongo_test2" in c.database_names()) + c.copy_database("pymongo_test", "pymongo_test1") + # copy_database() didn't accidentally end the request + self.assertTrue(c.in_request()) - c.copy_database("pymongo_test", "pymongo_test1") - # copy_database() didn't accidentally end the request - self.assertTrue(c.in_request()) - - self.assertTrue("pymongo_test1" in c.database_names()) - self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"]) - - c.end_request() + self.assertTrue("pymongo_test1" in c.database_names()) + self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"]) self.assertFalse(c.in_request()) c.copy_database("pymongo_test", "pymongo_test2", pair) @@ -441,7 +426,8 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertTrue("pymongo_test2" in c.database_names()) self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"]) - if version.at_least(c, (1, 3, 3, 1)) and server_started_with_auth(c): + if (Version.from_client(c).at_least(1, 3, 3, 1) and + server_started_with_auth(c)): c.drop_database("pymongo_test1") c.admin.add_user("admin", "password") @@ -469,7 +455,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): # Cleanup remove_all_users(c.pymongo_test) c.admin.remove_user("admin") - c.close() + c.admin.logout() + c.pymongo_test.logout() + c.disconnect() def test_get_default_database(self): host = one(self.hosts) @@ -498,16 +486,15 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_iteration(self): - client = self._get_client() + client = client_context.rs_client def iterate(): [a for a in client] self.assertRaises(TypeError, iterate) - client.close() def test_disconnect(self): - c = self._get_client() + c = client_context.rs_client coll = c.pymongo_test.bar c.disconnect() @@ -553,7 +540,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): except ImportError: raise SkipTest("No multiprocessing module") - client = self._get_client() + client = client_context.rs_client def f(pipe): try: @@ -584,7 +571,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): pass def test_document_class(self): - c = self._get_client() + c = client_context.rs_client db = c.pymongo_test db.test.insert({"x": 1}) @@ -594,10 +581,12 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): c.document_class = SON - self.assertEqual(SON, c.document_class) - self.assertTrue(isinstance(db.test.find_one(), SON)) - self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) - c.close() + try: + self.assertEqual(SON, c.document_class) + self.assertTrue(isinstance(db.test.find_one(), SON)) + self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) + finally: + c.document_class = dict c = self._get_client(document_class=SON) db = c.pymongo_test @@ -633,7 +622,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self._get_client, socketTimeoutMS='foo') def test_socket_timeout(self): - no_timeout = self._get_client() + no_timeout = client_context.rs_client timeout_sec = 1 timeout = self._get_client(socketTimeoutMS=timeout_sec*1000) @@ -666,7 +655,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): else: self.fail('RS client should have raised timeout error') - no_timeout.close() timeout.close() def test_timeout_does_not_mark_member_down(self): @@ -715,7 +703,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): tz_aware='foo', replicaSet=self.name) aware = self._get_client(tz_aware=True) - naive = self._get_client() + naive = client_context.rs_client aware.pymongo_test.drop_collection("test") now = datetime.datetime.utcnow() @@ -809,7 +797,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): # Test fix for PYTHON-294 -- make sure client closes its socket if it # gets an interrupt while waiting to recv() from it. - c = self._get_client() + c = client_context.rs_client db = c.pymongo_test # A $where clause which takes 1.5 sec to execute @@ -932,14 +920,12 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertInRequestAndSameSock(client, pools) client.close() - client = self._get_client() + client = client_context.rs_client pools = pools_from_rs_client(client) self.assertNotInRequestAndDifferentSock(client, pools) - client.start_request() - self.assertInRequestAndSameSock(client, pools) - client.end_request() + with client.start_request(): + self.assertInRequestAndSameSock(client, pools) self.assertNotInRequestAndDifferentSock(client, pools) - client.close() def test_nested_request(self): client = self._get_client(auto_start_request=True) @@ -984,48 +970,46 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): client.close() def test_request_threads(self): - client = self._get_client() - try: - pools = pools_from_rs_client(client) - self.assertNotInRequestAndDifferentSock(client, pools) + client = client_context.rs_client - started_request, ended_request = threading.Event(), threading.Event() - checked_request = threading.Event() - thread_done = [False] + pools = pools_from_rs_client(client) + self.assertNotInRequestAndDifferentSock(client, pools) - # Starting a request in one thread doesn't put the other thread in a - # request - def f(): - self.assertNotInRequestAndDifferentSock(client, pools) - client.start_request() - self.assertInRequestAndSameSock(client, pools) - started_request.set() - checked_request.wait() - checked_request.clear() - self.assertInRequestAndSameSock(client, pools) - client.end_request() - self.assertNotInRequestAndDifferentSock(client, pools) - ended_request.set() - checked_request.wait() - thread_done[0] = True + started_request, ended_request = threading.Event(), threading.Event() + checked_request = threading.Event() + thread_done = [False] - t = threading.Thread(target=f) - t.setDaemon(True) - t.start() - started_request.wait() + # Starting a request in one thread doesn't put the other thread in a + # request + def f(): self.assertNotInRequestAndDifferentSock(client, pools) - checked_request.set() - ended_request.wait() + client.start_request() + self.assertInRequestAndSameSock(client, pools) + started_request.set() + checked_request.wait() + checked_request.clear() + self.assertInRequestAndSameSock(client, pools) + client.end_request() self.assertNotInRequestAndDifferentSock(client, pools) - checked_request.set() - t.join() - self.assertNotInRequestAndDifferentSock(client, pools) - self.assertTrue(thread_done[0], "Thread didn't complete") - finally: - client.close() + ended_request.set() + checked_request.wait() + thread_done[0] = True + + t = threading.Thread(target=f) + t.setDaemon(True) + t.start() + started_request.wait() + self.assertNotInRequestAndDifferentSock(client, pools) + checked_request.set() + ended_request.wait() + self.assertNotInRequestAndDifferentSock(client, pools) + checked_request.set() + t.join() + self.assertNotInRequestAndDifferentSock(client, pools) + self.assertTrue(thread_done[0], "Thread didn't complete") def test_schedule_refresh(self): - client = self._get_client() + client = client_context.rs_client new_rs_state = rs_state = client._MongoReplicaSetClient__rs_state for host in rs_state.hosts: new_rs_state = new_rs_state.clone_with_host_down(host, 'error!') @@ -1037,8 +1021,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.w, len(rs_state.members), "MongoReplicaSetClient didn't detect members are up") - client.close() - def test_pinned_member(self): latency = 1000 * 1000 client = self._get_client(acceptablelatencyms=latency) @@ -1081,7 +1063,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): ReadPreference.NEAREST, None) def test_alive(self): - client = self._get_client() + client = client_context.rs_client self.assertTrue(client.alive()) client = MongoReplicaSetClient( diff --git a/test/test_son_manipulator.py b/test/test_son_manipulator.py index 9ed87dd15..084d5ebb5 100644 --- a/test/test_son_manipulator.py +++ b/test/test_son_manipulator.py @@ -19,22 +19,18 @@ import sys sys.path[0:0] = [""] from bson.son import SON -from pymongo.database import Database from pymongo.son_manipulator import (NamespaceInjector, ObjectIdInjector, ObjectIdShuffler, SONManipulator) -from test.test_client import get_client -from test import qcheck, unittest +from test import client_context, qcheck, unittest class TestSONManipulator(unittest.TestCase): - def setUp(self): - self.db = Database(get_client(), "pymongo_test") - - def tearDown(self): - self.db = None + @classmethod + def setUpClass(cls): + cls.db = client_context.client.pymongo_test def test_basic(self): manip = SONManipulator() diff --git a/test/test_ssl.py b/test/test_ssl.py index e510c2426..f93131308 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -37,8 +37,9 @@ from pymongo.common import HAS_SSL from pymongo.errors import (ConfigurationError, ConnectionFailure, OperationFailure) -from test import host, pair, port, SkipTest, unittest, version +from test import host, pair, port, SkipTest, unittest from test.utils import server_started_with_auth, remove_all_users +from test.version import Version CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), @@ -75,6 +76,8 @@ def is_server_resolvable(): finally: socket.setdefaulttimeout(socket_timeout) +# Shared ssl-enabled client for the tests +ssl_client = None if HAS_SSL: import ssl @@ -89,8 +92,9 @@ if HAS_SSL: # Is MongoDB configured with server.pem, ca.pem, and crl.pem from # mongodb jstests/lib? try: - MongoClient(host, port, connectTimeoutMS=100, ssl=True, - ssl_certfile=CLIENT_PEM) + global ssl_client + ssl_client = MongoClient(host, port, connectTimeoutMS=100, ssl=True, + ssl_certfile=CLIENT_PEM) CERT_SSL = True except ConnectionFailure: pass @@ -203,7 +207,8 @@ class TestClientSSL(unittest.TestCase): class TestSSL(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): if not HAS_SSL: raise SkipTest("The ssl module is not available.") @@ -243,8 +248,7 @@ class TestSSL(unittest.TestCase): if not CERT_SSL: raise SkipTest("No mongod available over SSL with certs") - client = MongoClient(host, port, ssl=True, ssl_certfile=CLIENT_PEM) - response = client.admin.command('ismaster') + response = ssl_client.admin.command('ismaster') if 'setName' in response: client = MongoReplicaSetClient(pair, replicaSet=response['setName'], @@ -269,8 +273,7 @@ class TestSSL(unittest.TestCase): if not CERT_SSL: raise SkipTest("No mongod available over SSL with certs") - client = MongoClient(host, port, ssl_certfile=CLIENT_PEM) - response = client.admin.command('ismaster') + response = ssl_client.admin.command('ismaster') if 'setName' in response: client = MongoReplicaSetClient(pair, replicaSet=response['setName'], @@ -376,8 +379,7 @@ class TestSSL(unittest.TestCase): if not CERT_SSL: raise SkipTest("No mongod available over SSL with certs") - client = MongoClient(host, port, ssl=True, ssl_certfile=CLIENT_PEM) - response = client.admin.command('ismaster') + response = ssl_client.admin.command('ismaster') try: MongoClient(pair, @@ -414,19 +416,18 @@ class TestSSL(unittest.TestCase): if not CERT_SSL: raise SkipTest("No mongod available over SSL with certs") - client = MongoClient(host, port, ssl=True, ssl_certfile=CLIENT_PEM) - if not version.at_least(client, (2, 5, 3, -1)): + if not Version.from_client(ssl_client).at_least(2, 5, 3, -1): raise SkipTest("MONGODB-X509 tests require MongoDB 2.5.3 or newer") - if not server_started_with_auth(client): + if not server_started_with_auth(ssl_client): raise SkipTest('Authentication is not enabled on server') # Give admin all necessary privileges. - client['$external'].add_user(MONGODB_X509_USERNAME, roles=[ + ssl_client['$external'].add_user(MONGODB_X509_USERNAME, roles=[ {'role': 'readWriteAnyDatabase', 'db': 'admin'}, {'role': 'userAdminAnyDatabase', 'db': 'admin'}]) - coll = client.pymongo_test.test + coll = ssl_client.pymongo_test.test self.assertRaises(OperationFailure, coll.count) - self.assertTrue(client.admin.authenticate(MONGODB_X509_USERNAME, - mechanism='MONGODB-X509')) + self.assertTrue(ssl_client.admin.authenticate(MONGODB_X509_USERNAME, + mechanism='MONGODB-X509')) self.assertTrue(coll.remove()) uri = ('mongodb://%s@%s:%d/?authMechanism=' 'MONGODB-X509' % (quote_plus(MONGODB_X509_USERNAME), host, port)) @@ -443,7 +444,7 @@ class TestSSL(unittest.TestCase): 'MONGODB-X509' % (quote_plus("not the username"), host, port)) self.assertRaises(ConfigurationError, MongoClient, uri, ssl=True, ssl_certfile=CLIENT_PEM) - self.assertRaises(OperationFailure, client.admin.authenticate, + self.assertRaises(OperationFailure, ssl_client.admin.authenticate, "not the username", mechanism="MONGODB-X509") @@ -456,8 +457,8 @@ class TestSSL(unittest.TestCase): ssl=True, ssl_certfile=CA_PEM) # Cleanup - remove_all_users(client['$external']) - client['$external'].logout() + remove_all_users(ssl_client['$external']) + ssl_client['$external'].logout() if __name__ == "__main__": unittest.main() diff --git a/test/test_thread_util.py b/test/test_thread_util.py index 942c3b18a..5bd08399b 100644 --- a/test/test_thread_util.py +++ b/test/test_thread_util.py @@ -158,7 +158,9 @@ class TestIdent(unittest.TestCase): class TestGreenletIdent(unittest.TestCase): - def setUp(self): + + @classmethod + def setUpClass(cls): if not thread_util.have_gevent: raise SkipTest("need Gevent") diff --git a/test/utils.py b/test/utils.py index b31288963..d19ad371b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -23,7 +23,8 @@ import threading from pymongo import MongoClient, MongoReplicaSetClient from pymongo.errors import AutoReconnect from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo -from test import host, port, SkipTest, version +from test import host, port, SkipTest +from test.version import Version try: @@ -124,7 +125,7 @@ def drop_collections(db): db.drop_collection(coll) def remove_all_users(db): - if version.at_least(db.connection, (2, 5, 3, -1)): + if Version.from_client(db.connection).at_least(2, 5, 3, -1): db.command({"dropAllUsersFromDatabase": 1}) else: db.system.users.remove({}) diff --git a/test/version.py b/test/version.py index 1632f8747..917530b7e 100644 --- a/test/version.py +++ b/test/version.py @@ -15,41 +15,49 @@ """Some tools for running tests based on MongoDB server version.""" -def _padded(iter, length, padding=0): - l = list(iter) - if len(l) < length: - for _ in range(length - len(l)): - l.append(0) - return l +class Version(tuple): + def __new__(cls, *version): + padded_version = cls._padded(version, 4) + return super(Version, cls).__new__(cls, tuple(padded_version)) -def _parse_version_string(version_string): - mod = 0 - if version_string.endswith("+"): - version_string = version_string[0:-1] - mod = 1 - elif version_string.endswith("-pre-"): - version_string = version_string[0:-5] - mod = -1 - elif version_string.endswith("-"): - version_string = version_string[0:-1] - mod = -1 - # Deal with '-rcX' substrings - if version_string.find('-rc') != -1: - version_string = version_string[0:version_string.find('-rc')] - mod = -1 + @classmethod + def _padded(cls, iter, length, padding=0): + l = list(iter) + if len(l) < length: + for _ in range(length - len(l)): + l.append(padding) + return l - version = [int(part) for part in version_string.split(".")] - version = _padded(version, 3) - version.append(mod) + @classmethod + def from_string(cls, version_string): + mod = 0 + if version_string.endswith("+"): + version_string = version_string[0:-1] + mod = 1 + elif version_string.endswith("-pre-"): + version_string = version_string[0:-5] + mod = -1 + elif version_string.endswith("-"): + version_string = version_string[0:-1] + mod = -1 + # Deal with '-rcX' substrings + if version_string.find('-rc') != -1: + version_string = version_string[0:version_string.find('-rc')] + mod = -1 - return tuple(version) + version = [int(part) for part in version_string.split(".")] + version = cls._padded(version, 3) + version.append(mod) + return Version(*version) -# Note this is probably broken for very old versions of the database... -def version(client): - return _parse_version_string(client.server_info()["version"]) + @classmethod + def from_client(cls, client): + return cls.from_string(client.server_info()['version']) + def at_least(self, *other_version): + return self >= Version(*other_version) -def at_least(client, min_version): - return version(client) >= tuple(_padded(min_version, 4)) + def __str__(self): + return ".".join(map(str, self))