PYTHON-683 Separate unit tests and integration tests in the pymongo test suite.

Raise SkipTest in tests that require a connection to MongoDB when none is available.
This commit is contained in:
Luke Lovett 2014-06-10 19:21:42 +00:00
parent ea74333b01
commit 4fa96c6c2e
20 changed files with 450 additions and 291 deletions

View File

@ -50,11 +50,22 @@ class ClientContext(object):
def __init__(self):
"""Create a client and grab essential information from the server."""
self.connected = False
self.ismaster = {}
self.w = None
self.setname = None
self.rs_client = None
self.cmd_line = None
self.version = Version(-1) # Needs to be comparable with Version
self.auth_enabled = False
self.test_commands_enabled = False
self.is_mongos = False
try:
self.client = pymongo.MongoClient(host, port)
except pymongo.errors.ConnectionFailure:
self.client = None
else:
self.connected = True
self.ismaster = self.client.admin.command('ismaster')
self.w = len(self.ismaster.get("hosts", [])) or 1
self.setname = self.ismaster.get('setName', '')
@ -91,6 +102,9 @@ class ClientContext(object):
def make_wrapper(f):
@wraps(f)
def wrap(*args, **kwargs):
# Always raise SkipTest if we can't connect to MongoDB
if not self.connected:
raise SkipTest("Cannot connect to MongoDB on %s" % pair)
if condition:
return f(*args, **kwargs)
raise SkipTest(msg)
@ -102,6 +116,12 @@ class ClientContext(object):
return decorate
return make_wrapper(func)
def require_connection(self, func):
"""Run a test only if we can connect to MongoDB."""
return self._require(self.connected,
"Cannot connect to MongoDB on %s" % pair,
func=func)
def require_version_min(self, *ver):
"""Run a test only if the server version is at least ``version``."""
other_version = Version(*ver)
@ -154,6 +174,15 @@ class ClientContext(object):
client_context = ClientContext()
class IntegrationTest(unittest.TestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""
@classmethod
@client_context.require_connection
def setUpClass(cls):
pass
def setup():
warnings.resetwarnings()
warnings.simplefilter("always")

View File

@ -62,7 +62,8 @@ class AutoAuthenticateThread(threading.Thread):
class TestGSSAPI(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls):
if not HAVE_KERBEROS:
raise SkipTest('Kerberos module not available.')
if not GSSAPI_HOST or not PRINCIPAL:
@ -158,7 +159,8 @@ class TestGSSAPI(unittest.TestCase):
class TestSASL(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls):
if not SASL_HOST or not SASL_USER or not SASL_PASS:
raise SkipTest('Must set SASL_HOST, '
'SASL_USER, and SASL_PASS to test SASL')

View File

@ -27,11 +27,49 @@ import bson
from bson.binary import *
from bson.py3compat import u
from bson.son import SON
from test import client_context, unittest
from test import client_context, unittest, SkipTest
from pymongo.mongo_client import MongoClient
class TestBinary(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Generated by the Java driver
from_java = (
b'bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu'
b'Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND'
b'ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+'
b'XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1'
b'aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR'
b'jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA'
b'AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z'
b'DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf'
b'aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx'
b'29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My'
b'1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB'
b'W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp'
b'bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc'
b'0MQAA')
cls.java_data = base64.b64decode(from_java)
# Generated by the .net driver
from_csharp = (
b'ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl'
b'iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2'
b'ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V'
b'pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl'
b'AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A'
b'ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z'
b'oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU'
b'zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn'
b'dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA'
b'CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT'
b'QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP'
b'MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00'
b'ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=')
cls.csharp_data = base64.b64decode(from_csharp)
def test_binary(self):
a_string = "hello world"
a_binary = Binary(b"hello world")
@ -91,26 +129,8 @@ class TestBinary(unittest.TestCase):
"Binary(%s, 100)" % (repr(b"test"),))
def test_legacy_java_uuid(self):
# Generated by the Java driver
from_java = (b'bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu'
b'Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND'
b'ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+'
b'XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1'
b'aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR'
b'jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA'
b'AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z'
b'DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf'
b'aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx'
b'29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My'
b'1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB'
b'W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp'
b'bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc'
b'0MQAA')
data = base64.b64decode(from_java)
# Test decoding
data = self.java_data
docs = bson.decode_all(data, SON, False, OLD_UUID_SUBTYPE)
for d in docs:
self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))
@ -145,7 +165,11 @@ class TestBinary(unittest.TestCase):
for doc in docs])
self.assertEqual(data, encoded)
# Test insert and find
@client_context.require_connection
def test_legacy_java_uuid_roundtrip(self):
data = self.java_data
docs = bson.decode_all(data, SON, False, JAVA_LEGACY)
client_context.client.pymongo_test.drop_collection('java_uuid')
coll = client_context.client.pymongo_test.java_uuid
coll.uuid_subtype = JAVA_LEGACY
@ -161,23 +185,7 @@ class TestBinary(unittest.TestCase):
client_context.client.pymongo_test.drop_collection('java_uuid')
def test_legacy_csharp_uuid(self):
# Generated by the .net driver
from_csharp = (b'ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl'
b'iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2'
b'ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V'
b'pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl'
b'AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A'
b'ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z'
b'oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU'
b'zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn'
b'dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA'
b'CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT'
b'QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP'
b'MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00'
b'ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=')
data = base64.b64decode(from_csharp)
data = self.csharp_data
# Test decoding
docs = bson.decode_all(data, SON, False, OLD_UUID_SUBTYPE)
@ -214,7 +222,11 @@ class TestBinary(unittest.TestCase):
for doc in docs])
self.assertEqual(data, encoded)
# Test insert and find
@client_context.require_connection
def test_legacy_csharp_uuid_roundtrip(self):
data = self.csharp_data
docs = bson.decode_all(data, SON, False, CSHARP_LEGACY)
client_context.client.pymongo_test.drop_collection('csharp_uuid')
coll = client_context.client.pymongo_test.csharp_uuid
coll.uuid_subtype = CSHARP_LEGACY
@ -235,6 +247,7 @@ class TestBinary(unittest.TestCase):
client = MongoClient(uri, _connect=False)
self.assertEqual(client.pymongo_test.test.uuid_subtype, CSHARP_LEGACY)
@client_context.require_connection
def test_uuid_queries(self):
coll = client_context.client.pymongo_test.test

