PYTHON-681 Reuse MongoClient whenever possible in the tests

This commit is contained in:
Luke Lovett 2014-04-30 18:50:50 +00:00
parent cccecff1e9
commit 26fb43cf78
22 changed files with 800 additions and 841 deletions

View File

@ -25,9 +25,12 @@ else:
from unittest import SkipTest
import warnings
from functools import wraps
import pymongo
from bson.py3compat import _unicode
from test.version import Version
# hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode
# type in Python 2, so ensure these hostnames are unicodes, too. It makes tests
@ -43,6 +46,114 @@ host3 = _unicode(os.environ.get("DB_IP3", 'localhost'))
port3 = int(os.environ.get("DB_PORT3", 27019))
class ClientContext(object):
def __init__(self):
"""Create a client and grab essential information from the server."""
try:
self.client = pymongo.MongoClient(host, port)
except pymongo.errors.ConnectionFailure:
self.client = None
else:
self.ismaster = self.client.admin.command('ismaster')
self.w = len(self.ismaster.get("hosts", [])) or 1
self.setname = self.ismaster.get('setName', '')
self.rs_client = None
if self.setname:
self.rs_client = pymongo.MongoReplicaSetClient(
pair, replicaSet=self.setname)
self.cmd_line = self.client.admin.command('getCmdLineOpts')
self.version = Version.from_client(self.client)
self.auth_enabled = self._server_started_with_auth()
self.test_commands_enabled = ('testCommandsEnabled=1'
in self.cmd_line['argv'])
self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid')
def _server_started_with_auth(self):
# MongoDB >= 2.0
if 'parsed' in self.cmd_line:
parsed = self.cmd_line['parsed']
# MongoDB >= 2.6
if 'security' in parsed:
security = parsed['security']
# >= rc3
if 'authorization' in security:
return security['authorization'] == 'enabled'
# < rc3
return (security.get('auth', False) or
bool(security.get('keyFile')))
return parsed.get('auth', False) or bool(parsed.get('keyFile'))
# Legacy
argv = self.cmd_line['argv']
return '--auth' in argv or '--keyFile' in argv
def _require(self, condition, msg, func=None):
def make_wrapper(f):
@wraps(f)
def wrap(*args, **kwargs):
if condition:
return f(*args, **kwargs)
raise SkipTest(msg)
return wrap
if func is None:
def decorate(f):
return make_wrapper(f)
return decorate
return make_wrapper(func)
def require_version_min(self, *ver):
"""Run a test only if the server version is at least ``version``."""
other_version = Version(*ver)
return self._require(self.version >= other_version,
"Server version must be at least %s"
% str(other_version))
def require_version_max(self, *ver):
"""Run a test only if the server version is at most ``version``."""
other_version = Version(*ver)
return self._require(self.version <= other_version,
"Server version must be at most %s"
% str(other_version))
def require_auth(self, func):
"""Run a test only if the server is running with auth enabled."""
return self.check_auth_with_sharding(
self._require(self.auth_enabled,
"Authentication is not enabled on the server",
func=func))
def require_replica_set(self, func):
"""Run a test only if the client is connected to a replica set."""
return self._require(self.rs_client is not None,
"Not connected to a replica set",
func=func)
def require_no_mongos(self, func):
"""Run a test only if the client is not connected to a mongos."""
return self._require(not self.is_mongos,
"Must be connected to a mongod, not a mongos",
func=func)
def check_auth_with_sharding(self, func):
"""Skip a test when connected to mongos < 2.0 and running with auth."""
condition = not (self.auth_enabled and
self.is_mongos and self.version < (2,))
return self._require(condition,
"Auth with sharding requires MongoDB >= 2.0.0",
func=func)
def require_test_commands(self, func):
"""Run a test only if the server has test commands enabled."""
return self._require(self.test_commands_enabled,
"Test commands must be enabled",
func=func)
# Reusable client context
client_context = ClientContext()
def setup():
warnings.resetwarnings()
warnings.simplefilter("always")

View File

@ -34,8 +34,9 @@ from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.mongo_client import MongoClient, _partition_node
from pymongo.read_preferences import ReadPreference
from test import SkipTest, unittest, utils, version
from test import SkipTest, unittest, utils
from test.utils import one
from test.version import Version
# May be imported from gevent, below.
@ -1013,7 +1014,7 @@ class TestLastErrorDefaults(HATestCase):
use_greenlets=use_greenlets)
def test_get_last_error_defaults(self):
if not version.at_least(self.c, (1, 9, 0)):
if not Version.from_client(self.c).at_least(1, 9, 0):
raise SkipTest("Need MongoDB >= 1.9.0 to test getLastErrorDefaults")
replset = self.c.local.system.replset.find_one()

View File

@ -30,8 +30,7 @@ from pymongo import MongoClient, MongoReplicaSetClient
from pymongo.auth import HAVE_KERBEROS
from pymongo.errors import OperationFailure, ConfigurationError
from pymongo.read_preferences import ReadPreference
from test import host, port, SkipTest, unittest, version
from test.utils import is_mongos, server_started_with_auth
from test import client_context, host, port, SkipTest, unittest
# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS.
GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
@ -229,13 +228,9 @@ class TestSASL(unittest.TestCase):
class TestAuthURIOptions(unittest.TestCase):
@client_context.require_auth
def setUp(self):
client = MongoClient(host, port)
# Sharded auth not supported before MongoDB 2.0
if is_mongos(client) and not version.at_least(client, (2, 0, 0)):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0")
if not server_started_with_auth(client):
raise SkipTest('Authentication is not enabled on server')
response = client.admin.command('ismaster')
self.set_name = str(response.get('setName', ''))
client.admin.add_user('admin', 'pass', roles=['userAdminAnyDatabase',
@ -313,14 +308,11 @@ class TestAuthURIOptions(unittest.TestCase):
class TestDelegatedAuth(unittest.TestCase):
@client_context.require_auth
@client_context.require_version_max(2, 5, 3)
@client_context.require_version_min(2, 4, 0)
def setUp(self):
self.client = MongoClient(host, port)
if not version.at_least(self.client, (2, 4, 0)):
raise SkipTest('Delegated authentication requires MongoDB >= 2.4.0')
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
if version.at_least(self.client, (2, 5, 3, -1)):
raise SkipTest('Delegated auth does not exist in MongoDB >= 2.5.3')
# Give admin all privileges.
self.client.admin.add_user('admin', 'pass',
roles=['readAnyDatabase',

View File

@ -27,8 +27,7 @@ import bson
from bson.binary import *
from bson.py3compat import u
from bson.son import SON
from test import unittest
from test.test_client import get_client
from test import client_context, unittest
from pymongo.mongo_client import MongoClient
@ -147,9 +146,8 @@ class TestBinary(unittest.TestCase):
self.assertEqual(data, encoded)
# Test insert and find
client = get_client()
client.pymongo_test.drop_collection('java_uuid')
coll = client.pymongo_test.java_uuid
client_context.client.pymongo_test.drop_collection('java_uuid')
coll = client_context.client.pymongo_test.java_uuid
coll.uuid_subtype = JAVA_LEGACY
coll.insert(docs)
@ -160,7 +158,7 @@ class TestBinary(unittest.TestCase):
coll.uuid_subtype = OLD_UUID_SUBTYPE
for d in coll.find():
self.assertNotEqual(d['newguid'], d['newguidstring'])
client.pymongo_test.drop_collection('java_uuid')
client_context.client.pymongo_test.drop_collection('java_uuid')
def test_legacy_csharp_uuid(self):
@ -217,9 +215,8 @@ class TestBinary(unittest.TestCase):
self.assertEqual(data, encoded)
# Test insert and find
client = get_client()
client.pymongo_test.drop_collection('csharp_uuid')
coll = client.pymongo_test.csharp_uuid
client_context.client.pymongo_test.drop_collection('csharp_uuid')
coll = client_context.client.pymongo_test.csharp_uuid
coll.uuid_subtype = CSHARP_LEGACY
coll.insert(docs)
@ -230,7 +227,7 @@ class TestBinary(unittest.TestCase):
coll.uuid_subtype = OLD_UUID_SUBTYPE
for d in coll.find():
self.assertNotEqual(d['newguid'], d['newguidstring'])
client.pymongo_test.drop_collection('csharp_uuid')
client_context.client.pymongo_test.drop_collection('csharp_uuid')
def test_uri_to_uuid(self):
@ -240,8 +237,7 @@ class TestBinary(unittest.TestCase):
def test_uuid_queries(self):
c = get_client()
coll = c.pymongo_test.test
coll = client_context.client.pymongo_test.test
coll.drop()
uu = uuid.uuid4()

View File

@ -21,19 +21,14 @@ sys.path[0:0] = [""]
from bson import InvalidDocument, SON
from bson.py3compat import string_type
from pymongo.errors import BulkWriteError, InvalidOperation, OperationFailure
from test import SkipTest, unittest, version
from test.test_client import get_client
from test.utils import (oid_generated_on_client,
remove_all_users,
server_started_with_auth,
server_started_with_nojournal)
from test import client_context, unittest
from test.utils import oid_generated_on_client, remove_all_users
class BulkTestBase(unittest.TestCase):
def setUp(self):
client = get_client()
self.has_write_commands = (client.max_wire_version > 1)
self.has_write_commands = (client_context.client.max_wire_version > 1)
def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response."""
@ -110,9 +105,12 @@ class BulkTestBase(unittest.TestCase):
class TestBulk(BulkTestBase):
@classmethod
def setUpClass(cls):
cls.coll = client_context.client.pymongo_test.test
def setUp(self):
super(TestBulk, self).setUp()
self.coll = get_client().pymongo_test.test
self.coll.remove()
def test_empty(self):
@ -583,8 +581,7 @@ class TestBulk(BulkTestBase):
self.assertEqual(self.coll.find({'x': 2}).count(), 0)
def test_upsert_large(self):
client = self.coll.database.connection
big = 'a' * (client.max_bson_size - 37)
big = 'a' * (client_context.client.max_bson_size - 37)
bulk = self.coll.initialize_ordered_bulk_op()
bulk.find({'x': 1}).upsert().update({'$set': {'s': big}})
result = bulk.execute()
@ -886,29 +883,26 @@ class TestBulk(BulkTestBase):
class TestBulkWriteConcern(BulkTestBase):
@classmethod
def setUpClass(cls):
cls.is_repl = ('setName' in client_context.ismaster)
cls.w = client_context.w
cls.coll = client_context.client.pymongo_test.test
def setUp(self):
super(TestBulkWriteConcern, self).setUp()
client = get_client()
ismaster = client.test.command('ismaster')
self.is_repl = bool(ismaster.get('setName'))
self.w = len(ismaster.get("hosts", []))
self.client = client
self.coll = client.pymongo_test.test
self.coll.remove()
@client_context.require_version_min(1, 8, 2)
def test_fsync_and_j(self):
if not version.at_least(self.client, (1, 8, 2)):
raise SkipTest("Need at least MongoDB 1.8.2")
batch = self.coll.initialize_ordered_bulk_op()
batch.insert({'a': 1})
self.assertRaises(
OperationFailure,
batch.execute, {'fsync': True, 'j': True})
@client_context.require_replica_set
def test_write_concern_failure_ordered(self):
if not self.is_repl:
raise SkipTest("Need a replica set to test.")
# Ensure we don't raise on wnote.
batch = self.coll.initialize_ordered_bulk_op()
batch.find({"something": "that does no exist"}).remove()
@ -985,10 +979,8 @@ class TestBulkWriteConcern(BulkTestBase):
finally:
self.coll.drop_index([('a', 1)])
@client_context.require_replica_set
def test_write_concern_failure_unordered(self):
if not self.is_repl:
raise SkipTest("Need a replica set to test.")
# Ensure we don't raise on wnote.
batch = self.coll.initialize_unordered_bulk_op()
batch.find({"something": "that does no exist"}).remove()
@ -1063,9 +1055,12 @@ class TestBulkWriteConcern(BulkTestBase):
class TestBulkNoResults(BulkTestBase):
@classmethod
def setUpClass(cls):
cls.coll = client_context.client.pymongo_test.test
def setUp(self):
super(TestBulkNoResults, self).setUp()
self.coll = get_client().pymongo_test.test
self.coll.remove()
def test_no_results_ordered_success(self):
@ -1114,21 +1109,21 @@ class TestBulkNoResults(BulkTestBase):
class TestBulkAuthorization(BulkTestBase):
@classmethod
@client_context.require_auth
@client_context.require_version_min(2, 5, 3)
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
cls.coll = cls.db.test
def setUp(self):
super(TestBulkAuthorization, self).setUp()
self.client = client = get_client()
if (not server_started_with_auth(client)
or not version.at_least(client, (2, 5, 3))):
raise SkipTest('Need at least MongoDB 2.5.3 with auth')
db = client.pymongo_test
self.coll = db.test
self.coll.remove()
db.add_user('dbOwner', 'pw', roles=['dbOwner'])
db.authenticate('dbOwner', 'pw')
db.add_user('readonly', 'pw', roles=['read'])
db.command(
self.db.add_user('dbOwner', 'pw', roles=['dbOwner'])
self.db.authenticate('dbOwner', 'pw')
self.db.add_user('readonly', 'pw', roles=['read'])
self.db.command(
'createRole', 'noremove',
privileges=[{
'actions': ['insert', 'update', 'find'],
@ -1136,13 +1131,21 @@ class TestBulkAuthorization(BulkTestBase):
}],
roles=[])
db.add_user('noremove', 'pw', roles=['noremove'])
db.logout()
self.db.add_user('noremove', 'pw', roles=['noremove'])
self.db.logout()
def tearDown(self):
self.db.logout()
self.db.authenticate('dbOwner', 'pw')
self.db.command('dropRole', 'noremove')
remove_all_users(self.db)
self.db.logout()
self.db.connection.disconnect()
def test_readonly(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
db = self.client.pymongo_test
db = client_context.client.pymongo_test
db.authenticate('readonly', 'pw')
bulk = self.coll.initialize_ordered_bulk_op()
bulk.insert({'x': 1})
@ -1151,7 +1154,7 @@ class TestBulkAuthorization(BulkTestBase):
def test_no_remove(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
db = self.client.pymongo_test
db = client_context.client.pymongo_test
db.authenticate('noremove', 'pw')
bulk = self.coll.initialize_ordered_bulk_op()
bulk.insert({'x': 1})
@ -1161,14 +1164,6 @@ class TestBulkAuthorization(BulkTestBase):
self.assertRaises(OperationFailure, bulk.execute)
self.assertEqual(set([1, 2]), set(self.coll.distinct('x')))
def tearDown(self):
db = self.client.pymongo_test
db.logout()
db.authenticate('dbOwner', 'pw')
db.command('dropRole', 'noremove')
remove_all_users(db)
db.logout()
if __name__ == "__main__":
unittest.main()

View File

@ -36,14 +36,12 @@ from pymongo.errors import (AutoReconnect,
InvalidName,
OperationFailure,
PyMongoError)
from test import host, pair, port, SkipTest, unittest, version
from test import client_context, host, pair, port, SkipTest, unittest
from test.pymongo_mocks import MockClient
from test.utils import (assertRaisesExactly,
delay,
is_mongos,
remove_all_users,
server_is_master_with_slave,
server_started_with_auth,
TestRequestMixin,
_TestLazyConnectMixin,
lazy_client_trial,
@ -56,6 +54,10 @@ def get_client(*args, **kwargs):
class TestClient(unittest.TestCase, TestRequestMixin):
@classmethod
def setUpClass(cls):
cls.client = client_context.client
def test_types(self):
self.assertRaises(TypeError, MongoClient, 1)
self.assertRaises(TypeError, MongoClient, 1.14)
@ -108,7 +110,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertEqual(host, c.host)
self.assertEqual(port, c.port)
if version.at_least(c, (2, 5, 4, -1)):
if client_context.version.at_least(2, 5, 4, -1):
self.assertTrue(c.max_wire_version > 0)
else:
self.assertEqual(c.max_wire_version, 0)
@ -136,10 +138,10 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertTrue(MongoClient(host, port))
def test_equality(self):
client = MongoClient(host, port)
self.assertEqual(client, MongoClient(host, port))
# ClientContext.client is constructed as MongoClient(host, port)
self.assertEqual(self.client, MongoClient(host, port))
# Explicitly test inequality
self.assertFalse(client != MongoClient(host, port))
self.assertFalse(self.client != MongoClient(host, port))
def test_host_w_port(self):
self.assertTrue(MongoClient("%s:%d" % (host, port)))
@ -153,10 +155,9 @@ class TestClient(unittest.TestCase, TestRequestMixin):
"MongoClient('%s', %d)" % (host, port))
def test_getters(self):
self.assertEqual(MongoClient(host, port).host, host)
self.assertEqual(MongoClient(host, port).port, port)
self.assertEqual(set([(host, port)]),
MongoClient(host, port).nodes)
self.assertEqual(self.client.host, host)
self.assertEqual(self.client.port, port)
self.assertEqual(set([(host, port)]), self.client.nodes)
def test_use_greenlets(self):
self.assertFalse(MongoClient(host, port).use_greenlets)
@ -166,62 +167,56 @@ class TestClient(unittest.TestCase, TestRequestMixin):
host, port, use_greenlets=True).use_greenlets)
def test_get_db(self):
client = MongoClient(host, port)
def make_db(base, name):
return base[name]
self.assertRaises(InvalidName, make_db, client, "")
self.assertRaises(InvalidName, make_db, client, "te$t")
self.assertRaises(InvalidName, make_db, client, "te.t")
self.assertRaises(InvalidName, make_db, client, "te\\t")
self.assertRaises(InvalidName, make_db, client, "te/t")
self.assertRaises(InvalidName, make_db, client, "te st")
self.assertRaises(InvalidName, make_db, self.client, "")
self.assertRaises(InvalidName, make_db, self.client, "te$t")
self.assertRaises(InvalidName, make_db, self.client, "te.t")
self.assertRaises(InvalidName, make_db, self.client, "te\\t")
self.assertRaises(InvalidName, make_db, self.client, "te/t")
self.assertRaises(InvalidName, make_db, self.client, "te st")
self.assertTrue(isinstance(client.test, Database))
self.assertEqual(client.test, client["test"])
self.assertEqual(client.test, Database(client, "test"))
self.assertTrue(isinstance(self.client.test, Database))
self.assertEqual(self.client.test, self.client["test"])
self.assertEqual(self.client.test, Database(self.client, "test"))
def test_database_names(self):
client = MongoClient(host, port)
self.client.pymongo_test.test.save({"dummy": u("object")})
self.client.pymongo_test_mike.test.save({"dummy": u("object")})
client.pymongo_test.test.save({"dummy": u("object")})
client.pymongo_test_mike.test.save({"dummy": u("object")})
dbs = client.database_names()
dbs = self.client.database_names()
self.assertTrue("pymongo_test" in dbs)
self.assertTrue("pymongo_test_mike" in dbs)
def test_drop_database(self):
client = MongoClient(host, port)
self.assertRaises(TypeError, client.drop_database, 5)
self.assertRaises(TypeError, client.drop_database, None)
self.assertRaises(TypeError, self.client.drop_database, 5)
self.assertRaises(TypeError, self.client.drop_database, None)
raise SkipTest("This test often fails due to SERVER-2329")
client.pymongo_test.test.save({"dummy": u("object")})
dbs = client.database_names()
self.client.pymongo_test.test.save({"dummy": u("object")})
dbs = self.client.database_names()
self.assertTrue("pymongo_test" in dbs)
client.drop_database("pymongo_test")
dbs = client.database_names()
self.client.drop_database("pymongo_test")
dbs = self.client.database_names()
self.assertTrue("pymongo_test" not in dbs)
client.pymongo_test.test.save({"dummy": u("object")})
dbs = client.database_names()
self.client.pymongo_test.test.save({"dummy": u("object")})
dbs = self.client.database_names()
self.assertTrue("pymongo_test" in dbs)
client.drop_database(client.pymongo_test)
dbs = client.database_names()
self.client.drop_database(self.client.pymongo_test)
dbs = self.client.database_names()
self.assertTrue("pymongo_test" not in dbs)
def test_copy_db(self):
c = MongoClient(host, port)
c = self.client
# Due to SERVER-2329, databases may not disappear
# from a master in a master-slave pair.
if server_is_master_with_slave(c):
raise SkipTest("SERVER-2329")
if (not version.at_least(c, (2, 6, 0)) and
is_mongos(c) and server_started_with_auth(c)):
if (client_context.version.at_least(2, 6, 0) and
client_context.is_mongos and client_context.auth_enabled):
raise SkipTest("Need mongos >= 2.6.0 to test with authentication")
# We test copy twice; once starting in a request and once not. In
# either case the copy should succeed (because it starts a request
@ -259,8 +254,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"])
# See SERVER-6427 for mongos
if not is_mongos(c) and server_started_with_auth(c):
if not client_context.is_mongos and client_context.auth_enabled:
c.drop_database("pymongo_test1")
c.admin.add_user("admin", "password")
@ -286,34 +280,31 @@ class TestClient(unittest.TestCase, TestRequestMixin):
# Cleanup
remove_all_users(c.pymongo_test)
c.admin.remove_user("admin")
c.admin.logout()
c.disconnect()
def test_iteration(self):
client = MongoClient(host, port)
def iterate():
[a for a in client]
[a for a in self.client]
self.assertRaises(TypeError, iterate)
def test_disconnect(self):
c = MongoClient(host, port)
coll = c.pymongo_test.bar
coll = self.client.pymongo_test.bar
c.disconnect()
c.disconnect()
self.client.disconnect()
self.client.disconnect()
coll.count()
c.disconnect()
c.disconnect()
self.client.disconnect()
self.client.disconnect()
coll.count()
def test_from_uri(self):
c = MongoClient(host, port)
self.assertEqual(c, MongoClient("mongodb://%s:%d" % (host, port)))
self.assertEqual(self.client,
MongoClient("mongodb://%s:%d" % (host, port)))
def test_get_default_database(self):
c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False)
@ -330,18 +321,13 @@ class TestClient(unittest.TestCase, TestRequestMixin):
c = MongoClient(uri, _connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
@client_context.require_auth
def test_auth_from_uri(self):
c = MongoClient(host, port)
# Sharded auth not supported before MongoDB 2.0
if is_mongos(c) and not version.at_least(c, (2, 0, 0)):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0")
if not server_started_with_auth(c):
raise SkipTest('Authentication is not enabled on server')
c.admin.add_user("admin", "pass")
c.admin.authenticate("admin", "pass")
self.client.admin.add_user("admin", "pass")
self.client.admin.authenticate("admin", "pass")
try:
c.pymongo_test.add_user("user", "pass", roles=['userAdmin', 'readWrite'])
self.client.pymongo_test.add_user(
"user", "pass", roles=['userAdmin', 'readWrite'])
self.assertRaises(ConfigurationError, MongoClient,
"mongodb://foo:bar@%s:%d" % (host, port))
@ -375,18 +361,13 @@ class TestClient(unittest.TestCase, TestRequestMixin):
finally:
# Clean up.
remove_all_users(c.pymongo_test)
remove_all_users(c.admin)
remove_all_users(self.client.pymongo_test)
remove_all_users(self.client.admin)
self.client.admin.logout()
self.client.disconnect()
@client_context.require_auth
def test_lazy_auth_raises_operation_failure(self):
# Check if we have the prerequisites to run this test.
c = MongoClient(host, port)
if not server_started_with_auth(c):
raise SkipTest('Authentication is not enabled on server')
if is_mongos(c) and not version.at_least(c, (2, 0, 0)):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0")
lazy_client = MongoClient(
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port),
_connect=False)
@ -397,8 +378,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
def test_unix_socket(self):
if not hasattr(socket, "AF_UNIX"):
raise SkipTest("UNIX-sockets are not supported on this system")
if (sys.platform == 'darwin' and
server_started_with_auth(MongoClient(host, port))):
if sys.platform == 'darwin' and client_context.auth_enabled:
raise SkipTest("SERVER-8492")
mongodb_socket = '/tmp/mongodb-27017.sock'
@ -428,7 +408,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
except ImportError:
raise SkipTest("No multiprocessing module")
db = MongoClient(host, port).pymongo_test
db = self.client.pymongo_test
# Failure occurs if the client is used before the fork
db.test.find_one()
@ -478,7 +458,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
pass
def test_document_class(self):
c = MongoClient(host, port)
c = self.client
db = c.pymongo_test
db.test.insert({"x": 1})
@ -488,9 +468,12 @@ class TestClient(unittest.TestCase, TestRequestMixin):
c.document_class = SON
self.assertEqual(SON, c.document_class)
self.assertTrue(isinstance(db.test.find_one(), SON))
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
try:
self.assertEqual(SON, c.document_class)
self.assertTrue(isinstance(db.test.find_one(), SON))
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
finally:
c.document_class = dict
c = MongoClient(host, port, document_class=SON)
db = c.pymongo_test
@ -531,7 +514,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
get_client, socketTimeoutMS='foo')
def test_socket_timeout(self):
no_timeout = MongoClient(host, port)
no_timeout = self.client
timeout_sec = 1
timeout = MongoClient(
host, port, socketTimeoutMS=1000 * timeout_sec)
@ -562,7 +545,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertRaises(ConfigurationError, MongoClient, tz_aware='foo')
aware = MongoClient(host, port, tz_aware=True)
naive = MongoClient(host, port)
naive = self.client
aware.pymongo_test.drop_collection("test")
now = datetime.datetime.utcnow()
@ -595,28 +578,26 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertTrue("pymongo_test" in dbs)
self.assertTrue("pymongo_test_bernie" in dbs)
@client_context.require_no_mongos
def test_fsync_lock_unlock(self):
c = get_client()
if is_mongos(c):
raise SkipTest('fsync/lock not supported by mongos')
if not version.at_least(c, (2, 0)) and server_started_with_auth(c):
if (not client_context.version.at_least(2, 0) and
client_context.auth_enabled):
raise SkipTest('Requires server >= 2.0 to test with auth')
res = c.admin.command('getCmdLineOpts')
if '--master' in res['argv'] and version.at_least(c, (2, 3, 0)):
if (server_is_master_with_slave(client_context.client) and
client_context.version.at_least(2, 3, 0)):
raise SkipTest('SERVER-7714')
self.assertFalse(c.is_locked)
self.assertFalse(self.client.is_locked)
# async flushing not supported on windows...
if sys.platform not in ('cygwin', 'win32'):
c.fsync(async=True)
self.assertFalse(c.is_locked)
c.fsync(lock=True)
self.assertTrue(c.is_locked)
self.client.fsync(async=True)
self.assertFalse(self.client.is_locked)
self.client.fsync(lock=True)
self.assertTrue(self.client.is_locked)
locked = True
c.unlock()
self.client.unlock()
for _ in range(5):
locked = c.is_locked
locked = self.client.is_locked
if not locked:
break
time.sleep(1)
@ -633,26 +614,23 @@ class TestClient(unittest.TestCase, TestRequestMixin):
# pool
self.assertEqual(1, len(get_pool(client).sockets))
# We need exec here because if the Python version is less than 2.6
# these with-statements won't even compile.
with contextlib.closing(client):
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
self.assertEqual(None, client._MongoClient__member)
with get_client() as client:
with self.client as client:
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
self.assertEqual(None, client._MongoClient__member)
def test_with_start_request(self):
client = get_client()
pool = get_pool(client)
pool = get_pool(self.client)
# No request started
self.assertNoRequest(pool)
self.assertDifferentSock(pool)
# Start a request
request_context_mgr = client.start_request()
request_context_mgr = self.client.start_request()
self.assertTrue(
isinstance(request_context_mgr, object)
)
@ -667,8 +645,8 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertDifferentSock(pool)
# Test the 'with' statement
with client.start_request() as request:
self.assertEqual(client, request.connection)
with self.client.start_request() as request:
self.assertEqual(self.client, request.connection)
self.assertNoSocketYet(pool)
self.assertSameSock(pool)
self.assertRequestSocket(pool)
@ -685,8 +663,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
)
# auto_start_request should default to False
client = get_client()
self.assertFalse(client.auto_start_request)
self.assertFalse(self.client.auto_start_request)
client = get_client(auto_start_request=True)
self.assertTrue(client.auto_start_request)
@ -709,35 +686,34 @@ class TestClient(unittest.TestCase, TestRequestMixin):
def test_nested_request(self):
# auto_start_request is False
client = get_client()
pool = get_pool(client)
self.assertFalse(client.in_request())
pool = get_pool(self.client)
self.assertFalse(self.client.in_request())
# Start and end request
client.start_request()
self.assertInRequestAndSameSock(client, pool)
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pool)
self.client.start_request()
self.assertInRequestAndSameSock(self.client, pool)
self.client.end_request()
self.assertNotInRequestAndDifferentSock(self.client, pool)
# Double-nesting
client.start_request()
client.start_request()
client.end_request()
self.assertInRequestAndSameSock(client, pool)
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pool)
self.client.start_request()
self.client.start_request()
self.client.end_request()
self.assertInRequestAndSameSock(self.client, pool)
self.client.end_request()
self.assertNotInRequestAndDifferentSock(self.client, pool)
# Extra end_request calls have no effect - count stays at zero
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pool)
self.client.end_request()
self.assertNotInRequestAndDifferentSock(self.client, pool)
client.start_request()
self.assertInRequestAndSameSock(client, pool)
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pool)
self.client.start_request()
self.assertInRequestAndSameSock(self.client, pool)
self.client.end_request()
self.assertNotInRequestAndDifferentSock(self.client, pool)
def test_request_threads(self):
client = get_client(auto_start_request=False)
client = self.client
pool = get_pool(client)
self.assertNotInRequestAndDifferentSock(client, pool)
@ -787,8 +763,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
# Test fix for PYTHON-294 -- make sure MongoClient closes its
# socket if it gets an interrupt while waiting to recv() from it.
c = get_client()
db = c.pymongo_test
db = self.client.pymongo_test
# A $where clause which takes 1.5 sec to execute
where = delay(1.5)
@ -826,17 +801,17 @@ class TestClient(unittest.TestCase, TestRequestMixin):
def test_operation_failure_without_request(self):
# Ensure MongoClient doesn't close socket after it gets an error
# response to getLastError. PYTHON-395.
c = get_client()
pool = get_pool(c)
self.assertEqual(1, len(pool.sockets))
pool = get_pool(self.client)
socket_count = len(pool.sockets)
self.assertGreaterEqual(socket_count, 1)
old_sock_info = next(iter(pool.sockets))
c.pymongo_test.test.drop()
c.pymongo_test.test.insert({'_id': 'foo'})
self.client.pymongo_test.test.drop()
self.client.pymongo_test.test.insert({'_id': 'foo'})
self.assertRaises(
OperationFailure,
c.pymongo_test.test.insert, {'_id': 'foo'})
self.client.pymongo_test.test.insert, {'_id': 'foo'})
self.assertEqual(1, len(pool.sockets))
self.assertEqual(socket_count, len(pool.sockets))
new_sock_info = next(iter(pool.sockets))
self.assertEqual(old_sock_info, new_sock_info)
@ -861,7 +836,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertEqual(old_sock_info, pool._get_request_state())
def test_alive(self):
self.assertTrue(get_client().alive())
self.assertTrue(self.client.alive())
client = MongoClient('doesnt exist', _connect=False)
self.assertFalse(client.alive())
@ -938,12 +913,9 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertEqual(expected_min, c.min_wire_version)
self.assertEqual(expected_max, c.max_wire_version)
@client_context.require_replica_set
def test_replica_set(self):
client = MongoClient(host, port)
name = client.pymongo_test.command('ismaster').get('setName')
if not name:
raise SkipTest('Not connected to a replica set')
name = client_context.setname
MongoClient(host, port, replicaSet=name) # No error.
self.assertRaises(

View File

@ -38,7 +38,7 @@ from pymongo import (ASCENDING, DESCENDING, GEO2D,
from pymongo import message as message_module
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.database import Database
from pymongo.read_preferences import ReadPreference
from pymongo.son_manipulator import SONManipulator
from pymongo.errors import (DocumentTooLarge,
@ -51,22 +51,22 @@ from pymongo.errors import (DocumentTooLarge,
from test.test_client import get_client
from test.utils import (is_mongos, joinall, enable_text_search, get_pool,
oid_generated_on_client)
from test import qcheck, SkipTest, unittest, version
from test import (client_context,
qcheck,
SkipTest,
unittest,
version)
class TestCollection(unittest.TestCase):
def setUp(self):
self.client = get_client()
self.db = self.client.pymongo_test
ismaster = self.db.command('ismaster')
self.setname = ismaster.get('setName')
self.w = len(ismaster.get('hosts', [])) or 1
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
cls.w = client_context.w
def tearDown(self):
self.db.drop_collection("test_large_limit")
self.db = None
self.client = None
def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5)
@ -357,9 +357,8 @@ class TestCollection(unittest.TestCase):
index_info = db.test.index_information()['loc_2d']
self.assertEqual([('loc', '2d')], index_info['key'])
@client_context.require_no_mongos
def test_index_haystack(self):
if is_mongos(self.db.connection):
raise SkipTest("geoSearch is not supported by mongos")
db = self.db
db.test.drop_indexes()
db.test.remove()
@ -393,14 +392,10 @@ class TestCollection(unittest.TestCase):
"type": "restaurant"
}, results[0])
@client_context.require_version_min(2, 3, 2)
@client_context.require_no_mongos
def test_index_text(self):
if not version.at_least(self.client, (2, 3, 2)):
raise SkipTest("Text search requires server >=2.3.2.")
if is_mongos(self.client):
raise SkipTest("setParameter does not work through mongos")
enable_text_search(self.client)
enable_text_search(client_context.client)
db = self.db
db.test.drop_indexes()
@ -408,7 +403,7 @@ class TestCollection(unittest.TestCase):
index_info = db.test.index_information()["t_text"]
self.assertTrue("weights" in index_info)
if version.at_least(self.client, (2, 5, 5)):
if client_context.version.at_least(2, 5, 5):
db.test.insert([
{'t': 'spam eggs and spam'},
{'t': 'spam'},
@ -426,10 +421,8 @@ class TestCollection(unittest.TestCase):
db.test.drop_indexes()
@client_context.require_version_min(2, 3, 2)
def test_index_2dsphere(self):
if not version.at_least(self.client, (2, 3, 2)):
raise SkipTest("2dsphere indexing requires server >=2.3.2.")
db = self.db
db.test.drop_indexes()
self.assertEqual("geo_2dsphere",
@ -444,10 +437,8 @@ class TestCollection(unittest.TestCase):
db.test.drop_indexes()
@client_context.require_version_min(2, 3, 2)
def test_index_hashed(self):
if not version.at_least(self.client, (2, 3, 2)):
raise SkipTest("hashed indexing requires server >=2.3.2.")
db = self.db
db.test.drop_indexes()
self.assertEqual("a_hashed",
@ -485,7 +476,7 @@ class TestCollection(unittest.TestCase):
db = self.db
self._drop_dups_setup(db)
if version.at_least(db.connection, (1, 9, 2)):
if client_context.version.at_least(1, 9, 2):
# No error, just drop the duplicate
db.test.create_index(
[('i', ASCENDING)],
@ -589,14 +580,14 @@ class TestCollection(unittest.TestCase):
db.drop_collection("test")
db.test.save({})
expected = {}
if version.at_least(db.connection, (2, 7, 0)):
if client_context.version.at_least(2, 7, 0):
# usePowerOf2Sizes server default
expected["flags"] = 1
self.assertEqual(db.test.options(), expected)
self.assertEqual(db.test.doesnotexist.options(), {})
db.drop_collection("test")
if version.at_least(db.connection, (1, 9)):
if client_context.version.at_least(1, 9):
db.create_collection("test", capped=True, size=4096)
result = db.test.options()
# mongos 2.2.x adds an $auth field when auth is enabled.
@ -703,7 +694,7 @@ class TestCollection(unittest.TestCase):
db.test.insert({"x": [1, 2, 3], "mike": "awesome"})
self.assertEqual([1, 2, 3], db.test.find_one()["x"])
if version.at_least(db.connection, (1, 5, 1)):
if client_context.version.at_least(1, 5, 1):
self.assertEqual([2, 3],
db.test.find_one(fields={"x": {"$slice":
-2}})["x"])
@ -827,24 +818,28 @@ class TestCollection(unittest.TestCase):
)
db.drop_collection("test")
wc = db.write_concern
db.write_concern = {"w": 0}
db.test.ensure_index([('i', ASCENDING)], unique=True)
try:
db.test.ensure_index([('i', ASCENDING)], unique=True)
# No error
db.test.insert([{'i': 1}] * 2)
self.assertEqual(1, db.test.count())
# No error
db.test.insert([{'i': 1}] * 2)
self.assertEqual(1, db.test.count())
# Implied safe
self.assertRaises(
DuplicateKeyError,
lambda: db.test.insert([{'i': 2}] * 2, fsync=True),
)
# Implied safe
self.assertRaises(
DuplicateKeyError,
lambda: db.test.insert([{'i': 2}] * 2, fsync=True),
)
# Explicit safe
self.assertRaises(
DuplicateKeyError,
lambda: db.test.insert([{'i': 2}] * 2, w=1),
)
# Explicit safe
self.assertRaises(
DuplicateKeyError,
lambda: db.test.insert([{'i': 2}] * 2, w=1),
)
finally:
db.write_concern = wc
def test_insert_iterables(self):
db = self.db
@ -864,14 +859,12 @@ class TestCollection(unittest.TestCase):
itertools.repeat(None, 10)))
self.assertEqual(db.test.find().count(), 10)
@client_context.require_version_min(2, 0)
def test_insert_manipulate_false(self):
# Test three aspects of insert with manipulate=False:
# 1. The return value is None or [None] as appropriate.
# 2. _id is not set on the passed-in document object.
# 3. _id is not sent to server.
if not version.at_least(self.db.connection, (2, 0)):
raise SkipTest('Need at least MongoDB 2.0')
collection = self.db.test_insert_manipulate_false
collection.drop()
oid = ObjectId()
@ -934,28 +927,29 @@ class TestCollection(unittest.TestCase):
doc = self.db.test.find_one()
doc['a.b'] = 'c'
expected = InvalidDocument
if version.at_least(self.client, (2, 5, 4, -1)):
if client_context.version.at_least(2, 5, 4, -1):
expected = OperationFailure
self.assertRaises(expected, self.db.test.save, doc)
def test_unique_index(self):
db = self.db
db.drop_collection("test")
db.test.create_index("hello")
with client_context.client.start_request():
db.drop_collection("test")
db.test.create_index("hello")
db.test.save({"hello": "world"})
db.test.save({"hello": "mike"})
db.test.save({"hello": "world"})
self.assertFalse(db.error())
db.test.save({"hello": "world"})
db.test.save({"hello": "mike"})
db.test.save({"hello": "world"})
self.assertFalse(db.error())
db.drop_collection("test")
db.test.create_index("hello", unique=True)
db.drop_collection("test")
db.test.create_index("hello", unique=True)
db.test.save({"hello": "world"})
db.test.save({"hello": "mike"})
db.test.save({"hello": "world"}, w=0)
self.assertTrue(db.error())
db.test.save({"hello": "world"})
db.test.save({"hello": "mike"})
db.test.save({"hello": "world"}, w=0)
self.assertTrue(db.error())
def test_duplicate_key_error(self):
db = self.db
@ -976,7 +970,7 @@ class TestCollection(unittest.TestCase):
self.assertEqual(2, db.test.count())
expected_error = OperationFailure
if version.at_least(db.connection, (1, 3)):
if client_context.version.at_least(1, 3):
expected_error = DuplicateKeyError
self.assertRaises(expected_error,
@ -1012,11 +1006,9 @@ class TestCollection(unittest.TestCase):
collection.write_concern = {'wtimeout': 1000}
self.assertRaises(DuplicateKeyError, collection.insert, {'_id': 1})
@client_context.require_version_min(1, 9, 1)
def test_continue_on_error(self):
db = self.db
if not version.at_least(db.connection, (1, 9, 1)):
raise SkipTest("continue_on_error requires MongoDB >= 1.9.1")
db.drop_collection("test")
oid = db.test.insert({"one": 1})
self.assertEqual(1, db.test.count())
@ -1027,33 +1019,34 @@ class TestCollection(unittest.TestCase):
docs.append({"four": 4})
docs.append({"five": 5})
db.test.insert(docs, manipulate=False, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(1, db.test.count())
with client_context.client.start_request():
db.test.insert(docs, manipulate=False, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(1, db.test.count())
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(4, db.test.count())
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(4, db.test.count())
db.drop_collection("test")
oid = db.test.insert({"_id": oid, "one": 1}, w=0)
self.assertEqual(1, db.test.count())
docs[0].pop("_id")
docs[2]["_id"] = oid
db.drop_collection("test")
oid = db.test.insert({"_id": oid, "one": 1}, w=0)
self.assertEqual(1, db.test.count())
docs[0].pop("_id")
docs[2]["_id"] = oid
db.test.insert(docs, manipulate=False, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(3, db.test.count())
db.test.insert(docs, manipulate=False, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(3, db.test.count())
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(6, db.test.count())
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
self.assertEqual(11000, db.error()['code'])
self.assertEqual(6, db.test.count())
def test_error_code(self):
try:
self.db.test.update({}, {"$thismodifierdoesntexist": 1})
except OperationFailure as exc:
if version.at_least(self.db.connection, (1, 3)):
if client_context.version.at_least(1, 3):
self.assertTrue(exc.code in (9, 10147, 16840, 17009))
# Just check that we set the error document. Fields
# vary by MongoDB version.
@ -1078,15 +1071,16 @@ class TestCollection(unittest.TestCase):
db.test.insert, {"hello": {"a": 4, "b": 10}})
def test_safe_insert(self):
db = self.db
db.drop_collection("test")
with client_context.client.start_request():
db = self.db
db.drop_collection("test")
a = {"hello": "world"}
db.test.insert(a)
db.test.insert(a, w=0)
self.assertTrue("E11000" in db.error()["err"])
a = {"hello": "world"}
db.test.insert(a)
db.test.insert(a, w=0)
self.assertTrue("E11000" in db.error()["err"])
self.assertRaises(OperationFailure, db.test.insert, a)
self.assertRaises(OperationFailure, db.test.insert, a)
def test_update(self):
db = self.db
@ -1129,7 +1123,7 @@ class TestCollection(unittest.TestCase):
def test_update_nmodified(self):
db = self.db
db.drop_collection("test")
used_write_commands = (self.client.max_wire_version > 1)
used_write_commands = (client_context.client.max_wire_version > 1)
db.test.insert({'_id': 1})
result = db.test.update({'_id': 1}, {'$set': {'x': 1}})
@ -1145,11 +1139,9 @@ class TestCollection(unittest.TestCase):
else:
self.assertFalse('nModified' in result)
@client_context.require_version_min(1, 1, 3, -1)
def test_multi_update(self):
db = self.db
if not version.at_least(db.connection, (1, 1, 3, -1)):
raise SkipTest("multi-update requires MongoDB >= 1.1.3")
db.drop_collection("test")
db.test.save({"x": 4, "y": 3})
@ -1176,34 +1168,35 @@ class TestCollection(unittest.TestCase):
self.assertEqual(2, db.test.find_one()["count"])
def test_safe_update(self):
db = self.db
v113minus = version.at_least(db.connection, (1, 1, 3, -1))
v19 = version.at_least(db.connection, (1, 9))
with client_context.client.start_request():
db = self.db
v113minus = client_context.version.at_least(1, 1, 3, -1)
v19 = client_context.version.at_least(1, 9)
db.drop_collection("test")
db.test.create_index("x", unique=True)
db.drop_collection("test")
db.test.create_index("x", unique=True)
db.test.insert({"x": 5})
id = db.test.insert({"x": 4})
db.test.insert({"x": 5})
id = db.test.insert({"x": 4})
self.assertEqual(
None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0))
self.assertEqual(
None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0))
if v19:
self.assertTrue("E11000" in db.error()["err"])
elif v113minus:
self.assertTrue(db.error()["err"].startswith("E11001"))
else:
self.assertTrue(db.error()["err"].startswith("E12011"))
if v19:
self.assertTrue("E11000" in db.error()["err"])
elif v113minus:
self.assertTrue(db.error()["err"].startswith("E11001"))
else:
self.assertTrue(db.error()["err"].startswith("E12011"))
self.assertRaises(OperationFailure, db.test.update,
{"_id": id}, {"$inc": {"x": 1}})
self.assertRaises(OperationFailure, db.test.update,
{"_id": id}, {"$inc": {"x": 1}})
self.assertEqual(1, db.test.update({"_id": id},
{"$inc": {"x": 2}})["n"])
self.assertEqual(1, db.test.update({"_id": id},
{"$inc": {"x": 2}})["n"])
self.assertEqual(0, db.test.update({"_id": "foo"},
{"$inc": {"x": 2}})["n"])
self.assertEqual(0, db.test.update({"_id": "foo"},
{"$inc": {"x": 2}})["n"])
def test_update_with_invalid_keys(self):
self.db.drop_collection("test")
@ -1212,7 +1205,7 @@ class TestCollection(unittest.TestCase):
doc['a.b'] = 'c'
expected = InvalidDocument
if version.at_least(self.client, (2, 5, 4, -1)):
if client_context.version.at_least(2, 5, 4, -1):
expected = OperationFailure
# Replace
@ -1249,16 +1242,17 @@ class TestCollection(unittest.TestCase):
{})['n'])
def test_safe_save(self):
db = self.db
db.drop_collection("test")
db.test.create_index("hello", unique=True)
with client_context.client.start_request():
db = self.db
db.drop_collection("test")
db.test.create_index("hello", unique=True)
db.test.save({"hello": "world"})
db.test.save({"hello": "world"}, w=0)
self.assertTrue("E11000" in db.error()["err"])
db.test.save({"hello": "world"})
db.test.save({"hello": "world"}, w=0)
self.assertTrue("E11000" in db.error()["err"])
self.assertRaises(OperationFailure, db.test.save,
{"hello": "world"})
self.assertRaises(OperationFailure, db.test.save,
{"hello": "world"})
def test_safe_remove(self):
db = self.db
@ -1271,7 +1265,7 @@ class TestCollection(unittest.TestCase):
self.assertEqual(None, db.test.remove({"x": 1}, w=0))
self.assertEqual(1, db.test.count())
if version.at_least(db.connection, (1, 1, 3, -1)):
if client_context.version.at_least(1, 1, 3, -1):
self.assertRaises(OperationFailure, db.test.remove,
{"x": 1})
else: # Just test that it doesn't blow up
@ -1283,18 +1277,16 @@ class TestCollection(unittest.TestCase):
self.assertEqual(2, db.test.remove({})["n"])
self.assertEqual(0, db.test.remove({})["n"])
@client_context.require_version_min(1, 5, 1)
def test_last_error_options(self):
if not version.at_least(self.client, (1, 5, 1)):
raise SkipTest("getLastError options require MongoDB >= 1.5.1")
self.db.test.save({"x": 1}, w=1, wtimeout=1)
self.db.test.insert({"x": 1}, w=1, wtimeout=1)
self.db.test.remove({"x": 1}, w=1, wtimeout=1)
self.db.test.update({"x": 1}, {"y": 2}, w=1, wtimeout=1)
ismaster = self.client.admin.command("ismaster")
if ismaster.get("setName"):
w = len(ismaster["hosts"]) + 1
if client_context.setname:
# client_context.w is the number of hosts in the replica set
w = client_context.w + 1
self.assertRaises(WTimeoutError, self.db.test.save,
{"x": 1}, w=w, wtimeout=1)
self.assertRaises(WTimeoutError, self.db.test.insert,
@ -1314,7 +1306,7 @@ class TestCollection(unittest.TestCase):
self.fail("WTimeoutError was not raised")
# can't use fsync and j options together
if version.at_least(self.client, (1, 8, 2)):
if client_context.version.at_least(1, 8, 2):
self.assertRaises(OperationFailure, self.db.test.insert,
{"_id": 1}, j=True, fsync=True)
@ -1335,9 +1327,8 @@ class TestCollection(unittest.TestCase):
self.assertEqual(db.test.find({'foo': 'bar'}).count(), 1)
self.assertEqual(db.test.find({'foo': re.compile(r'ba.*')}).count(), 2)
@client_context.require_version_min(2, 1, 0)
def test_aggregate(self):
if not version.at_least(self.db.connection, (2, 1, 0)):
raise SkipTest("The aggregate command requires MongoDB >= 2.1.0")
db = self.db
db.drop_collection("test")
db.test.save({'foo': [1, 2]})
@ -1353,81 +1344,74 @@ class TestCollection(unittest.TestCase):
self.assertEqual(1.0, result['ok'])
self.assertEqual([{'foo': [1, 2]}], result['result'])
@client_context.require_version_min(2, 3, 2) # See SERVER-6470.
def test_aggregate_with_compile_re(self):
# See SERVER-6470.
if not version.at_least(self.db.connection, (2, 3, 2)):
raise SkipTest(
"Retrieving a regex with aggregation requires "
"MongoDB >= 2.3.2")
self.db.test.drop()
self.db.test.insert({'r': re.compile('.*')})
db = self.client.pymongo_test
db.test.drop()
db.test.insert({'r': re.compile('.*')})
result = db.test.aggregate([])
result = self.db.test.aggregate([])
self.assertTrue(isinstance(result['result'][0]['r'], RE_TYPE))
result = db.test.aggregate([], compile_re=False)
result = self.db.test.aggregate([], compile_re=False)
self.assertTrue(isinstance(result['result'][0]['r'], Regex))
@client_context.require_version_min(2, 5, 1)
def test_aggregation_cursor_validation(self):
if not version.at_least(self.db.connection, (2, 5, 1)):
raise SkipTest("Aggregation cursor requires MongoDB >= 2.5.1")
db = self.db
projection = {'$project': {'_id': '$_id'}}
cursor = db.test.aggregate(projection, cursor={})
self.assertTrue(isinstance(cursor, CommandCursor))
@client_context.require_version_min(2, 5, 1)
def test_aggregation_cursor(self):
if not version.at_least(self.db.connection, (2, 5, 1)):
raise SkipTest("Aggregation cursor requires MongoDB >= 2.5.1")
db = self.db
if self.setname:
db = MongoReplicaSetClient(host=self.client.host,
port=self.client.port,
replicaSet=self.setname)[db.name]
if client_context.setname:
db = client_context.rs_client[db.name]
# Test that getMore messages are sent to the right server.
db.read_preference = ReadPreference.SECONDARY
for collection_size in (10, 1000):
db.drop_collection("test")
db.test.insert([{'_id': i} for i in range(collection_size)],
w=self.w)
expected_sum = sum(range(collection_size))
# Use batchSize to ensure multiple getMore messages
cursor = db.test.aggregate(
{'$project': {'_id': '$_id'}},
cursor={'batchSize': 5})
try:
for collection_size in (10, 1000):
db.drop_collection("test")
db.test.insert([{'_id': i} for i in range(collection_size)],
w=self.w)
expected_sum = sum(range(collection_size))
# Use batchSize to ensure multiple getMore messages
cursor = db.test.aggregate(
{'$project': {'_id': '$_id'}},
cursor={'batchSize': 5})
self.assertEqual(
expected_sum,
sum(doc['_id'] for doc in cursor))
self.assertEqual(
expected_sum,
sum(doc['_id'] for doc in cursor))
finally:
db.read_preference = ReadPreference.PRIMARY
@client_context.require_version_min(2, 5, 5)
@client_context.require_no_mongos
def test_parallel_scan(self):
if is_mongos(self.db.connection):
raise SkipTest("mongos does not support parallel_scan")
if not version.at_least(self.db.connection, (2, 5, 5)):
raise SkipTest("Requires MongoDB >= 2.5.5")
db = self.db
db.drop_collection("test")
if self.setname:
db = MongoReplicaSetClient(host=self.client.host,
port=self.client.port,
replicaSet=self.setname)[db.name]
if client_context.setname:
db = client_context.rs_client[db.name]
# Test that getMore messages are sent to the right server.
db.read_preference = ReadPreference.SECONDARY
coll = db.test
coll.insert(({'_id': i} for i in range(8000)), w=self.w)
docs = []
threads = [threading.Thread(target=docs.extend, args=(cursor,))
for cursor in coll.parallel_scan(3)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(
set(range(8000)),
set(doc['_id'] for doc in docs))
try:
coll = db.test
coll.insert(({'_id': i} for i in range(8000)), w=self.w)
docs = []
threads = [threading.Thread(target=docs.extend, args=(cursor,))
for cursor in coll.parallel_scan(3)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(
set(range(8000)),
set(doc['_id'] for doc in docs))
finally:
db.read_preference = ReadPreference.PRIMARY
def test_group(self):
db = self.db
@ -1515,7 +1499,7 @@ class TestCollection(unittest.TestCase):
Code(reduce_function,
{"inc_value": 0.5}))[0]['count'])
if version.at_least(db.connection, (1, 1)):
if client_context.version.at_least(1, 1):
self.assertEqual(2, db.test.group([], {}, {"count": 0},
Code(reduce_function,
{"inc_value": 1}),
@ -1754,10 +1738,8 @@ class TestCollection(unittest.TestCase):
# The socket should be discarded.
self.assertEqual(0, len(socks))
@client_context.require_version_min(1, 1)
def test_distinct(self):
if not version.at_least(self.db.connection, (1, 1)):
raise SkipTest("distinct command requires MongoDB >= 1.1")
self.db.drop_collection("test")
test = self.db.test
@ -1814,11 +1796,11 @@ class TestCollection(unittest.TestCase):
def test_insert_large_document(self):
max_size = self.db.connection.max_bson_size
half_size = int(max_size / 2)
if version.at_least(self.db.connection, (1, 7, 4)):
if client_context.version.at_least(1, 7, 4):
self.assertEqual(max_size, 16777216)
expected = DocumentTooLarge
if version.at_least(self.client, (2, 5, 4, -1)):
if client_context.version.at_least(2, 5, 4, -1):
# Document too large handled by the server
expected = OperationFailure
self.assertRaises(expected, self.db.test.insert,
@ -1838,8 +1820,8 @@ class TestCollection(unittest.TestCase):
self.db.test.update({"bar": "x"}, {"bar": "x" * (max_size - 32)})
def test_insert_large_batch(self):
max_bson_size = self.client.max_bson_size
if version.at_least(self.client, (2, 5, 4, -1)):
max_bson_size = client_context.client.max_bson_size
if client_context.version.at_least(2, 5, 4, -1):
# Write commands are limited to 16MB + 16k per batch
big_string = 'x' * int(max_bson_size / 2)
else:
@ -1862,12 +1844,9 @@ class TestCollection(unittest.TestCase):
# Test that inserts fail after first error, unacknowledged.
self.db.test.drop()
self.client.start_request()
try:
with client_context.client.start_request():
self.assertTrue(self.db.test.insert(batch, w=0))
self.assertEqual(1, self.db.test.count())
finally:
self.client.end_request()
# 2 batches, 2 errors, acknowledged, continue on error
self.db.test.drop()
@ -1884,13 +1863,11 @@ class TestCollection(unittest.TestCase):
# 2 batches, 2 errors, unacknowledged, continue on error
self.db.test.drop()
self.client.start_request()
try:
self.assertTrue(self.db.test.insert(batch, continue_on_error=True, w=0))
with client_context.client.start_request():
self.assertTrue(
self.db.test.insert(batch, continue_on_error=True, w=0))
# Only the first and third documents should be inserted.
self.assertEqual(2, self.db.test.count())
finally:
self.client.end_request()
def test_numerous_inserts(self):
# Ensure we don't exceed server's 1000-document batch size limit.
@ -1935,10 +1912,8 @@ class TestCollection(unittest.TestCase):
coll.uuid_subtype = 6
self.assertEqual(doc, coll.find_one({'_id': 2}))
@client_context.require_version_min(1, 1, 1)
def test_map_reduce(self):
if not version.at_least(self.db.connection, (1, 1, 1)):
raise SkipTest("mapReduce command requires MongoDB >= 1.1.1")
db = self.db
db.drop_collection("test")
@ -1964,7 +1939,7 @@ class TestCollection(unittest.TestCase):
self.assertEqual(2, result.find_one({"_id": "dog"})["value"])
self.assertEqual(1, result.find_one({"_id": "mouse"})["value"])
if version.at_least(self.db.connection, (1, 7, 4)):
if client_context.version.at_least(1, 7, 4):
db.test.insert({"id": 5, "tags": ["hampster"]})
result = db.test.map_reduce(map, reduce, out='mrunittests')
self.assertEqual(1, result.find_one({"_id": "hampster"})["value"])
@ -1992,8 +1967,8 @@ class TestCollection(unittest.TestCase):
self.assertEqual(2, result.find_one({"_id": "dog"})["value"])
self.assertEqual(1, result.find_one({"_id": "mouse"})["value"])
if (is_mongos(self.db.connection)
and not version.at_least(self.db.connection, (2, 1, 2))):
if (client_context.is_mongos and
not client_context.version.at_least(2, 1, 2)):
pass
else:
result = db.test.map_reduce(map, reduce,
@ -2003,7 +1978,7 @@ class TestCollection(unittest.TestCase):
self.assertEqual(3, result.find_one({"_id": "cat"})["value"])
self.assertEqual(2, result.find_one({"_id": "dog"})["value"])
self.assertEqual(1, result.find_one({"_id": "mouse"})["value"])
self.client.drop_database('mrtestdb')
client_context.client.drop_database('mrtestdb')
full_result = db.test.map_reduce(map, reduce,
out='mrunittests', full_response=True)
@ -2014,7 +1989,7 @@ class TestCollection(unittest.TestCase):
self.assertEqual(1, result.find_one({"_id": "dog"})["value"])
self.assertEqual(None, result.find_one({"_id": "mouse"}))
if version.at_least(self.db.connection, (1, 7, 4)):
if client_context.version.at_least(1, 7, 4):
result = db.test.map_reduce(map, reduce, out={'inline': 1})
self.assertTrue(isinstance(result, dict))
self.assertTrue('results' in result)
@ -2077,7 +2052,7 @@ class TestCollection(unittest.TestCase):
# Starting with MongoDB 2.5.2 this is no longer possible
# from insert, update, or findAndModify.
if not version.at_least(self.db.connection, (2, 5, 2)):
if not client_context.version.at_least(2, 5, 2):
# Force insert of ref without $id.
c.insert(ref_only, check_keys=False)
self.assertEqual(DBRef('collection', id=None),
@ -2117,7 +2092,7 @@ class TestCollection(unittest.TestCase):
# Test that we raise DuplicateKeyError when appropriate.
# MongoDB doesn't have a code field for DuplicateKeyError
# from commands before 2.2.
if version.at_least(self.db.connection, (2, 2)):
if client_context.version.at_least(2, 2):
c.ensure_index('i', unique=True)
self.assertRaises(DuplicateKeyError,
c.find_and_modify, query={'i': 1, 'j': 1},
@ -2139,7 +2114,7 @@ class TestCollection(unittest.TestCase):
self.assertEqual(None,
c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}))
# The return value changed in 2.1.2. See SERVER-6226.
if version.at_least(self.db.connection, (2, 1, 2)):
if client_context.version.at_least(2, 1, 2):
self.assertEqual(None, c.find_and_modify({'_id': 1},
{'$inc': {'i': 1}},
upsert=True))
@ -2160,8 +2135,8 @@ class TestCollection(unittest.TestCase):
# Test with full_response=True
# No lastErrorObject from mongos until 2.0
if (not is_mongos(self.db.connection) and
version.at_least(self.db.connection, (2, 0))):
if (not client_context.is_mongos and
client_context.version.at_least(2, 0)):
result = c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}},
new=True, upsert=True,
full_response=True,
@ -2251,9 +2226,8 @@ class TestCollection(unittest.TestCase):
self.assertRaises(TypeError, c.find_and_modify,
{}, {'$inc': {'i': 1}}, sort=sort)
@client_context.require_version_min(2, 0, 0)
def test_find_with_nested(self):
if not version.at_least(self.db.connection, (2, 0, 0)):
raise SkipTest("nested $and and $or requires MongoDB >= 2.0")
c = self.db.test
c.drop()
c.insert([{'i': i} for i in range(5)]) # [0, 1, 2, 3, 4]
@ -2309,7 +2283,7 @@ class TestCollection(unittest.TestCase):
son['foo'] += 2
return son
db = self.client.pymongo_test
db = client_context.client.pymongo_test
db.add_son_manipulator(IncByTwo())
c = db.test
c.drop()
@ -2321,7 +2295,7 @@ class TestCollection(unittest.TestCase):
c.remove({})
def test_compile_re(self):
c = self.client.pymongo_test.test
c = self.db.test
c.drop()
c.insert({'r': re.compile('.*')})

View File

@ -26,17 +26,14 @@ from bson.son import SON
from pymongo.mongo_client import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.errors import ConfigurationError, OperationFailure
from test import pair, unittest, SkipTest, version
from test.utils import drop_collections
from test import client_context, pair, unittest
class TestCommon(unittest.TestCase):
def test_uuid_subtype(self):
self.client = MongoClient(pair)
self.db = self.client.pymongo_test
coll = self.client.pymongo_test.uuid
self.db = client_context.client.pymongo_test
coll = self.db.uuid
coll.drop()
def change_subtype(collection, subtype):
@ -102,21 +99,20 @@ class TestCommon(unittest.TestCase):
self.assertEqual(5, coll.find_one({'_id': uu})['i'])
# Test command
db = self.client.pymongo_test
no_obj_error = "No matching object found"
result = db.command('findAndModify', 'uuid',
allowable_errors=[no_obj_error],
uuid_subtype=UUID_SUBTYPE,
query={'_id': uu},
update={'$set': {'i': 6}})
result = self.db.command('findAndModify', 'uuid',
allowable_errors=[no_obj_error],
uuid_subtype=UUID_SUBTYPE,
query={'_id': uu},
update={'$set': {'i': 6}})
self.assertEqual(None, result.get('value'))
self.assertEqual(5, db.command('findAndModify', 'uuid',
update={'$set': {'i': 6}},
query={'_id': uu})['value']['i'])
self.assertEqual(6, db.command('findAndModify', 'uuid',
update={'$set': {'i': 7}},
query={'_id': UUIDLegacy(uu)}
)['value']['i'])
self.assertEqual(5, self.db.command('findAndModify', 'uuid',
update={'$set': {'i': 6}},
query={'_id': uu})['value']['i'])
self.assertEqual(6, self.db.command(
'findAndModify', 'uuid',
update={'$set': {'i': 7}},
query={'_id': UUIDLegacy(uu)})['value']['i'])
# Test (inline)_map_reduce
coll.drop()
@ -140,23 +136,23 @@ class TestCommon(unittest.TestCase):
coll.uuid_subtype = UUID_SUBTYPE
q = {"_id": uu}
if version.at_least(self.db.connection, (1, 7, 4)):
if client_context.version.at_least(1, 7, 4):
result = coll.inline_map_reduce(map, reduce, query=q)
self.assertEqual([], result)
result = coll.map_reduce(map, reduce, "results", query=q)
self.assertEqual(0, db.results.count())
self.assertEqual(0, self.db.results.count())
coll.uuid_subtype = OLD_UUID_SUBTYPE
q = {"_id": uu}
if version.at_least(self.db.connection, (1, 7, 4)):
if client_context.version.at_least(1, 7, 4):
result = coll.inline_map_reduce(map, reduce, query=q)
self.assertEqual(2, len(result))
result = coll.map_reduce(map, reduce, "results", query=q)
self.assertEqual(2, db.results.count())
self.assertEqual(2, self.db.results.count())
db.drop_collection("result")
self.db.drop_collection("result")
coll.drop()
# Test group
@ -233,13 +229,9 @@ class TestCommon(unittest.TestCase):
self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,)))
self.assertFalse(m != MongoClient("mongodb://%s/?w=0" % (pair,)))
@client_context.require_replica_set
def test_mongo_replica_set_client(self):
c = MongoClient(pair)
ismaster = c.admin.command('ismaster')
if 'setName' in ismaster:
setname = str(ismaster.get('setName'))
else:
raise SkipTest("Not connected to a replica set.")
setname = client_context.setname
m = MongoReplicaSetClient(pair, replicaSet=setname, w=0)
coll = m.pymongo_test.write_concern_test
coll.drop()
@ -249,7 +241,7 @@ class TestCommon(unittest.TestCase):
self.assertTrue(coll.insert(doc))
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = MongoReplicaSetClient(pair, replicaSet=setname)
m = client_context.rs_client
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc, w=0))
self.assertRaises(OperationFailure, coll.insert, doc)

View File

@ -31,12 +31,10 @@ from pymongo import (ASCENDING,
OFF)
from pymongo.command_cursor import CommandCursor
from pymongo.cursor_manager import CursorManager
from pymongo.database import Database
from pymongo.errors import (InvalidOperation,
OperationFailure,
ExecutionTimeout)
from test import SkipTest, unittest, version
from test.test_client import get_client
from test import client_context, SkipTest, unittest
from test.utils import is_mongos, get_command_line, server_started_with_auth
if PY3:
@ -45,17 +43,12 @@ if PY3:
class TestCursor(unittest.TestCase):
def setUp(self):
self.client = get_client()
self.db = Database(self.client, "pymongo_test")
def tearDown(self):
self.db = None
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
@client_context.require_version_min(2, 5, 3, -1)
def test_max_time_ms(self):
if not version.at_least(self.db.connection, (2, 5, 3, -1)):
raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")
db = self.db
db.pymongo_test.drop()
coll = db.pymongo_test
@ -79,11 +72,12 @@ class TestCursor(unittest.TestCase):
self.assertTrue(coll.find_one(max_time_ms=1000))
if "enableTestCommands=1" in get_command_line(self.client)["argv"]:
client = client_context.client
if "enableTestCommands=1" in client_context.cmd_line['argv']:
# Cursor parses server timeout error in response to initial query.
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")
client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")
try:
cursor = coll.find().max_time_ms(1)
try:
@ -95,27 +89,23 @@ class TestCursor(unittest.TestCase):
self.assertRaises(ExecutionTimeout,
coll.find_one, max_time_ms=1)
finally:
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="off")
client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="off")
@client_context.require_version_min(2, 5, 3, -1)
@client_context.require_test_commands
def test_max_time_ms_getmore(self):
# Test that Cursor handles server timeout error in response to getmore.
if "enableTestCommands=1" not in get_command_line(self.client)["argv"]:
raise SkipTest("Need test commands enabled")
if not version.at_least(self.db.connection, (2, 5, 3, -1)):
raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")
coll = self.db.pymongo_test
coll.insert({} for _ in range(200))
cursor = coll.find().max_time_ms(100)
# Send initial query before turning on failpoint.
next(cursor)
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")
client_context.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")
try:
try:
# Iterate up to first getmore.
@ -125,9 +115,9 @@ class TestCursor(unittest.TestCase):
else:
self.fail("ExecutionTimeout not raised")
finally:
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="off")
client_context.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="off")
def test_explain(self):
a = self.db.test.find()
@ -779,7 +769,7 @@ class TestCursor(unittest.TestCase):
self.db.test.drop()
self.db.test.save({"x": 1})
if not version.at_least(self.db.connection, (1, 1, 3, -1)):
if not client_context.version.at_least(1, 1, 3, -1):
for _ in self.db.test.find({}, ["a"]):
self.fail()
@ -875,10 +865,8 @@ class TestCursor(unittest.TestCase):
self.assertRaises(IndexError,
lambda x: self.db.test.find().skip(50)[x], 50)
@client_context.require_version_min(1, 1, 4, -1)
def test_count_with_limit_and_skip(self):
if not version.at_least(self.db.connection, (1, 1, 4, -1)):
raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")
self.assertRaises(TypeError, self.db.test.find().count, "foo")
def check_len(cursor, length):
@ -961,10 +949,8 @@ class TestCursor(unittest.TestCase):
finally:
db.drop_collection("test")
@client_context.require_version_min(1, 1, 3, 1)
def test_distinct(self):
if not version.at_least(self.db.connection, (1, 1, 3, 1)):
raise SkipTest("distinct with query requires MongoDB >= 1.1.3")
self.db.drop_collection("test")
self.db.test.save({"a": 1})
@ -990,10 +976,8 @@ class TestCursor(unittest.TestCase):
self.assertEqual(["b", "c"], distinct)
@client_context.require_version_min(1, 5, 1)
def test_max_scan(self):
if not version.at_least(self.db.connection, (1, 5, 1)):
raise SkipTest("maxScan requires MongoDB >= 1.5.1")
self.db.drop_collection("test")
for _ in range(100):
self.db.test.insert({})
@ -1018,11 +1002,9 @@ class TestCursor(unittest.TestCase):
self.assertFalse(c2.alive)
self.assertTrue(c1.alive)
@client_context.require_version_min(2, 0)
@client_context.require_no_mongos
def test_comment(self):
if is_mongos(self.client):
raise SkipTest("profile is not supported by mongos")
if not version.at_least(self.db.connection, (2, 0)):
raise SkipTest("Requires server >= 2.0")
if server_started_with_auth(self.db.connection):
raise SkipTest("SERVER-4754 - This test uses profiling.")

View File

@ -44,9 +44,8 @@ from pymongo.errors import (CollectionInvalid,
from pymongo.son_manipulator import (AutoReference,
NamespaceInjector,
ObjectIdShuffler)
from test import SkipTest, unittest, version
from test.utils import (get_command_line, is_mongos,
remove_all_users, server_started_with_auth)
from test import client_context, SkipTest, unittest
from test.utils import remove_all_users, server_started_with_auth
from test.test_client import get_client
if PY3:
@ -55,11 +54,9 @@ if PY3:
class TestDatabase(unittest.TestCase):
def setUp(self):
self.client = get_client()
def tearDown(self):
self.client = None
@classmethod
def setUpClass(cls):
cls.client = client_context.client
def test_name(self):
self.assertRaises(TypeError, Database, self.client, 4)
@ -112,7 +109,7 @@ class TestDatabase(unittest.TestCase):
db.create_collection("test.foo")
self.assertTrue(u("test.foo") in db.collection_names())
expected = {}
if version.at_least(self.client, (2, 7, 0)):
if client_context.version.at_least(2, 7, 0):
# usePowerOf2Sizes server default
expected["flags"] = 1
result = db.test.foo.options()
@ -185,9 +182,8 @@ class TestDatabase(unittest.TestCase):
self.assertTrue(db.validate_collection(db.test, scandata=True, full=True))
self.assertTrue(db.validate_collection(db.test, True, True))
@client_context.require_no_mongos
def test_profiling_levels(self):
if is_mongos(self.client):
raise SkipTest('profile is not supported by mongos')
db = self.client.pymongo_test
self.assertEqual(db.profiling_level(), OFF) # default
@ -215,9 +211,8 @@ class TestDatabase(unittest.TestCase):
db.set_profiling_level(OFF, 100) # back to default
self.assertEqual(100, db.command("profile", -1)['slowms'])
@client_context.require_no_mongos
def test_profiling_info(self):
if is_mongos(self.client):
raise SkipTest('profile is not supported by mongos')
db = self.client.pymongo_test
db.set_profiling_level(ALL)
@ -236,7 +231,7 @@ class TestDatabase(unittest.TestCase):
self.assertTrue(len(info) >= 1)
# These basically clue us in to server changes.
if version.at_least(db.connection, (1, 9, 1, -1)):
if client_context.version.at_least(1, 9, 1, -1):
self.assertTrue(isinstance(info[0]['responseLength'], int))
self.assertTrue(isinstance(info[0]['millis'], int))
self.assertTrue(isinstance(info[0]['client'], string_type))
@ -256,9 +251,8 @@ class TestDatabase(unittest.TestCase):
self.assertRaises(TypeError, iterate)
@client_context.require_no_mongos
def test_errors(self):
if is_mongos(self.client):
raise SkipTest('getpreverror not supported by mongos')
db = self.client.pymongo_test
db.reset_error_history()
@ -296,15 +290,11 @@ class TestDatabase(unittest.TestCase):
self.assertEqual(db.command("buildinfo"), db.command({"buildinfo": 1}))
# We use 'aggregate' as our example command, since it's an easy way to
# retrieve a BSON regex from a collection using a command. But until
# MongoDB 2.3.2, aggregation turned regexes into strings: SERVER-6470.
@client_context.require_version_min(2, 3, 2)
def test_command_with_compile_re(self):
# We use 'aggregate' as our example command, since it's an easy way to
# retrieve a BSON regex from a collection using a command. But until
# MongoDB 2.3.2, aggregation turned regexes into strings: SERVER-6470.
if not version.at_least(self.client, (2, 3, 2)):
raise SkipTest(
"Retrieving a regex with aggregation requires "
"MongoDB >= 2.3.2")
db = self.client.pymongo_test
db.test.drop()
db.test.insert({'r': re.compile('.*')})
@ -317,14 +307,15 @@ class TestDatabase(unittest.TestCase):
def test_last_status(self):
db = self.client.pymongo_test
db.test.remove({})
db.test.save({"i": 1})
with self.client.start_request():
db.test.remove({})
db.test.save({"i": 1})
db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0)
self.assertTrue(db.last_status()["updatedExisting"])
db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0)
self.assertTrue(db.last_status()["updatedExisting"])
db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0)
self.assertFalse(db.last_status()["updatedExisting"])
db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0)
self.assertFalse(db.last_status()["updatedExisting"])
def test_password_digest(self):
self.assertRaises(TypeError, auth._password_digest, 5)
@ -340,13 +331,8 @@ class TestDatabase(unittest.TestCase):
self.assertEqual(auth._password_digest("Gustave", u("Dor\xe9")),
u("81e0e2364499209f466e75926a162d73"))
@client_context.require_auth
def test_authenticate_add_remove_user(self):
if (is_mongos(self.client) and not
version.at_least(self.client, (2, 0, 0))):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0")
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
db = self.client.pymongo_test
# Configuration errors
@ -357,7 +343,7 @@ class TestDatabase(unittest.TestCase):
self.assertRaises(ConfigurationError, db.add_user,
"user", 'password', True, roles=['read'])
if version.at_least(self.client, (2, 5, 3, -1)):
if client_context.version.at_least(2, 5, 3, -1):
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
self.assertRaises(DeprecationWarning, db.add_user,
@ -400,7 +386,7 @@ class TestDatabase(unittest.TestCase):
db.authenticate, "Gustave", u("Dor\xe9"))
self.assertTrue(db.authenticate("Gustave", u("password")))
if not version.at_least(self.client, (2, 5, 3, -1)):
if not client_context.version.at_least(2, 5, 3, -1):
# Add a readOnly user
db.add_user("Ross", "password", read_only=True)
db.logout()
@ -411,17 +397,13 @@ class TestDatabase(unittest.TestCase):
# Cleanup
finally:
remove_all_users(db)
self.client.pymongo_test.logout()
self.client.admin.remove_user("admin")
self.client.admin.logout()
self.client.disconnect()
@client_context.require_auth
def test_make_user_readonly(self):
if (is_mongos(self.client)
and not version.at_least(self.client, (2, 0, 0))):
raise SkipTest('Auth with sharding requires MongoDB >= 2.0.0')
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
admin = self.client.admin
admin.add_user('admin', 'pw')
admin.authenticate('admin', 'pw')
@ -451,13 +433,12 @@ class TestDatabase(unittest.TestCase):
remove_all_users(db)
admin.remove_user("admin")
admin.logout()
db.logout()
self.client.disconnect()
@client_context.require_version_min(2, 5, 3, -1)
@client_context.require_auth
def test_default_roles(self):
if not version.at_least(self.client, (2, 5, 3, -1)):
raise SkipTest("Default roles only exist in MongoDB >= 2.5.3")
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
# "Admin" user
db = self.client.admin
db.add_user('admin', 'pass')
@ -503,14 +484,11 @@ class TestDatabase(unittest.TestCase):
db.authenticate('user', 'pass')
remove_all_users(db)
db.logout()
self.client.disconnect()
@client_context.require_version_min(2, 5, 3, -1)
@client_context.require_auth
def test_new_user_cmds(self):
if not version.at_least(self.client, (2, 5, 3, -1)):
raise SkipTest("User manipulation through commands "
"requires MongoDB >= 2.5.3")
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
db = self.client.pymongo_test
db.add_user("amalia", "password", roles=["userAdmin"])
db.authenticate("amalia", "password")
@ -527,13 +505,10 @@ class TestDatabase(unittest.TestCase):
finally:
db.remove_user("amalia")
db.logout()
self.client.disconnect()
@client_context.require_auth
def test_authenticate_and_safe(self):
if (is_mongos(self.client) and not
version.at_least(self.client, (2, 0, 0))):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0")
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
db = self.client.auth_test
db.add_user("bernie", "password",
@ -555,14 +530,10 @@ class TestDatabase(unittest.TestCase):
finally:
db.remove_user("bernie")
db.logout()
self.client.disconnect()
@client_context.require_auth
def test_authenticate_and_request(self):
if (is_mongos(self.client) and not
version.at_least(self.client, (2, 0, 0))):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0")
if not server_started_with_auth(self.client):
raise SkipTest('Authentication is not enabled on server')
# Database.authenticate() needs to be in a request - check that it
# always runs in a request, and that it restores the request state
# (in or not in a request) properly when it's finished.
@ -584,19 +555,14 @@ class TestDatabase(unittest.TestCase):
db.remove_user("mike")
db.logout()
request_db.logout()
self.client.disconnect()
@client_context.require_auth
def test_authenticate_multiple(self):
client = get_client()
if (is_mongos(client) and not
version.at_least(self.client, (2, 2, 0))):
raise SkipTest("Need mongos >= 2.2.0")
if not server_started_with_auth(client):
raise SkipTest("Authentication is not enabled on server")
# Setup
users_db = client.pymongo_test
admin_db = client.admin
other_db = client.pymongo_test1
users_db = self.client.pymongo_test
admin_db = self.client.admin
other_db = self.client.pymongo_test1
users_db.test.remove()
other_db.test.remove()
@ -606,7 +572,7 @@ class TestDatabase(unittest.TestCase):
try:
self.assertTrue(admin_db.authenticate('admin', 'pass'))
if version.at_least(self.client, (2, 5, 3, -1)):
if client_context.version.at_least(2, 5, 3, -1):
admin_db.add_user('ro-admin', 'pass',
roles=["userAdmin", "readAnyDatabase"])
else:
@ -632,7 +598,7 @@ class TestDatabase(unittest.TestCase):
other_db.test.insert, {})
# Force close all sockets
client.disconnect()
self.client.disconnect()
# We should still be able to write to the regular user's db
self.assertTrue(users_db.test.remove())
@ -644,11 +610,11 @@ class TestDatabase(unittest.TestCase):
# Cleanup
finally:
admin_db.logout()
users_db.logout()
admin_db.authenticate('admin', 'pass')
remove_all_users(users_db)
remove_all_users(admin_db)
admin_db.logout()
users_db.logout()
self.client.disconnect()
def test_id_ordering(self):
# PyMongo attempts to have _id show up first
@ -718,7 +684,7 @@ class TestDatabase(unittest.TestCase):
# TODO some of these tests belong in the collection level testing.
def test_save_find_one(self):
db = Database(self.client, "pymongo_test")
db = self.client.pymongo_test
db.test.remove({})
a_doc = SON({"hello": u("world")})
@ -873,7 +839,7 @@ class TestDatabase(unittest.TestCase):
del db.system_js['add']
self.assertEqual(0, db.system.js.count())
if version.at_least(db.connection, (1, 3, 2, -1)):
if client_context.version.at_least(1, 3, 2, -1):
self.assertRaises(OperationFailure, db.system_js.add, 1, 5)
# TODO right now CodeWScope doesn't work w/ system js
@ -883,7 +849,7 @@ class TestDatabase(unittest.TestCase):
self.assertRaises(OperationFailure, db.system_js.non_existant)
# XXX: Broken in V8, works in SpiderMonkey
if not version.at_least(db.connection, (2, 3, 0)):
if not client_context.version.at_least(2, 3, 0):
db.system_js.no_param = Code("return 5;")
self.assertEqual(5, db.system_js.no_param())
@ -939,17 +905,15 @@ class TestDatabase(unittest.TestCase):
self.assertRaises(UserWarning, self.client.pymongo_test.command,
'ping', read_preference=ReadPreference.SECONDARY)
try:
self.client.pymongo_test.command('dbStats',
self.client.pymongo_test.command(
'dbStats',
read_preference=ReadPreference.SECONDARY_PREFERRED)
except UserWarning:
self.fail("Shouldn't have raised UserWarning.")
@client_context.require_version_min(2, 5, 3, -1)
@client_context.require_test_commands
def test_command_max_time_ms(self):
if not version.at_least(self.client, (2, 5, 3, -1)):
raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")
if "enableTestCommands=1" not in get_command_line(self.client)["argv"]:
raise SkipTest("Test commands must be enabled.")
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")

View File

@ -35,20 +35,19 @@ from gridfs.errors import (NoFile,
UnsupportedAPI)
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from test.test_client import get_client
from test import qcheck, unittest
from test import client_context, qcheck, unittest
class TestGridFile(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
def setUp(self):
self.db = get_client().pymongo_test
self.db.fs.files.remove({})
self.db.fs.chunks.remove({})
def tearDown(self):
self.db = None
def test_basic(self):
f = GridIn(self.db.fs, filename="test")
f.write(b"hello world")

View File

@ -32,8 +32,7 @@ import gridfs
from bson.py3compat import u, StringIO, string_type
from gridfs.errors import (FileExists,
NoFile)
from test import unittest
from test.test_client import get_client
from test import client_context, unittest
from test.utils import joinall
@ -71,17 +70,17 @@ class JustRead(threading.Thread):
class TestGridfs(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
cls.fs = gridfs.GridFS(cls.db)
cls.alt = gridfs.GridFS(cls.db, "alt")
def setUp(self):
self.db = get_client().pymongo_test
self.db.drop_collection("fs.files")
self.db.drop_collection("fs.chunks")
self.db.drop_collection("alt.files")
self.db.drop_collection("alt.chunks")
self.fs = gridfs.GridFS(self.db)
self.alt = gridfs.GridFS(self.db, "alt")
def tearDown(self):
self.db = self.fs = self.alt = None
def test_gridfs(self):
self.assertRaises(TypeError, gridfs.GridFS, "foo")
@ -278,17 +277,14 @@ class TestGridfs(unittest.TestCase):
self.assertEqual(b"hello world", self.fs.get(oid).read())
def test_file_exists(self):
db = get_client(w=1).pymongo_test
fs = gridfs.GridFS(db)
oid = self.fs.put(b"hello")
self.assertRaises(FileExists, self.fs.put, b"world", _id=oid)
oid = fs.put(b"hello")
self.assertRaises(FileExists, fs.put, b"world", _id=oid)
one = fs.new_file(_id=123)
one = self.fs.new_file(_id=123)
one.write(b"some content")
one.close()
two = fs.new_file(_id=123)
two = self.fs.new_file(_id=123)
self.assertRaises(FileExists, two.write, b'x' * 262146)
def test_exists(self):
@ -446,10 +442,9 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase):
self.assertRaises(ConnectionFailure, fs.put, 'data')
def tearDown(self):
rsc = self._get_client()
rsc = client_context.rs_client
rsc.pymongo_test.drop_collection('fs.files')
rsc.pymongo_test.drop_collection('fs.chunks')
rsc.close()
if __name__ == "__main__":

View File

@ -33,23 +33,21 @@ from bson.son import RE_TYPE
from bson.timestamp import Timestamp
from bson.tz_util import utc
from test import SkipTest, unittest
from test.test_client import get_client
from test import client_context, SkipTest, unittest
PY3 = sys.version_info[0] == 3
class TestJsonUtil(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
def setUp(self):
if not json_util.json_lib:
raise SkipTest("No json or simplejson module")
self.db = get_client().pymongo_test
def tearDown(self):
self.db = None
def round_tripped(self, doc):
return json_util.loads(json_util.dumps(doc))

View File

@ -30,9 +30,10 @@ from pymongo.mongo_client import MongoClient
from pymongo.pool import Pool, NO_REQUEST, NO_SOCKET_YET, SocketInfo
from pymongo.errors import ConfigurationError, ConnectionFailure
from pymongo.errors import ExceededMaxWaiters
from test import version, host, port, SkipTest
from test import host, port, SkipTest
from test.test_client import get_client
from test.utils import delay, is_mongos, one, get_pool
from test.version import Version
N = 10
DB = "pymongo-pooling-tests"
@ -1191,7 +1192,7 @@ class _TestPoolSocketSharing(_TestPoolingBase):
# Javascript function that pauses N seconds per document
fn = delay(10)
if (is_mongos(db.connection) or not
version.at_least(db.connection, (1, 7, 2))):
Version.from_client(db.connection).at_least(1, 7, 2)):
# mongos doesn't support eval so we have to use $where
# which is less reliable in this context.
self.assertEqual(1, db.test.find({"$where": fn}).count())

View File

@ -14,20 +14,18 @@
"""Test the pymongo module itself."""
import os
import sys
sys.path[0:0] = [""]
import pymongo
from test import host, port, unittest
from test import unittest
class TestPyMongo(unittest.TestCase):
def test_mongo_client_alias(self):
# Testing that pymongo module imports mongo_client.MongoClient
c = pymongo.MongoClient(host, port)
self.assertEqual(c.host, host)
self.assertEqual(c.port, port)
self.assertEqual(pymongo.MongoClient,
pymongo.mongo_client.MongoClient)
if __name__ == "__main__":

View File

@ -32,21 +32,21 @@ from pymongo.errors import ConfigurationError
from test.test_replica_set_client import TestReplicaSetClientBase
from test.test_client import get_client
from test import host, port, SkipTest, unittest, utils, version
from test import client_context, host, port, SkipTest, unittest, utils
from test.version import Version
class TestReadPreferencesBase(TestReplicaSetClientBase):
def setUp(self):
super(TestReadPreferencesBase, self).setUp()
# Insert some data so we can use cursors in read_from_which_host
c = self._get_client()
c.pymongo_test.test.drop()
c.pymongo_test.test.insert([{'_id': i} for i in range(10)], w=self.w)
client_context.client.pymongo_test.test.drop()
client_context.client.pymongo_test.test.insert(
[{'_id': i} for i in range(10)], w=self.w)
def tearDown(self):
super(TestReadPreferencesBase, self).tearDown()
c = self._get_client()
c.pymongo_test.test.drop()
client_context.client.pymongo_test.test.drop()
def read_from_which_host(self, client):
"""Do a find() on the client and return which host was used
@ -204,6 +204,7 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# Effectively ignore members' ping times so we can test the effect
# of ReadPreference modes only
acceptableLatencyMS=1000*1000)
self.client_version = Version.from_client(self.c)
def tearDown(self):
# We create a lot of collections and indexes in these tests, so drop
@ -330,14 +331,14 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
('geoSearch', 'test'), ('near', [33, 33]), ('maxDistance', 6),
('search', {'type': 'restaurant'}), ('limit', 30)])))
if version.at_least(self.c, (2, 1, 0)):
if self.client_version.at_least(2, 1, 0):
self._test_fn(True, lambda: self.c.pymongo_test.command(SON([
('aggregate', 'test'),
('pipeline', [])
])))
# Text search.
if version.at_least(self.c, (2, 3, 2)):
if self.client_version.at_least(2, 3, 2):
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
@ -394,10 +395,8 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
('out', {'inline': True})
])))
@client_context.require_version_min(2, 5, 2)
def test_aggregate_command_with_out(self):
if not version.at_least(self.c, (2, 5, 2)):
raise SkipTest("Aggregation with $out requires MongoDB >= 2.5.2")
# Tests aggregate command when pipeline contains $out.
self.c.pymongo_test.test.insert({"x": 1, "y": 1}, w=self.w)
self.c.pymongo_test.test.insert({"x": 1, "y": 2}, w=self.w)
@ -479,7 +478,7 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
lambda: self.c.pymongo_test.test.find().distinct('a'))
def test_aggregate(self):
if version.at_least(self.c, (2, 1, 0)):
if self.client_version.at_least(2, 1, 0):
self._test_fn(True,
lambda: self.c.pymongo_test.test.aggregate(
[{'$project': {'_id': 1}}]))