View File

@ -25,6 +25,11 @@ from test import client_context, unittest
from test.utils import oid_generated_on_client, remove_all_users
@client_context.require_connection
def setUpModule():
pass
class BulkTestBase(unittest.TestCase):
def setUp(self):

View File

@ -36,7 +36,13 @@ from pymongo.errors import (AutoReconnect,
InvalidName,
OperationFailure,
PyMongoError)
from test import client_context, host, pair, port, SkipTest, unittest
from test import (client_context,
host,
pair,
port,
SkipTest,
unittest,
IntegrationTest)
from test.pymongo_mocks import MockClient
from test.utils import (assertRaisesExactly,
delay,
@ -53,10 +59,12 @@ def get_client(*args, **kwargs):
return MongoClient(host, port, *args, **kwargs)
class TestClient(unittest.TestCase, TestRequestMixin):
class TestClientNoConnect(unittest.TestCase):
"""MongoClient unit tests that don't require connecting to MongoDB."""
@classmethod
def setUpClass(cls):
cls.client = client_context.client
cls.client = MongoClient(host, port, _connect=False)
def test_types(self):
self.assertRaises(TypeError, MongoClient, 1)
@ -67,6 +75,50 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertRaises(ConfigurationError, MongoClient, [])
def test_get_db(self):
def make_db(base, name):
return base[name]
self.assertRaises(InvalidName, make_db, self.client, "")
self.assertRaises(InvalidName, make_db, self.client, "te$t")
self.assertRaises(InvalidName, make_db, self.client, "te.t")
self.assertRaises(InvalidName, make_db, self.client, "te\\t")
self.assertRaises(InvalidName, make_db, self.client, "te/t")
self.assertRaises(InvalidName, make_db, self.client, "te st")
self.assertTrue(isinstance(self.client.test, Database))
self.assertEqual(self.client.test, self.client["test"])
self.assertEqual(self.client.test, Database(self.client, "test"))
def test_iteration(self):
def iterate():
[a for a in self.client]
self.assertRaises(TypeError, iterate)
def test_get_default_database(self):
c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_default_database_error(self):
# URI with no database.
c = MongoClient("mongodb://%s:%d/" % (host, port), _connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (host, port)
c = MongoClient(uri, _connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
class TestClient(IntegrationTest, TestRequestMixin):
@classmethod
def setUpClass(cls):
super(TestClient, cls).setUpClass()
cls.client = client_context.client
def test_constants(self):
MongoClient.HOST = host
MongoClient.PORT = port
@ -166,21 +218,6 @@ class TestClient(unittest.TestCase, TestRequestMixin):
MongoClient(
host, port, use_greenlets=True).use_greenlets)
def test_get_db(self):
def make_db(base, name):
return base[name]
self.assertRaises(InvalidName, make_db, self.client, "")
self.assertRaises(InvalidName, make_db, self.client, "te$t")
self.assertRaises(InvalidName, make_db, self.client, "te.t")
self.assertRaises(InvalidName, make_db, self.client, "te\\t")
self.assertRaises(InvalidName, make_db, self.client, "te/t")
self.assertRaises(InvalidName, make_db, self.client, "te st")
self.assertTrue(isinstance(self.client.test, Database))
self.assertEqual(self.client.test, self.client["test"])
self.assertEqual(self.client.test, Database(self.client, "test"))
def test_database_names(self):
self.client.pymongo_test.test.save({"dummy": u("object")})
self.client.pymongo_test_mike.test.save({"dummy": u("object")})
@ -283,12 +320,6 @@ class TestClient(unittest.TestCase, TestRequestMixin):
c.admin.logout()
c.disconnect()
def test_iteration(self):
def iterate():
[a for a in self.client]
self.assertRaises(TypeError, iterate)
def test_disconnect(self):
coll = self.client.pymongo_test.bar
@ -306,21 +337,6 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertEqual(self.client,
MongoClient("mongodb://%s:%d" % (host, port)))
def test_get_default_database(self):
c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_default_database_error(self):
# URI with no database.
c = MongoClient("mongodb://%s:%d/" % (host, port), _connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (host, port)
c = MongoClient(uri, _connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
@client_context.require_auth
def test_auth_from_uri(self):
self.client.admin.add_user("admin", "pass")
@ -933,12 +949,14 @@ class TestClient(unittest.TestCase, TestRequestMixin):
client.pymongo_test.test.remove(w=0)
class TestClientLazyConnect(unittest.TestCase, _TestLazyConnectMixin):
class TestClientLazyConnect(IntegrationTest, _TestLazyConnectMixin):
def _get_client(self, **kwargs):
return get_client(**kwargs)
class TestClientLazyConnectBadSeeds(unittest.TestCase):
class TestClientLazyConnectBadSeeds(IntegrationTest):
def _get_client(self, **kwargs):
kwargs.setdefault('connectTimeoutMS', 100)
@ -963,7 +981,7 @@ class TestClientLazyConnectBadSeeds(unittest.TestCase):
class TestClientLazyConnectOneGoodSeed(
unittest.TestCase,
IntegrationTest,
_TestLazyConnectMixin):
def _get_client(self, **kwargs):
@ -992,7 +1010,8 @@ class TestClientLazyConnectOneGoodSeed(
self._get_client, use_greenlets=False)
class TestMongoClientFailover(unittest.TestCase):
class TestMongoClientFailover(IntegrationTest):
def test_discover_primary(self):
c = MockClient(
standalones=[],

View File

@ -36,6 +36,7 @@ from bson.son import SON, RE_TYPE
from pymongo import (ASCENDING, DESCENDING, GEO2D,
GEOHAYSTACK, GEOSPHERE, HASHED, TEXT)
from pymongo import message as message_module
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
from pymongo.database import Database
@ -48,25 +49,18 @@ from pymongo.errors import (DocumentTooLarge,
InvalidOperation,
OperationFailure,
WTimeoutError)
from test.test_client import get_client
from test.test_client import get_client, IntegrationTest
from test.utils import (is_mongos, joinall, enable_text_search, get_pool,
oid_generated_on_client)
from test import (client_context,
qcheck,
SkipTest,
unittest,
version)
from test import client_context, host, port, qcheck, unittest
class TestCollection(unittest.TestCase):
class TestCollectionNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
cls.w = client_context.w
def tearDown(self):
self.db.drop_collection("test_large_limit")
client = MongoClient(host, port, _connect=False)
cls.db = client.pymongo_test
def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5)
@ -92,6 +86,23 @@ class TestCollection(unittest.TestCase):
self.assertEqual(self.db.test.mike, self.db["test.mike"])
self.assertEqual(self.db.test["mike"], self.db["test.mike"])
def test_iteration(self):
self.assertRaises(TypeError, next, self.db)
class TestCollection(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestCollection, cls).setUpClass()
cls.db = client_context.client.pymongo_test
cls.w = client_context.w
@classmethod
def tearDownClass(cls):
cls.db.drop_collection("test_large_limit")
def test_drop_nonexistent_collection(self):
self.db.drop_collection('test')
self.assertFalse('test' in self.db.collection_names())
@ -740,14 +751,6 @@ class TestCollection(unittest.TestCase):
self.assertEqual(x["hello"], u("world"))
self.assertTrue("_id" in x)
def test_iteration(self):
db = self.db
def iterate():
[a for a in db.test]
self.assertRaises(TypeError, iterate)
def test_invalid_key_names(self):
db = self.db
db.test.drop()

View File

@ -29,6 +29,11 @@ from pymongo.errors import ConfigurationError, OperationFailure
from test import client_context, pair, unittest
@client_context.require_connection
def setUpModule():
pass
class TestCommon(unittest.TestCase):
def test_uuid_subtype(self):

View File

@ -25,7 +25,8 @@ sys.path[0:0] = [""]
from bson.code import Code
from bson.py3compat import u, PY3
from bson.son import SON
from pymongo import (ASCENDING,
from pymongo import (MongoClient,
ASCENDING,
DESCENDING,
ALL,
OFF)
@ -34,17 +35,113 @@ from pymongo.cursor_manager import CursorManager
from pymongo.errors import (InvalidOperation,
OperationFailure,
ExecutionTimeout)
from test import client_context, SkipTest, unittest
from test import client_context, SkipTest, unittest, host, port, IntegrationTest
from test.utils import is_mongos, get_command_line, server_started_with_auth
if PY3:
long = int
class TestCursor(unittest.TestCase):
class TestCursorNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
cls.db = client.test
def test_deepcopy_cursor_littered_with_regexes(self):
cursor = self.db.test.find({
"x": re.compile("^hmmm.*"),
"y": [re.compile("^hmm.*")],
"z": {"a": [re.compile("^hm.*")]},
re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})
cursor2 = copy.deepcopy(cursor)
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)
def test_add_remove_option(self):
cursor = self.db.test.find()
self.assertEqual(0, cursor._Cursor__query_flags)
cursor.add_option(2)
cursor2 = self.db.test.find(tailable=True)
self.assertEqual(2, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.add_option(32)
cursor2 = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.add_option(128)
cursor2 = self.db.test.find(tailable=True,
await_data=True).add_option(128)
self.assertEqual(162, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertEqual(162, cursor._Cursor__query_flags)
cursor.add_option(128)
self.assertEqual(162, cursor._Cursor__query_flags)
cursor.remove_option(128)
cursor2 = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(32)
cursor2 = self.db.test.find(tailable=True)
self.assertEqual(2, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertEqual(2, cursor._Cursor__query_flags)
cursor.remove_option(32)
self.assertEqual(2, cursor._Cursor__query_flags)
# Timeout
cursor = self.db.test.find(timeout=False)
self.assertEqual(16, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(16)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(16)
self.assertEqual(0, cursor._Cursor__query_flags)
# Tailable / Await data
cursor = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(34)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(32)
self.assertEqual(2, cursor._Cursor__query_flags)
# Exhaust - which mongos doesn't support
cursor = self.db.test.find(exhaust=True)
self.assertEqual(64, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(64)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertTrue(cursor._Cursor__exhaust)
cursor.remove_option(64)
self.assertEqual(0, cursor._Cursor__query_flags)
self.assertFalse(cursor._Cursor__exhaust)
# Partial
cursor = self.db.test.find(partial=True)
self.assertEqual(128, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(128)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(128)
self.assertEqual(0, cursor._Cursor__query_flags)
class TestCursor(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestCursor, cls).setUpClass()
cls.db = client_context.client.pymongo_test
@client_context.require_version_min(2, 5, 3, -1)
@ -677,94 +774,6 @@ class TestCursor(unittest.TestCase):
self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)
def test_deepcopy_cursor_littered_with_regexes(self):
cursor = self.db.test.find({"x": re.compile("^hmmm.*"),
"y": [re.compile("^hmm.*")],
"z": {"a": [re.compile("^hm.*")]},
re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})
cursor2 = copy.deepcopy(cursor)
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)
def test_add_remove_option(self):
cursor = self.db.test.find()
self.assertEqual(0, cursor._Cursor__query_flags)
cursor.add_option(2)
cursor2 = self.db.test.find(tailable=True)
self.assertEqual(2, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.add_option(32)
cursor2 = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.add_option(128)
cursor2 = self.db.test.find(tailable=True,
await_data=True).add_option(128)
self.assertEqual(162, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertEqual(162, cursor._Cursor__query_flags)
cursor.add_option(128)
self.assertEqual(162, cursor._Cursor__query_flags)
cursor.remove_option(128)
cursor2 = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(32)
cursor2 = self.db.test.find(tailable=True)
self.assertEqual(2, cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertEqual(2, cursor._Cursor__query_flags)
cursor.remove_option(32)
self.assertEqual(2, cursor._Cursor__query_flags)
# Timeout
cursor = self.db.test.find(timeout=False)
self.assertEqual(16, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(16)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(16)
self.assertEqual(0, cursor._Cursor__query_flags)
# Tailable / Await data
cursor = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(34)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(32)
self.assertEqual(2, cursor._Cursor__query_flags)
# Exhaust - which mongos doesn't support
if not is_mongos(self.db.connection):
cursor = self.db.test.find(exhaust=True)
self.assertEqual(64, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(64)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertTrue(cursor._Cursor__exhaust)
cursor.remove_option(64)
self.assertEqual(0, cursor._Cursor__query_flags)
self.assertFalse(cursor._Cursor__exhaust)
# Partial
cursor = self.db.test.find(partial=True)
self.assertEqual(128, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(128)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
cursor.remove_option(128)
self.assertEqual(0, cursor._Cursor__query_flags)
def test_count_with_fields(self):
self.db.test.drop()
self.db.test.save({"x": 1})

View File

@ -28,7 +28,8 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from bson.py3compat import u, string_type, text_type, PY3
from bson.son import SON, RE_TYPE
from pymongo import (ALL,
from pymongo import (MongoClient,
ALL,
auth,
OFF,
SLOW_ONLY,
@ -45,7 +46,7 @@ from pymongo.son_manipulator import (AutoReference,
NamespaceInjector,
SONManipulator,
ObjectIdShuffler)
from test import client_context, SkipTest, unittest
from test import client_context, SkipTest, unittest, host, port, IntegrationTest
from test.utils import remove_all_users, server_started_with_auth
from test.test_client import get_client
@ -53,11 +54,11 @@ if PY3:
long = int
class TestDatabase(unittest.TestCase):
class TestDatabaseNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.client = client_context.client
cls.client = MongoClient(host, port, _connect=False)
def test_name(self):
self.assertRaises(TypeError, Database, self.client, 4)
@ -77,11 +78,6 @@ class TestDatabase(unittest.TestCase):
self.assertFalse(Database(self.client, "test") !=
Database(self.client, "test"))
def test_repr(self):
self.assertEqual(repr(Database(self.client, "pymongo_test")),
"Database(%r, %s)" % (self.client,
repr(u("pymongo_test"))))
def test_get_coll(self):
db = Database(self.client, "pymongo_test")
self.assertEqual(db.test, db["test"])
@ -89,6 +85,39 @@ class TestDatabase(unittest.TestCase):
self.assertNotEqual(db.test, Collection(db, "mike"))
self.assertEqual(db.test.mike, db["test.mike"])
def test_iteration(self):
self.assertRaises(TypeError, next, self.client.pymongo_test)
def test_manipulator_properties(self):
db = self.client.foo
self.assertEqual([], db.incoming_manipulators)
self.assertEqual([], db.incoming_copying_manipulators)
self.assertEqual([], db.outgoing_manipulators)
self.assertEqual([], db.outgoing_copying_manipulators)
db.add_son_manipulator(AutoReference(db))
db.add_son_manipulator(NamespaceInjector())
db.add_son_manipulator(ObjectIdShuffler())
self.assertEqual(1, len(db.incoming_manipulators))
self.assertEqual(db.incoming_manipulators, ['NamespaceInjector'])
self.assertEqual(2, len(db.incoming_copying_manipulators))
for name in db.incoming_copying_manipulators:
self.assertTrue(name in ('ObjectIdShuffler', 'AutoReference'))
self.assertEqual([], db.outgoing_manipulators)
self.assertEqual(['AutoReference'], db.outgoing_copying_manipulators)
class TestDatabase(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestDatabase, cls).setUpClass()
cls.client = client_context.client
def test_repr(self):
self.assertEqual(repr(Database(self.client, "pymongo_test")),
"Database(%r, %s)" % (self.client,
repr(u("pymongo_test"))))
def test_create_collection(self):
db = Database(self.client, "pymongo_test")
@ -244,14 +273,6 @@ class TestDatabase(unittest.TestCase):
self.assertTrue(isinstance(info[0]["millis"], float))
self.assertTrue(isinstance(info[0]["ts"], datetime.datetime))
def test_iteration(self):
db = self.client.pymongo_test
def iterate():
[a for a in db]
self.assertRaises(TypeError, iterate)
@client_context.require_no_mongos
def test_errors(self):
db = self.client.pymongo_test
@ -869,23 +890,6 @@ class TestDatabase(unittest.TestCase):
del db.system_js.foo
self.assertEqual(["bar"], db.system_js.list())
def test_manipulator_properties(self):
db = self.client.foo
self.assertEqual([], db.incoming_manipulators)
self.assertEqual([], db.incoming_copying_manipulators)
self.assertEqual([], db.outgoing_manipulators)
self.assertEqual([], db.outgoing_copying_manipulators)
db.add_son_manipulator(AutoReference(db))
db.add_son_manipulator(NamespaceInjector())
db.add_son_manipulator(ObjectIdShuffler())
self.assertEqual(1, len(db.incoming_manipulators))
self.assertEqual(db.incoming_manipulators, ['NamespaceInjector'])
self.assertEqual(2, len(db.incoming_copying_manipulators))
for name in db.incoming_copying_manipulators:
self.assertTrue(name in ('ObjectIdShuffler', 'AutoReference'))
self.assertEqual([], db.outgoing_manipulators)
self.assertEqual(['AutoReference'], db.outgoing_copying_manipulators)
def test_command_response_without_ok(self):
# Sometimes (SERVER-10891) the server's response to a badly-formatted
# command document will have no 'ok' field. We should raise

View File

@ -35,13 +35,62 @@ from gridfs.errors import (NoFile,
UnsupportedAPI)
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from test import client_context, qcheck, unittest
from test import client_context, qcheck, unittest, host, port, IntegrationTest
class TestGridFile(unittest.TestCase):
class TestGridFileNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
cls.db = client.pymongo_test
def test_grid_file(self):
self.assertRaises(UnsupportedAPI, GridFile)
def test_grid_in_custom_opts(self):
self.assertRaises(TypeError, GridIn, "foo")
a = GridIn(self.db.fs, _id=5, filename="my_file",
contentType="text/html", chunkSize=1000, aliases=["foo"],
metadata={"foo": 1, "bar": 2}, bar=3, baz="hello")
self.assertEqual(5, a._id)
self.assertEqual("my_file", a.filename)
self.assertEqual("my_file", a.name)
self.assertEqual("text/html", a.content_type)
self.assertEqual(1000, a.chunk_size)
self.assertEqual(["foo"], a.aliases)
self.assertEqual({"foo": 1, "bar": 2}, a.metadata)
self.assertEqual(3, a.bar)
self.assertEqual("hello", a.baz)
self.assertRaises(AttributeError, getattr, a, "mike")
b = GridIn(self.db.fs,
content_type="text/html", chunk_size=1000, baz=100)
self.assertEqual("text/html", b.content_type)
self.assertEqual(1000, b.chunk_size)
self.assertEqual(100, b.baz)
def test_grid_out_cursor_options(self):
self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {},
tailable=True)
self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {},
fields={"filename": 1})
cursor = GridOutCursor(self.db.fs, {})
cursor_clone = cursor.clone()
self.assertEqual(cursor_clone.__dict__, cursor.__dict__)
self.assertRaises(NotImplementedError, cursor.add_option, 0)
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
class TestGridFile(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestGridFile, cls).setUpClass()
cls.db = client_context.client.pymongo_test
def setUp(self):
@ -96,9 +145,6 @@ class TestGridFile(unittest.TestCase):
# test that md5 still works...
self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", g.md5)
def test_grid_file(self):
self.assertRaises(UnsupportedAPI, GridFile)
def test_grid_in_default_opts(self):
self.assertRaises(TypeError, GridIn, "foo")
@ -173,30 +219,6 @@ class TestGridFile(unittest.TestCase):
self.assertEqual(a.aliases, b.aliases)
self.assertEqual(a.forty_two, b.forty_two)
def test_grid_in_custom_opts(self):
self.assertRaises(TypeError, GridIn, "foo")
a = GridIn(self.db.fs, _id=5, filename="my_file",
contentType="text/html", chunkSize=1000, aliases=["foo"],
metadata={"foo": 1, "bar": 2}, bar=3, baz="hello")
self.assertEqual(5, a._id)
self.assertEqual("my_file", a.filename)
self.assertEqual("my_file", a.name)
self.assertEqual("text/html", a.content_type)
self.assertEqual(1000, a.chunk_size)
self.assertEqual(["foo"], a.aliases)
self.assertEqual({"foo": 1, "bar": 2}, a.metadata)
self.assertEqual(3, a.bar)
self.assertEqual("hello", a.baz)
self.assertRaises(AttributeError, getattr, a, "mike")
b = GridIn(self.db.fs,
content_type="text/html", chunk_size=1000, baz=100)
self.assertEqual("text/html", b.content_type)
self.assertEqual(1000, b.chunk_size)
self.assertEqual(100, b.baz)
def test_grid_out_default_opts(self):
self.assertRaises(TypeError, GridOut, "foo")
@ -584,19 +606,6 @@ Bye"""))
self.assertRaises(ConnectionFailure, infile.write, b'data goes here')
self.assertRaises(ConnectionFailure, infile.close)
def test_grid_out_cursor_options(self):
self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {},
tailable=True)
self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {},
fields={"filename":1})
cursor = GridOutCursor(self.db.fs, {})
cursor_clone = cursor.clone()
self.assertEqual(cursor_clone.__dict__, cursor.__dict__)
self.assertRaises(NotImplementedError, cursor.add_option, 0)
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
if __name__ == "__main__":
unittest.main()

View File

@ -32,7 +32,7 @@ import gridfs
from bson.py3compat import u, StringIO, string_type
from gridfs.errors import (FileExists,
NoFile)
from test import client_context, unittest
from test import client_context, unittest, host, port, IntegrationTest
from test.utils import joinall
@ -68,10 +68,23 @@ class JustRead(threading.Thread):
assert data == b"hello"
class TestGridfs(unittest.TestCase):
class TestGridfsNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
cls.db = client.pymongo_test
def test_gridfs(self):
self.assertRaises(TypeError, gridfs.GridFS, "foo")
self.assertRaises(TypeError, gridfs.GridFS, self.db, 5)
class TestGridfs(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestGridfs, cls).setUpClass()
cls.db = client_context.client.pymongo_test
cls.fs = gridfs.GridFS(cls.db)
cls.alt = gridfs.GridFS(cls.db, "alt")
@ -82,10 +95,6 @@ class TestGridfs(unittest.TestCase):
self.db.drop_collection("alt.files")
self.db.drop_collection("alt.chunks")
def test_gridfs(self):
self.assertRaises(TypeError, gridfs.GridFS, "foo")
self.assertRaises(TypeError, gridfs.GridFS, self.db, 5)
def test_basic(self):
oid = self.fs.put(b"hello world")
self.assertEqual(b"hello world", self.fs.get(oid).read())
@ -393,6 +402,7 @@ class TestGridfs(unittest.TestCase):
class TestGridfsReplicaSet(TestReplicaSetClientBase):
def test_gridfs_replica_set(self):
rsc = self._get_client(
w=self.w, wtimeout=5000,

View File

@ -33,17 +33,13 @@ from bson.son import RE_TYPE
from bson.timestamp import Timestamp
from bson.tz_util import utc
from test import client_context, SkipTest, unittest
from test import client_context, SkipTest, unittest, IntegrationTest
PY3 = sys.version_info[0] == 3
class TestJsonUtil(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
def setUp(self):
if not json_util.json_lib:
raise SkipTest("No json or simplejson module")
@ -206,6 +202,18 @@ class TestJsonUtil(unittest.TestCase):
# Check order.
self.assertEqual('{"$code": "return z", "$scope": {"z": 2}}', res)
class TestJsonUtilRoundtrip(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestJsonUtilRoundtrip, cls).setUpClass()
cls.db = client_context.client.pymongo_test
def setUp(self):
if not json_util.json_lib:
raise SkipTest("No json or simplejson module")
def test_cursor(self):
db = self.db

View File

@ -20,10 +20,15 @@ import threading
sys.path[0:0] = [""]
from pymongo.errors import AutoReconnect
from test import unittest
from test import unittest, client_context
from test.pymongo_mocks import MockClient
@client_context.require_connection
def setUpModule():
pass
class FindOne(threading.Thread):
def __init__(self, client):
super(FindOne, self).__init__()

View File

@ -21,13 +21,18 @@ sys.path[0:0] = [""]
from bson.py3compat import thread
from test import host, port, SkipTest, unittest
from test import host, port, SkipTest, unittest, client_context
from test.test_pooling_base import (
_TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets,
_TestPoolSocketSharing, _TestWaitQueueMultiple, one)
from test.utils import get_pool
@client_context.require_connection
def setUpModule():
pass
class TestPoolingThreads(_TestPooling, unittest.TestCase):
use_greenlets = False

View File

@ -20,13 +20,18 @@ import time
from pymongo import pool
from pymongo.errors import ConfigurationError
from test import host, port, SkipTest, unittest
from test import host, port, SkipTest, unittest, client_context
from test.utils import looplet
from test.test_pooling_base import (
_TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets,
_TestPoolSocketSharing, _TestWaitQueueMultiple, has_gevent)
@client_context.require_connection
def setUpModule():
pass
class TestPoolingGevent(_TestPooling, unittest.TestCase):
"""Apply all the standard pool tests with greenlets and Gevent"""
use_greenlets = True

View File

@ -32,11 +32,18 @@ from pymongo.errors import ConfigurationError
from test.test_replica_set_client import TestReplicaSetClientBase
from test.test_client import get_client
from test import client_context, host, port, SkipTest, unittest, utils
from test import (client_context,
host,
port,
SkipTest,
unittest,
utils,
IntegrationTest)
from test.version import Version
class TestReadPreferencesBase(TestReplicaSetClientBase):
def setUp(self):
super(TestReadPreferencesBase, self).setUp()
# Insert some data so we can use cursors in read_from_which_host
@ -194,6 +201,7 @@ class ReadPrefTester(MongoReplicaSetClient):
class TestCommandAndReadPreference(TestReplicaSetClientBase):
def setUp(self):
super(TestCommandAndReadPreference, self).setUp()
@ -505,7 +513,8 @@ class TestMovingAverage(unittest.TestCase):
self.assertEqual((30 - 100 + 17 + 43 - 1111) / 5., avg7.get())
class TestMongosConnection(unittest.TestCase):
class TestMongosConnection(IntegrationTest):
def test_mongos_connection(self):
c = get_client()
is_mongos = utils.is_mongos(c)

View File

@ -53,6 +53,7 @@ class TestReplicaSetClientAgainstStandalone(unittest.TestCase):
"""This is a funny beast -- we want to run tests for MongoReplicaSetClient
but only if the database at DB_IP and DB_PORT is a standalone.
"""
@client_context.require_connection
def setUp(self):
if client_context.setname:
raise SkipTest("Connected to a replica set, not a standalone mongod")
@ -1073,6 +1074,8 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
class TestReplicaSetWireVersion(unittest.TestCase):
@client_context.require_connection
def test_wire_version(self):
c = MockReplicaSetClient(
standalones=[],
@ -1142,6 +1145,8 @@ class TestReplicaSetClientLazyConnectBadSeeds(
class TestReplicaSetClientInternalIPs(unittest.TestCase):
@client_context.require_connection
def test_connect_with_internal_ips(self):
# Client is passed an IP it can reach, 'a:1', but the RS config
# only contains unreachable IPs like 'internal-ip'. PYTHON-608.
@ -1157,6 +1162,8 @@ class TestReplicaSetClientInternalIPs(unittest.TestCase):
class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase):
@client_context.require_connection
def test_max_write_batch_size(self):
c = MockReplicaSetClient(
standalones=[],

View File

@ -20,10 +20,15 @@ sys.path[0:0] = [""]
from pymongo.errors import ConfigurationError, ConnectionFailure
from pymongo import ReadPreference
from test import unittest
from test import unittest, client_context
from test.pymongo_mocks import MockClient, MockReplicaSetClient
@client_context.require_connection
def setUpModule():
pass
class TestSecondaryBecomesStandalone(unittest.TestCase):
# An administrator removes a secondary from a 3-node set and
# brings it back up as standalone, without updating the other

View File

@ -19,18 +19,20 @@ import sys
sys.path[0:0] = [""]
from bson.son import SON
from pymongo import MongoClient
from pymongo.son_manipulator import (NamespaceInjector,
ObjectIdInjector,
ObjectIdShuffler,
SONManipulator)
from test import client_context, qcheck, unittest
from test import qcheck, unittest, host, port
class TestSONManipulator(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = client_context.client.pymongo_test
client = MongoClient(host, port, _connect=False)
cls.db = client.pymongo_test
def test_basic(self):
manip = SONManipulator()

View File

@ -17,7 +17,7 @@
import threading
import traceback
from test import SkipTest, unittest
from test import SkipTest, unittest, client_context
from test.utils import (joinall, remove_all_users,
server_started_with_auth, RendezvousThread)
from test.test_client import get_client
@ -26,6 +26,11 @@ from pymongo.pool import SocketInfo, _closed
from pymongo.errors import AutoReconnect, OperationFailure
@client_context.require_connection
def setUpModule():
pass
class AutoAuthenticateThreads(threading.Thread):
def __init__(self, collection, num):