View File

@ -29,7 +29,6 @@ sys.path[0:0] = [""]
from bson.py3compat import thread, u, _unicode
from bson.son import SON
from bson.tz_util import utc
from pymongo.mongo_client import MongoClient
from pymongo.read_preferences import ReadPreference, Secondary, Nearest
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.mongo_replica_set_client import _partition_node, have_gevent
@ -40,13 +39,14 @@ from pymongo.errors import (AutoReconnect,
ConnectionFailure,
InvalidName,
OperationFailure, InvalidOperation)
from test import pair, port, SkipTest, unittest, version
from test import client_context, pair, port, SkipTest, unittest
from test.pymongo_mocks import MockReplicaSetClient
from test.utils import (
delay, assertReadFrom, assertReadFromAll, read_from_which_host,
remove_all_users, assertRaisesExactly, TestRequestMixin, one,
server_started_with_auth, pools_from_rs_client, get_pool,
_TestLazyConnectMixin)
from test.version import Version
class TestReplicaSetClientAgainstStandalone(unittest.TestCase):
@ -54,9 +54,7 @@ class TestReplicaSetClientAgainstStandalone(unittest.TestCase):
but only if the database at DB_IP and DB_PORT is a standalone.
"""
def setUp(self):
client = MongoClient(pair)
response = client.admin.command('ismaster')
if 'setName' in response:
if client_context.setname:
raise SkipTest("Connected to a replica set, not a standalone mongod")
def test_connect(self):
@ -66,32 +64,29 @@ class TestReplicaSetClientAgainstStandalone(unittest.TestCase):
class TestReplicaSetClientBase(unittest.TestCase):
def setUp(self):
client = MongoClient(pair)
response = client.admin.command('ismaster')
if 'setName' in response:
self.name = str(response['setName'])
self.w = len(response['hosts'])
self.hosts = set([_partition_node(h)
for h in response["hosts"]])
self.arbiters = set([_partition_node(h)
for h in response.get("arbiters", [])])
repl_set_status = client.admin.command('replSetGetStatus')
primary_info = [
m for m in repl_set_status['members']
if m['stateStr'] == 'PRIMARY'
][0]
@classmethod
@client_context.require_replica_set
def setUpClass(cls):
cls.name = client_context.setname
ismaster = client_context.ismaster
cls.w = client_context.w
cls.hosts = set(_partition_node(h) for h in ismaster['hosts'])
cls.arbiters = set(_partition_node(h)
for h in ismaster.get("arbiters", []))
self.primary = _partition_node(primary_info['name'])
self.secondaries = [
_partition_node(m['name']) for m in repl_set_status['members']
if m['stateStr'] == 'SECONDARY'
]
else:
raise SkipTest("Not connected to a replica set")
repl_set_status = client_context.client.admin.command(
'replSetGetStatus')
primary_info = [
m for m in repl_set_status['members']
if m['stateStr'] == 'PRIMARY'
][0]
super(TestReplicaSetClientBase, self).setUp()
cls.primary = _partition_node(primary_info['name'])
cls.secondaries = [
_partition_node(m['name']) for m in repl_set_status['members']
if m['stateStr'] == 'SECONDARY'
]
def _get_client(self, **kwargs):
return MongoReplicaSetClient(pair,
@ -138,7 +133,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertTrue(c.primary)
self.assertTrue(c.secondaries)
if version.at_least(c, (2, 5, 4, -1)):
if Version.from_client(c).at_least(2, 5, 4, -1):
self.assertTrue(c.max_wire_version > 0)
else:
self.assertEqual(c.max_wire_version, 0)
@ -169,10 +164,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
@client_context.require_auth
def test_init_disconnected_with_auth(self):
c = self._get_client()
if not server_started_with_auth(c):
raise SkipTest('Authentication is not enabled on server')
c = client_context.rs_client
c.admin.add_user("admin", "pass")
c.admin.authenticate("admin", "pass")
@ -199,6 +193,8 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
# Clean up.
remove_all_users(c.pymongo_test)
remove_all_users(c.admin)
c.admin.logout()
c.disconnect()
def test_connect(self):
assertRaisesExactly(ConnectionFailure, MongoReplicaSetClient,
@ -210,7 +206,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertTrue(MongoReplicaSetClient(pair, replicaSet=self.name))
def test_repr(self):
client = self._get_client()
client = client_context.rs_client
# Quirk: the RS client makes a frozenset of hosts from a dict's keys,
# so we must do the same to achieve the same order.
@ -223,7 +219,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
"MongoReplicaSetClient([%s])" % hosts_repr)
def test_properties(self):
c = MongoReplicaSetClient(pair, replicaSet=self.name)
c = client_context.rs_client
c.admin.command('ping')
self.assertEqual(c.primary, self.primary)
self.assertEqual(c.hosts, self.hosts)
@ -240,7 +236,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
cursor = c.pymongo_test.test.find()
self.assertEqual(
ReadPreference.PRIMARY, cursor._Cursor__read_preference)
c.close()
tag_sets = [{'dc': 'la', 'rack': '2'}, {'foo': 'bar'}]
secondary = Secondary(tag_sets)
@ -268,7 +263,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertEqual(
nearest, cursor._Cursor__read_preference)
if version.at_least(c, (1, 7, 4)):
if Version.from_client(c).at_least(1, 7, 4):
self.assertEqual(c.max_bson_size, 16777216)
else:
self.assertEqual(c.max_bson_size, 4194304)
@ -283,7 +278,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
pair, replicaSet=self.name, use_greenlets=True).use_greenlets)
def test_get_db(self):
client = self._get_client()
client = client_context.rs_client
def make_db(base, name):
return base[name]
@ -298,10 +293,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertTrue(isinstance(client.test, Database))
self.assertEqual(client.test, client["test"])
self.assertEqual(client.test, Database(client, "test"))
client.close()
def test_auto_reconnect_exception_when_read_preference_is_secondary(self):
c = self._get_client()
c = client_context.rs_client
db = c.pymongo_test
def raise_socket_error(*args, **kwargs):
@ -316,12 +310,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
finally:
socket.socket.sendall = old_sendall
@client_context.require_auth
def test_lazy_auth_raises_operation_failure(self):
# Check if we have the prerequisites to run this test.
c = self._get_client()
if not server_started_with_auth(c):
raise SkipTest('Authentication is not enabled on server')
lazy_client = MongoReplicaSetClient(
"mongodb://user:wrong@%s/pymongo_test" % pair,
replicaSet=self.name,
@ -331,7 +322,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
OperationFailure, lazy_client.test.collection.find_one)
def test_operations(self):
c = self._get_client()
c = client_context.rs_client
# Check explicitly for a case we've commonly hit in tests:
# a replica set is started with a tiny oplog, a previous
@ -367,10 +358,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
db.test.remove({})
self.assertEqual(0, db.test.count())
db.test.drop()
c.close()
def test_database_names(self):
client = self._get_client()
client = client_context.rs_client
client.pymongo_test.test.save({"dummy": u("object")})
client.pymongo_test_mike.test.save({"dummy": u("object")})
@ -378,10 +368,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
dbs = client.database_names()
self.assertTrue("pymongo_test" in dbs)
self.assertTrue("pymongo_test_mike" in dbs)
client.close()
def test_drop_database(self):
client = self._get_client()
client = client_context.rs_client
self.assertRaises(TypeError, client.drop_database, 5)
self.assertRaises(TypeError, client.drop_database, None)
@ -399,37 +388,33 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
client.drop_database(client.pymongo_test)
dbs = client.database_names()
self.assertTrue("pymongo_test" not in dbs)
client.close()
def test_copy_db(self):
c = self._get_client()
c = client_context.rs_client
# We test copy twice; once starting in a request and once not. In
# either case the copy should succeed (because it starts a request
# internally) and should leave us in the same state as before the copy.
c.start_request()
with c.start_request():
self.assertRaises(TypeError, c.copy_database, 4, "foo")
self.assertRaises(TypeError, c.copy_database, "foo", 4)
self.assertRaises(TypeError, c.copy_database, 4, "foo")
self.assertRaises(TypeError, c.copy_database, "foo", 4)
self.assertRaises(InvalidName, c.copy_database, "foo", "$foo")
self.assertRaises(InvalidName, c.copy_database, "foo", "$foo")
c.pymongo_test.test.drop()
c.drop_database("pymongo_test1")
c.drop_database("pymongo_test2")
c.pymongo_test.test.drop()
c.drop_database("pymongo_test1")
c.drop_database("pymongo_test2")
c.pymongo_test.test.insert({"foo": "bar"})
c.pymongo_test.test.insert({"foo": "bar"})
self.assertFalse("pymongo_test1" in c.database_names())
self.assertFalse("pymongo_test2" in c.database_names())
self.assertFalse("pymongo_test1" in c.database_names())
self.assertFalse("pymongo_test2" in c.database_names())
c.copy_database("pymongo_test", "pymongo_test1")
# copy_database() didn't accidentally end the request
self.assertTrue(c.in_request())
c.copy_database("pymongo_test", "pymongo_test1")
# copy_database() didn't accidentally end the request
self.assertTrue(c.in_request())
self.assertTrue("pymongo_test1" in c.database_names())
self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"])
c.end_request()
self.assertTrue("pymongo_test1" in c.database_names())
self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"])
self.assertFalse(c.in_request())
c.copy_database("pymongo_test", "pymongo_test2", pair)
@ -441,7 +426,8 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertTrue("pymongo_test2" in c.database_names())
self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"])
if version.at_least(c, (1, 3, 3, 1)) and server_started_with_auth(c):
if (Version.from_client(c).at_least(1, 3, 3, 1) and
server_started_with_auth(c)):
c.drop_database("pymongo_test1")
c.admin.add_user("admin", "password")
@ -469,7 +455,9 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
# Cleanup
remove_all_users(c.pymongo_test)
c.admin.remove_user("admin")
c.close()
c.admin.logout()
c.pymongo_test.logout()
c.disconnect()
def test_get_default_database(self):
host = one(self.hosts)
@ -498,16 +486,15 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_iteration(self):
client = self._get_client()
client = client_context.rs_client
def iterate():
[a for a in client]
self.assertRaises(TypeError, iterate)
client.close()
def test_disconnect(self):
c = self._get_client()
c = client_context.rs_client
coll = c.pymongo_test.bar
c.disconnect()
@ -553,7 +540,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
except ImportError:
raise SkipTest("No multiprocessing module")
client = self._get_client()
client = client_context.rs_client
def f(pipe):
try:
@ -584,7 +571,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
pass
def test_document_class(self):
c = self._get_client()
c = client_context.rs_client
db = c.pymongo_test
db.test.insert({"x": 1})
@ -594,10 +581,12 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
c.document_class = SON
self.assertEqual(SON, c.document_class)
self.assertTrue(isinstance(db.test.find_one(), SON))
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
c.close()
try:
self.assertEqual(SON, c.document_class)
self.assertTrue(isinstance(db.test.find_one(), SON))
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
finally:
c.document_class = dict
c = self._get_client(document_class=SON)
db = c.pymongo_test
@ -633,7 +622,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self._get_client, socketTimeoutMS='foo')
def test_socket_timeout(self):
no_timeout = self._get_client()
no_timeout = client_context.rs_client
timeout_sec = 1
timeout = self._get_client(socketTimeoutMS=timeout_sec*1000)
@ -666,7 +655,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
else:
self.fail('RS client should have raised timeout error')
no_timeout.close()
timeout.close()
def test_timeout_does_not_mark_member_down(self):
@ -715,7 +703,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
tz_aware='foo', replicaSet=self.name)
aware = self._get_client(tz_aware=True)
naive = self._get_client()
naive = client_context.rs_client
aware.pymongo_test.drop_collection("test")
now = datetime.datetime.utcnow()
@ -809,7 +797,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
# Test fix for PYTHON-294 -- make sure client closes its socket if it
# gets an interrupt while waiting to recv() from it.
c = self._get_client()
c = client_context.rs_client
db = c.pymongo_test
# A $where clause which takes 1.5 sec to execute
@ -932,14 +920,12 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertInRequestAndSameSock(client, pools)
client.close()
client = self._get_client()
client = client_context.rs_client
pools = pools_from_rs_client(client)
self.assertNotInRequestAndDifferentSock(client, pools)
client.start_request()
self.assertInRequestAndSameSock(client, pools)
client.end_request()
with client.start_request():
self.assertInRequestAndSameSock(client, pools)
self.assertNotInRequestAndDifferentSock(client, pools)
client.close()
def test_nested_request(self):
client = self._get_client(auto_start_request=True)
@ -984,48 +970,46 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
client.close()
def test_request_threads(self):
client = self._get_client()
try:
pools = pools_from_rs_client(client)
self.assertNotInRequestAndDifferentSock(client, pools)
client = client_context.rs_client
started_request, ended_request = threading.Event(), threading.Event()
checked_request = threading.Event()
thread_done = [False]
pools = pools_from_rs_client(client)
self.assertNotInRequestAndDifferentSock(client, pools)
# Starting a request in one thread doesn't put the other thread in a
# request
def f():
self.assertNotInRequestAndDifferentSock(client, pools)
client.start_request()
self.assertInRequestAndSameSock(client, pools)
started_request.set()
checked_request.wait()
checked_request.clear()
self.assertInRequestAndSameSock(client, pools)
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pools)
ended_request.set()
checked_request.wait()
thread_done[0] = True
started_request, ended_request = threading.Event(), threading.Event()
checked_request = threading.Event()
thread_done = [False]
t = threading.Thread(target=f)
t.setDaemon(True)
t.start()
started_request.wait()
# Starting a request in one thread doesn't put the other thread in a
# request
def f():
self.assertNotInRequestAndDifferentSock(client, pools)
checked_request.set()
ended_request.wait()
client.start_request()
self.assertInRequestAndSameSock(client, pools)
started_request.set()
checked_request.wait()
checked_request.clear()
self.assertInRequestAndSameSock(client, pools)
client.end_request()
self.assertNotInRequestAndDifferentSock(client, pools)
checked_request.set()
t.join()
self.assertNotInRequestAndDifferentSock(client, pools)
self.assertTrue(thread_done[0], "Thread didn't complete")
finally:
client.close()
ended_request.set()
checked_request.wait()
thread_done[0] = True
t = threading.Thread(target=f)
t.setDaemon(True)
t.start()
started_request.wait()
self.assertNotInRequestAndDifferentSock(client, pools)
checked_request.set()
ended_request.wait()
self.assertNotInRequestAndDifferentSock(client, pools)
checked_request.set()
t.join()
self.assertNotInRequestAndDifferentSock(client, pools)
self.assertTrue(thread_done[0], "Thread didn't complete")
def test_schedule_refresh(self):
client = self._get_client()
client = client_context.rs_client
new_rs_state = rs_state = client._MongoReplicaSetClient__rs_state
for host in rs_state.hosts:
new_rs_state = new_rs_state.clone_with_host_down(host, 'error!')
@ -1037,8 +1021,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.w, len(rs_state.members),
"MongoReplicaSetClient didn't detect members are up")
client.close()
def test_pinned_member(self):
latency = 1000 * 1000
client = self._get_client(acceptablelatencyms=latency)
@ -1081,7 +1063,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
ReadPreference.NEAREST, None)
def test_alive(self):
client = self._get_client()
client = client_context.rs_client
self.assertTrue(client.alive())
client = MongoReplicaSetClient(

View File

@ -19,22 +19,18 @@ import sys
sys.path[0:0] = [""]
from bson.son import SON
from pymongo.database import Database
from pymongo.son_manipulator import (NamespaceInjector,
ObjectIdInjector,
ObjectIdShuffler,
SONManipulator)
from test.test_client import get_client
from test import qcheck, unittest
from test import client_context, qcheck, unittest
class TestSONManipulator(unittest.TestCase):
def setUp(self):
self.db = Database(get_client(), "pymongo_test")
def tearDown(self):
self.db = None
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
def test_basic(self):
manip = SONManipulator()

View File

@ -37,8 +37,9 @@ from pymongo.common import HAS_SSL
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
OperationFailure)
from test import host, pair, port, SkipTest, unittest, version
from test import host, pair, port, SkipTest, unittest
from test.utils import server_started_with_auth, remove_all_users
from test.version import Version
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
@ -75,6 +76,8 @@ def is_server_resolvable():
finally:
socket.setdefaulttimeout(socket_timeout)
# Shared ssl-enabled client for the tests
ssl_client = None
if HAS_SSL:
import ssl
@ -89,8 +92,9 @@ if HAS_SSL:
# Is MongoDB configured with server.pem, ca.pem, and crl.pem from
# mongodb jstests/lib?
try:
MongoClient(host, port, connectTimeoutMS=100, ssl=True,
ssl_certfile=CLIENT_PEM)
global ssl_client
ssl_client = MongoClient(host, port, connectTimeoutMS=100, ssl=True,
ssl_certfile=CLIENT_PEM)
CERT_SSL = True
except ConnectionFailure:
pass
@ -203,7 +207,8 @@ class TestClientSSL(unittest.TestCase):
class TestSSL(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls):
if not HAS_SSL:
raise SkipTest("The ssl module is not available.")
@ -243,8 +248,7 @@ class TestSSL(unittest.TestCase):
if not CERT_SSL:
raise SkipTest("No mongod available over SSL with certs")
client = MongoClient(host, port, ssl=True, ssl_certfile=CLIENT_PEM)
response = client.admin.command('ismaster')
response = ssl_client.admin.command('ismaster')
if 'setName' in response:
client = MongoReplicaSetClient(pair,
replicaSet=response['setName'],
@ -269,8 +273,7 @@ class TestSSL(unittest.TestCase):
if not CERT_SSL:
raise SkipTest("No mongod available over SSL with certs")
client = MongoClient(host, port, ssl_certfile=CLIENT_PEM)
response = client.admin.command('ismaster')
response = ssl_client.admin.command('ismaster')
if 'setName' in response:
client = MongoReplicaSetClient(pair,
replicaSet=response['setName'],
@ -376,8 +379,7 @@ class TestSSL(unittest.TestCase):
if not CERT_SSL:
raise SkipTest("No mongod available over SSL with certs")
client = MongoClient(host, port, ssl=True, ssl_certfile=CLIENT_PEM)
response = client.admin.command('ismaster')
response = ssl_client.admin.command('ismaster')
try:
MongoClient(pair,
@ -414,19 +416,18 @@ class TestSSL(unittest.TestCase):
if not CERT_SSL:
raise SkipTest("No mongod available over SSL with certs")
client = MongoClient(host, port, ssl=True, ssl_certfile=CLIENT_PEM)
if not version.at_least(client, (2, 5, 3, -1)):
if not Version.from_client(ssl_client).at_least(2, 5, 3, -1):
raise SkipTest("MONGODB-X509 tests require MongoDB 2.5.3 or newer")
if not server_started_with_auth(client):
if not server_started_with_auth(ssl_client):
raise SkipTest('Authentication is not enabled on server')
# Give admin all necessary privileges.
client['$external'].add_user(MONGODB_X509_USERNAME, roles=[
ssl_client['$external'].add_user(MONGODB_X509_USERNAME, roles=[
{'role': 'readWriteAnyDatabase', 'db': 'admin'},
{'role': 'userAdminAnyDatabase', 'db': 'admin'}])
coll = client.pymongo_test.test
coll = ssl_client.pymongo_test.test
self.assertRaises(OperationFailure, coll.count)
self.assertTrue(client.admin.authenticate(MONGODB_X509_USERNAME,
mechanism='MONGODB-X509'))
self.assertTrue(ssl_client.admin.authenticate(MONGODB_X509_USERNAME,
mechanism='MONGODB-X509'))
self.assertTrue(coll.remove())
uri = ('mongodb://%s@%s:%d/?authMechanism='
'MONGODB-X509' % (quote_plus(MONGODB_X509_USERNAME), host, port))
@ -443,7 +444,7 @@ class TestSSL(unittest.TestCase):
'MONGODB-X509' % (quote_plus("not the username"), host, port))
self.assertRaises(ConfigurationError, MongoClient, uri,
ssl=True, ssl_certfile=CLIENT_PEM)
self.assertRaises(OperationFailure, client.admin.authenticate,
self.assertRaises(OperationFailure, ssl_client.admin.authenticate,
"not the username",
mechanism="MONGODB-X509")
@ -456,8 +457,8 @@ class TestSSL(unittest.TestCase):
ssl=True, ssl_certfile=CA_PEM)
# Cleanup
remove_all_users(client['$external'])
client['$external'].logout()
remove_all_users(ssl_client['$external'])
ssl_client['$external'].logout()
if __name__ == "__main__":
unittest.main()

View File

@ -158,7 +158,9 @@ class TestIdent(unittest.TestCase):
class TestGreenletIdent(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls):
if not thread_util.have_gevent:
raise SkipTest("need Gevent")

View File

@ -23,7 +23,8 @@ import threading
from pymongo import MongoClient, MongoReplicaSetClient
from pymongo.errors import AutoReconnect
from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo
from test import host, port, SkipTest, version
from test import host, port, SkipTest
from test.version import Version
try:
@ -124,7 +125,7 @@ def drop_collections(db):
db.drop_collection(coll)
def remove_all_users(db):
if version.at_least(db.connection, (2, 5, 3, -1)):
if Version.from_client(db.connection).at_least(2, 5, 3, -1):
db.command({"dropAllUsersFromDatabase": 1})
else:
db.system.users.remove({})

View File

@ -15,41 +15,49 @@
"""Some tools for running tests based on MongoDB server version."""
def _padded(iter, length, padding=0):
l = list(iter)
if len(l) < length:
for _ in range(length - len(l)):
l.append(0)
return l
class Version(tuple):
def __new__(cls, *version):
padded_version = cls._padded(version, 4)
return super(Version, cls).__new__(cls, tuple(padded_version))
def _parse_version_string(version_string):
mod = 0
if version_string.endswith("+"):
version_string = version_string[0:-1]
mod = 1
elif version_string.endswith("-pre-"):
version_string = version_string[0:-5]
mod = -1
elif version_string.endswith("-"):
version_string = version_string[0:-1]
mod = -1
# Deal with '-rcX' substrings
if version_string.find('-rc') != -1:
version_string = version_string[0:version_string.find('-rc')]
mod = -1
@classmethod
def _padded(cls, iter, length, padding=0):
l = list(iter)
if len(l) < length:
for _ in range(length - len(l)):
l.append(padding)
return l
version = [int(part) for part in version_string.split(".")]
version = _padded(version, 3)
version.append(mod)
@classmethod
def from_string(cls, version_string):
mod = 0
if version_string.endswith("+"):
version_string = version_string[0:-1]
mod = 1
elif version_string.endswith("-pre-"):
version_string = version_string[0:-5]
mod = -1
elif version_string.endswith("-"):
version_string = version_string[0:-1]
mod = -1
# Deal with '-rcX' substrings
if version_string.find('-rc') != -1:
version_string = version_string[0:version_string.find('-rc')]
mod = -1
return tuple(version)
version = [int(part) for part in version_string.split(".")]
version = cls._padded(version, 3)
version.append(mod)
return Version(*version)
# Note this is probably broken for very old versions of the database...
def version(client):
return _parse_version_string(client.server_info()["version"])
@classmethod
def from_client(cls, client):
return cls.from_string(client.server_info()['version'])
def at_least(self, *other_version):
return self >= Version(*other_version)
def at_least(client, min_version):
return version(client) >= tuple(_padded(min_version, 4))
def __str__(self):
return ".".join(map(str, self))