PYTHON-785 Don't use requests in tests.
This commit is contained in:
parent
daad4446f8
commit
e8be121a89
@ -31,6 +31,7 @@ from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from test import client_context, host, port, SkipTest, unittest, Version
|
||||
from test.utils import delay
|
||||
|
||||
# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS.
|
||||
GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
|
||||
@ -46,18 +47,22 @@ SASL_DB = os.environ.get('SASL_DB', '$external')
|
||||
|
||||
class AutoAuthenticateThread(threading.Thread):
|
||||
"""Used in testing threaded authentication.
|
||||
|
||||
This does collection.find_one() with a 1-second delay to ensure it must
|
||||
check out and authenticate multiple sockets from the pool concurrently.
|
||||
|
||||
:Parameters:
|
||||
`collection`: An auth-protected collection containing one document.
|
||||
"""
|
||||
|
||||
def __init__(self, database):
|
||||
def __init__(self, collection):
|
||||
super(AutoAuthenticateThread, self).__init__()
|
||||
self.database = database
|
||||
self.success = True
|
||||
self.collection = collection
|
||||
self.success = False
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.database.command('dbstats')
|
||||
except OperationFailure:
|
||||
self.success = False
|
||||
assert self.collection.find_one({'$where': delay(1)}) is not None
|
||||
self.success = True
|
||||
|
||||
|
||||
class TestGSSAPI(unittest.TestCase):
|
||||
@ -148,14 +153,20 @@ class TestGSSAPI(unittest.TestCase):
|
||||
def test_gssapi_threaded(self):
|
||||
|
||||
client = MongoClient(GSSAPI_HOST)
|
||||
# Make sure each thread uses a different socket.
|
||||
client.start_request()
|
||||
self.assertTrue(client.test.authenticate(PRINCIPAL,
|
||||
mechanism='GSSAPI'))
|
||||
|
||||
# Need one document in the collection. AutoAuthenticateThread does
|
||||
# collection.find_one with a 1-second delay, forcing it to check out
|
||||
# multiple sockets from the pool concurrently, proving that
|
||||
# auto-authentication works with GSSAPI.
|
||||
collection = client.test.collection
|
||||
collection.remove()
|
||||
collection.insert({'_id': 1})
|
||||
|
||||
threads = []
|
||||
for _ in range(4):
|
||||
threads.append(AutoAuthenticateThread(client.test))
|
||||
threads.append(AutoAuthenticateThread(collection))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
@ -174,7 +185,7 @@ class TestGSSAPI(unittest.TestCase):
|
||||
|
||||
threads = []
|
||||
for _ in range(4):
|
||||
threads.append(AutoAuthenticateThread(client.foo))
|
||||
threads.append(AutoAuthenticateThread(collection))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
|
||||
@ -29,7 +29,7 @@ from test import (client_context,
|
||||
port,
|
||||
IntegrationTest,
|
||||
SkipTest)
|
||||
from test.utils import oid_generated_on_client, remove_all_users
|
||||
from test.utils import oid_generated_on_client, remove_all_users, wait_until
|
||||
|
||||
|
||||
class BulkTestBase(IntegrationTest):
|
||||
@ -1115,9 +1115,9 @@ class TestBulkNoResults(BulkTestBase):
|
||||
batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}})
|
||||
batch.insert({'_id': 2})
|
||||
batch.find({'_id': 1}).remove_one()
|
||||
with self.client.start_request():
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
self.assertEqual(2, self.coll.count())
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
wait_until(lambda: 2 == self.coll.count(),
|
||||
'insert 2 documents')
|
||||
|
||||
def test_no_results_ordered_failure(self):
|
||||
|
||||
@ -1127,9 +1127,9 @@ class TestBulkNoResults(BulkTestBase):
|
||||
batch.insert({'_id': 2})
|
||||
batch.insert({'_id': 1})
|
||||
batch.find({'_id': 1}).remove_one()
|
||||
with self.client.start_request():
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
self.assertEqual(3, self.coll.count())
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
wait_until(lambda: 3 == self.coll.count(),
|
||||
'insert 3 documents')
|
||||
|
||||
def test_no_results_unordered_success(self):
|
||||
|
||||
@ -1138,9 +1138,9 @@ class TestBulkNoResults(BulkTestBase):
|
||||
batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}})
|
||||
batch.insert({'_id': 2})
|
||||
batch.find({'_id': 1}).remove_one()
|
||||
with self.client.start_request():
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
self.assertEqual(2, self.coll.count())
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
wait_until(lambda: 2 == self.coll.count(),
|
||||
'insert 2 documents')
|
||||
|
||||
def test_no_results_unordered_failure(self):
|
||||
|
||||
@ -1150,10 +1150,10 @@ class TestBulkNoResults(BulkTestBase):
|
||||
batch.insert({'_id': 2})
|
||||
batch.insert({'_id': 1})
|
||||
batch.find({'_id': 1}).remove_one()
|
||||
with self.client.start_request():
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
self.assertEqual(2, self.coll.count())
|
||||
self.assertTrue(self.coll.find_one({'_id': 1}) is None)
|
||||
self.assertTrue(batch.execute({'w': 0}) is None)
|
||||
wait_until(lambda: 2 == self.coll.count(),
|
||||
'insert 2 documents')
|
||||
self.assertTrue(self.coll.find_one({'_id': 1}) is None)
|
||||
|
||||
|
||||
class TestBulkAuthorization(BulkTestBase):
|
||||
|
||||
@ -18,7 +18,6 @@ import contextlib
|
||||
import datetime
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
@ -32,7 +31,6 @@ from bson.son import SON
|
||||
from bson.tz_util import utc
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.database import Database
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo import auth, message
|
||||
from pymongo.errors import (AutoReconnect,
|
||||
ConfigurationError,
|
||||
@ -63,7 +61,6 @@ from test.utils import (assertRaisesExactly,
|
||||
ignore_deprecations,
|
||||
remove_all_users,
|
||||
server_is_master_with_slave,
|
||||
TestRequestMixin,
|
||||
get_pool,
|
||||
one,
|
||||
connected,
|
||||
@ -74,7 +71,7 @@ from test.utils import (assertRaisesExactly,
|
||||
NTHREADS)
|
||||
|
||||
|
||||
class ClientUnitTest(unittest.TestCase, TestRequestMixin):
|
||||
class ClientUnitTest(unittest.TestCase):
|
||||
"""MongoClient tests that don't require a server."""
|
||||
|
||||
@classmethod
|
||||
@ -166,7 +163,7 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin):
|
||||
self.assertEqual(Database(c, 'foo'), c.get_default_database())
|
||||
|
||||
|
||||
class TestClient(IntegrationTest, TestRequestMixin):
|
||||
class TestClient(IntegrationTest):
|
||||
|
||||
def test_constants(self):
|
||||
# Set bad defaults.
|
||||
@ -605,80 +602,6 @@ class TestClient(IntegrationTest, TestRequestMixin):
|
||||
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
|
||||
self.assertEqual(0, len(get_pool(client).sockets))
|
||||
|
||||
def test_with_start_request(self):
|
||||
pool = get_pool(self.client)
|
||||
|
||||
# No request started
|
||||
self.assertNoRequest(pool)
|
||||
self.assertDifferentSock(pool)
|
||||
|
||||
# Start a request
|
||||
request_context_mgr = self.client.start_request()
|
||||
self.assertTrue(
|
||||
isinstance(request_context_mgr, object)
|
||||
)
|
||||
|
||||
self.assertNoSocketYet(pool)
|
||||
self.assertSameSock(pool)
|
||||
self.assertRequestSocket(pool)
|
||||
|
||||
# End request
|
||||
request_context_mgr.__exit__(None, None, None)
|
||||
self.assertNoRequest(pool)
|
||||
self.assertDifferentSock(pool)
|
||||
|
||||
# Test the 'with' statement
|
||||
with self.client.start_request() as request:
|
||||
self.assertEqual(self.client, request.connection)
|
||||
self.assertNoSocketYet(pool)
|
||||
self.assertSameSock(pool)
|
||||
self.assertRequestSocket(pool)
|
||||
|
||||
# Request has ended
|
||||
self.assertNoRequest(pool)
|
||||
self.assertDifferentSock(pool)
|
||||
|
||||
def test_request_threads(self):
|
||||
client = self.client
|
||||
pool = get_pool(client)
|
||||
self.assertNotInRequestAndDifferentSock(client, pool)
|
||||
|
||||
started_request, ended_request = threading.Event(), threading.Event()
|
||||
checked_request = threading.Event()
|
||||
thread_done = [False]
|
||||
|
||||
# Starting a request in one thread doesn't put the other thread in a
|
||||
# request
|
||||
def f():
|
||||
self.assertNotInRequestAndDifferentSock(client, pool)
|
||||
client.start_request()
|
||||
self.assertInRequestAndSameSock(client, pool)
|
||||
started_request.set()
|
||||
checked_request.wait()
|
||||
checked_request.clear()
|
||||
self.assertInRequestAndSameSock(client, pool)
|
||||
client.end_request()
|
||||
self.assertNotInRequestAndDifferentSock(client, pool)
|
||||
ended_request.set()
|
||||
checked_request.wait()
|
||||
thread_done[0] = True
|
||||
|
||||
t = threading.Thread(target=f)
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
# It doesn't matter in what order the main thread or t initially get
|
||||
# to started_request.set() / wait(); by waiting here we ensure that t
|
||||
# has called client.start_request() before we assert on the next line.
|
||||
started_request.wait()
|
||||
self.assertNotInRequestAndDifferentSock(client, pool)
|
||||
checked_request.set()
|
||||
ended_request.wait()
|
||||
self.assertNotInRequestAndDifferentSock(client, pool)
|
||||
checked_request.set()
|
||||
t.join()
|
||||
self.assertNotInRequestAndDifferentSock(client, pool)
|
||||
self.assertTrue(thread_done[0], "Thread didn't complete")
|
||||
|
||||
def test_interrupt_signal(self):
|
||||
if sys.platform.startswith('java'):
|
||||
# We can't figure out how to raise an exception on a thread that's
|
||||
@ -724,7 +647,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
|
||||
next(db.foo.find())
|
||||
)
|
||||
|
||||
def test_operation_failure_without_request(self):
|
||||
def test_operation_failure(self):
|
||||
# Ensure MongoClient doesn't close socket after it gets an error
|
||||
# response to getLastError. PYTHON-395.
|
||||
pool = get_pool(self.client)
|
||||
@ -741,27 +664,6 @@ class TestClient(IntegrationTest, TestRequestMixin):
|
||||
new_sock_info = next(iter(pool.sockets))
|
||||
self.assertEqual(old_sock_info, new_sock_info)
|
||||
|
||||
def test_operation_failure_with_request(self):
|
||||
# Ensure MongoClient doesn't close socket after it gets an error
|
||||
# response to getLastError. PYTHON-395.
|
||||
c = rs_or_single_client()
|
||||
c.start_request()
|
||||
pool = get_pool(c)
|
||||
|
||||
# Pool reserves a socket for this thread.
|
||||
c.pymongo_test.test.find_one()
|
||||
self.assertTrue(isinstance(pool._get_request_state(), SocketInfo))
|
||||
|
||||
old_sock_info = pool._get_request_state()
|
||||
c.pymongo_test.test.drop()
|
||||
c.pymongo_test.test.insert({'_id': 'foo'})
|
||||
self.assertRaises(
|
||||
OperationFailure,
|
||||
c.pymongo_test.test.insert, {'_id': 'foo'})
|
||||
|
||||
# OperationFailure doesn't affect the request socket
|
||||
self.assertEqual(old_sock_info, pool._get_request_state())
|
||||
|
||||
def test_alive(self):
|
||||
self.assertTrue(self.client.alive())
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ from pymongo.errors import (DocumentTooLarge,
|
||||
from test.test_client import IntegrationTest
|
||||
from test.utils import (is_mongos, joinall, enable_text_search, get_pool,
|
||||
oid_generated_on_client, one, ignore_deprecations,
|
||||
rs_or_single_client)
|
||||
rs_or_single_client, wait_until)
|
||||
from test import client_context, host, port, qcheck, unittest
|
||||
|
||||
|
||||
@ -793,14 +793,15 @@ class TestCollection(IntegrationTest):
|
||||
db.drop_collection("test")
|
||||
db.test.ensure_index([('i', ASCENDING)], unique=True)
|
||||
|
||||
with db.connection.start_request():
|
||||
# No error
|
||||
db.test.insert([{'i': i} for i in range(5, 10)], w=0)
|
||||
db.test.remove()
|
||||
# No error
|
||||
db.test.insert([{'i': i} for i in range(5, 10)], w=0)
|
||||
wait_until(lambda: 5 == db.test.count(), 'insert 5 documents')
|
||||
|
||||
# No error
|
||||
db.test.insert([{'i': 1}] * 2, w=0)
|
||||
self.assertEqual(1, db.test.count())
|
||||
db.test.remove()
|
||||
|
||||
# No error
|
||||
db.test.insert([{'i': 1}] * 2, w=0)
|
||||
wait_until(lambda: 1 == db.test.count(), 'insert 1 document')
|
||||
|
||||
self.assertRaises(
|
||||
DuplicateKeyError,
|
||||
@ -811,20 +812,19 @@ class TestCollection(IntegrationTest):
|
||||
wc = db.write_concern
|
||||
db.write_concern = {"w": 0}
|
||||
try:
|
||||
with db.connection.start_request():
|
||||
db.test.ensure_index([('i', ASCENDING)], unique=True)
|
||||
db.test.ensure_index([('i', ASCENDING)], unique=True)
|
||||
|
||||
# No error
|
||||
db.test.insert([{'i': 1}] * 2)
|
||||
self.assertEqual(1, db.test.count())
|
||||
# No error.
|
||||
db.test.insert([{'i': 1}] * 2)
|
||||
wait_until(lambda: 1 == db.test.count(), 'insert 1 document')
|
||||
|
||||
# Implied safe
|
||||
# Implied acknowledged.
|
||||
self.assertRaises(
|
||||
DuplicateKeyError,
|
||||
lambda: db.test.insert([{'i': 2}] * 2, fsync=True),
|
||||
)
|
||||
|
||||
# Explicit safe
|
||||
# Explicit acknowledged.
|
||||
self.assertRaises(
|
||||
DuplicateKeyError,
|
||||
lambda: db.test.insert([{'i': 2}] * 2, w=1),
|
||||
@ -902,7 +902,7 @@ class TestCollection(IntegrationTest):
|
||||
self.db.test.find_one({'_id': 'explicit_id'})['hello']
|
||||
)
|
||||
|
||||
# Safe mode
|
||||
# Acknowledged mode.
|
||||
self.db.test.create_index("hello", unique=True)
|
||||
# No exception, even though we duplicate the first doc's "hello" value
|
||||
self.db.test.save({'_id': 'explicit_id', 'hello': 'world'}, w=0)
|
||||
@ -924,23 +924,19 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
def test_unique_index(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("hello")
|
||||
|
||||
with self.client.start_request():
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("hello")
|
||||
# No error.
|
||||
db.test.save({"hello": "world"})
|
||||
db.test.save({"hello": "world"})
|
||||
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("hello", unique=True)
|
||||
|
||||
with self.assertRaises(DuplicateKeyError):
|
||||
db.test.save({"hello": "world"})
|
||||
db.test.save({"hello": "mike"})
|
||||
db.test.save({"hello": "world"})
|
||||
self.assertFalse(db.error())
|
||||
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("hello", unique=True)
|
||||
|
||||
db.test.save({"hello": "world"})
|
||||
db.test.save({"hello": "mike"})
|
||||
db.test.save({"hello": "world"}, w=0)
|
||||
self.assertTrue(db.error())
|
||||
|
||||
def test_duplicate_key_error(self):
|
||||
db = self.db
|
||||
@ -949,41 +945,17 @@ class TestCollection(IntegrationTest):
|
||||
db.test.create_index("x", unique=True)
|
||||
|
||||
db.test.insert({"_id": 1, "x": 1})
|
||||
db.test.insert({"_id": 2, "x": 2})
|
||||
|
||||
with db.connection.start_request():
|
||||
# No error
|
||||
db.test.insert({"_id": 1, "x": 1}, w=0)
|
||||
db.test.save({"_id": 1, "x": 1}, w=0)
|
||||
db.test.insert({"_id": 2, "x": 2}, w=0)
|
||||
db.test.save({"_id": 2, "x": 2}, w=0)
|
||||
with self.assertRaises(DuplicateKeyError) as context:
|
||||
db.test.insert({"x": 1})
|
||||
|
||||
# But all those statements didn't do anything
|
||||
self.assertEqual(2, db.test.count())
|
||||
self.assertIsNotNone(context.exception.details)
|
||||
|
||||
expected_error = OperationFailure
|
||||
if client_context.version.at_least(1, 3):
|
||||
expected_error = DuplicateKeyError
|
||||
with self.assertRaises(DuplicateKeyError) as context:
|
||||
db.test.save({"x": 1})
|
||||
|
||||
self.assertRaises(expected_error,
|
||||
db.test.insert, {"_id": 1})
|
||||
self.assertRaises(expected_error,
|
||||
db.test.insert, {"x": 1})
|
||||
|
||||
self.assertRaises(expected_error,
|
||||
db.test.save, {"x": 2})
|
||||
self.assertRaises(expected_error,
|
||||
db.test.update, {"x": 1},
|
||||
{"$inc": {"x": 1}})
|
||||
|
||||
try:
|
||||
db.test.insert({"_id": 1})
|
||||
except expected_error as exc:
|
||||
# Just check that we set the error document. Fields
|
||||
# vary by MongoDB version.
|
||||
self.assertTrue(exc.details is not None)
|
||||
else:
|
||||
self.fail("%s was not raised" % (expected_error.__name__,))
|
||||
self.assertIsNotNone(context.exception.details)
|
||||
self.assertEqual(1, db.test.count())
|
||||
|
||||
def test_wtimeout(self):
|
||||
# Ensure setting wtimeout doesn't disable write concern altogether.
|
||||
@ -1006,33 +978,33 @@ class TestCollection(IntegrationTest):
|
||||
self.assertEqual(1, db.test.count())
|
||||
|
||||
docs = []
|
||||
docs.append({"_id": oid, "two": 2})
|
||||
docs.append({"_id": oid, "two": 2}) # Duplicate _id.
|
||||
docs.append({"three": 3})
|
||||
docs.append({"four": 4})
|
||||
docs.append({"five": 5})
|
||||
|
||||
with self.client.start_request():
|
||||
db.test.insert(docs, manipulate=False, w=0)
|
||||
self.assertEqual(11000, db.error()['code'])
|
||||
self.assertEqual(1, db.test.count())
|
||||
with self.assertRaises(DuplicateKeyError):
|
||||
db.test.insert(docs, manipulate=False)
|
||||
|
||||
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
|
||||
self.assertEqual(11000, db.error()['code'])
|
||||
self.assertEqual(4, db.test.count())
|
||||
self.assertEqual(1, db.test.count())
|
||||
|
||||
db.drop_collection("test")
|
||||
oid = db.test.insert({"_id": oid, "one": 1}, w=0)
|
||||
self.assertEqual(1, db.test.count())
|
||||
docs[0].pop("_id")
|
||||
docs[2]["_id"] = oid
|
||||
with self.assertRaises(DuplicateKeyError):
|
||||
db.test.insert(docs, manipulate=False, continue_on_error=True)
|
||||
|
||||
db.test.insert(docs, manipulate=False, w=0)
|
||||
self.assertEqual(11000, db.error()['code'])
|
||||
self.assertEqual(3, db.test.count())
|
||||
self.assertEqual(4, db.test.count())
|
||||
|
||||
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
|
||||
self.assertEqual(11000, db.error()['code'])
|
||||
self.assertEqual(6, db.test.count())
|
||||
db.drop_collection("test")
|
||||
oid = db.test.insert({"_id": oid, "one": 1}, w=0)
|
||||
self.assertEqual(1, db.test.count())
|
||||
docs[0].pop("_id")
|
||||
docs[2]["_id"] = oid
|
||||
|
||||
with self.assertRaises(DuplicateKeyError):
|
||||
db.test.insert(docs, manipulate=False)
|
||||
|
||||
self.assertEqual(3, db.test.count())
|
||||
db.test.insert(docs, manipulate=False, continue_on_error=True, w=0)
|
||||
self.assertEqual(6, db.test.count())
|
||||
|
||||
def test_error_code(self):
|
||||
try:
|
||||
@ -1062,17 +1034,14 @@ class TestCollection(IntegrationTest):
|
||||
self.assertRaises(DuplicateKeyError,
|
||||
db.test.insert, {"hello": {"a": 4, "b": 10}})
|
||||
|
||||
def test_safe_insert(self):
|
||||
with self.client.start_request():
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
def test_acknowledged_insert(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
|
||||
a = {"hello": "world"}
|
||||
db.test.insert(a)
|
||||
db.test.insert(a, w=0)
|
||||
self.assertTrue("E11000" in db.error()["err"])
|
||||
|
||||
self.assertRaises(OperationFailure, db.test.insert, a)
|
||||
a = {"hello": "world"}
|
||||
db.test.insert(a)
|
||||
db.test.insert(a, w=0)
|
||||
self.assertRaises(OperationFailure, db.test.insert, a)
|
||||
|
||||
def test_update(self):
|
||||
db = self.db
|
||||
@ -1159,36 +1128,25 @@ class TestCollection(IntegrationTest):
|
||||
self.assertEqual(1, db.test.count())
|
||||
self.assertEqual(2, db.test.find_one()["count"])
|
||||
|
||||
def test_safe_update(self):
|
||||
with self.client.start_request():
|
||||
db = self.db
|
||||
v113minus = client_context.version.at_least(1, 1, 3, -1)
|
||||
v19 = client_context.version.at_least(1, 9)
|
||||
def test_acknowledged_update(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("x", unique=True)
|
||||
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("x", unique=True)
|
||||
db.test.insert({"x": 5})
|
||||
id = db.test.insert({"x": 4})
|
||||
|
||||
db.test.insert({"x": 5})
|
||||
id = db.test.insert({"x": 4})
|
||||
self.assertEqual(
|
||||
None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0))
|
||||
|
||||
self.assertEqual(
|
||||
None, db.test.update({"_id": id}, {"$inc": {"x": 1}}, w=0))
|
||||
self.assertRaises(DuplicateKeyError, db.test.update,
|
||||
{"_id": id}, {"$inc": {"x": 1}})
|
||||
|
||||
if v19:
|
||||
self.assertTrue("E11000" in db.error()["err"])
|
||||
elif v113minus:
|
||||
self.assertTrue(db.error()["err"].startswith("E11001"))
|
||||
else:
|
||||
self.assertTrue(db.error()["err"].startswith("E12011"))
|
||||
self.assertEqual(1, db.test.update({"_id": id},
|
||||
{"$inc": {"x": 2}})["n"])
|
||||
|
||||
self.assertRaises(OperationFailure, db.test.update,
|
||||
{"_id": id}, {"$inc": {"x": 1}})
|
||||
|
||||
self.assertEqual(1, db.test.update({"_id": id},
|
||||
{"$inc": {"x": 2}})["n"])
|
||||
|
||||
self.assertEqual(0, db.test.update({"_id": "foo"},
|
||||
{"$inc": {"x": 2}})["n"])
|
||||
self.assertEqual(0, db.test.update({"_id": "foo"},
|
||||
{"$inc": {"x": 2}})["n"])
|
||||
|
||||
def test_update_with_invalid_keys(self):
|
||||
self.db.drop_collection("test")
|
||||
@ -1233,20 +1191,17 @@ class TestCollection(IntegrationTest):
|
||||
self.assertNotEqual(0, self.db.test.update({"hello": "world"},
|
||||
{})['n'])
|
||||
|
||||
def test_safe_save(self):
|
||||
with self.client.start_request():
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("hello", unique=True)
|
||||
def test_acknowledged_save(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.test.create_index("hello", unique=True)
|
||||
|
||||
db.test.save({"hello": "world"})
|
||||
db.test.save({"hello": "world"}, w=0)
|
||||
self.assertTrue("E11000" in db.error()["err"])
|
||||
db.test.save({"hello": "world"})
|
||||
db.test.save({"hello": "world"}, w=0)
|
||||
self.assertRaises(DuplicateKeyError, db.test.save,
|
||||
{"hello": "world"})
|
||||
|
||||
self.assertRaises(OperationFailure, db.test.save,
|
||||
{"hello": "world"})
|
||||
|
||||
def test_safe_remove(self):
|
||||
def test_acknowledged_remove(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.create_collection("test", capped=True, size=1000)
|
||||
@ -1254,16 +1209,8 @@ class TestCollection(IntegrationTest):
|
||||
db.test.insert({"x": 1})
|
||||
self.assertEqual(1, db.test.count())
|
||||
|
||||
with db.connection.start_request():
|
||||
self.assertEqual(None, db.test.remove({"x": 1}, w=0))
|
||||
self.assertEqual(1, db.test.count())
|
||||
|
||||
if client_context.version.at_least(1, 1, 3, -1):
|
||||
self.assertRaises(OperationFailure, db.test.remove,
|
||||
{"x": 1})
|
||||
else: # Just test that it doesn't blow up
|
||||
db.test.remove({"x": 1})
|
||||
|
||||
# Can't remove from capped collection.
|
||||
self.assertRaises(OperationFailure, db.test.remove, {"x": 1})
|
||||
db.drop_collection("test")
|
||||
db.test.insert({"x": 1})
|
||||
db.test.insert({"x": 1})
|
||||
@ -1836,16 +1783,16 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
batch[1]['_id'] = batch[0]['_id']
|
||||
|
||||
# Test that inserts fail after first error, acknowledged.
|
||||
# Test that inserts fail after first error.
|
||||
self.db.test.drop()
|
||||
self.assertRaises(DuplicateKeyError, self.db.test.insert, batch, w=1)
|
||||
self.assertRaises(DuplicateKeyError, self.db.test.insert, batch)
|
||||
self.assertEqual(1, self.db.test.count())
|
||||
|
||||
# Test that inserts fail after first error, unacknowledged.
|
||||
# 2 batches, 2 errors, continue on error.
|
||||
self.db.test.drop()
|
||||
with self.client.start_request():
|
||||
self.assertTrue(self.db.test.insert(batch, w=0))
|
||||
self.assertEqual(1, self.db.test.count())
|
||||
self.assertTrue(self.db.test.insert(batch, w=0))
|
||||
wait_until(lambda: 1 == self.db.test.count(),
|
||||
'insert 1 document')
|
||||
|
||||
# 2 batches, 2 errors, acknowledged, continue on error
|
||||
self.db.test.drop()
|
||||
@ -1862,11 +1809,11 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
# 2 batches, 2 errors, unacknowledged, continue on error
|
||||
self.db.test.drop()
|
||||
with self.client.start_request():
|
||||
self.assertTrue(
|
||||
self.db.test.insert(batch, continue_on_error=True, w=0))
|
||||
# Only the first and third documents should be inserted.
|
||||
self.assertEqual(2, self.db.test.count())
|
||||
self.assertTrue(
|
||||
self.db.test.insert(batch, continue_on_error=True, w=0))
|
||||
# Only the first and third documents should be inserted.
|
||||
wait_until(lambda: 2 == self.db.test.count(),
|
||||
'insert 2 documents')
|
||||
|
||||
def test_numerous_inserts(self):
|
||||
# Ensure we don't exceed server's 1000-document batch size limit.
|
||||
|
||||
@ -57,6 +57,7 @@ from test.utils import (
|
||||
ignore_deprecations,
|
||||
remove_all_users,
|
||||
rs_or_single_client_noauth,
|
||||
rs_or_single_client,
|
||||
server_started_with_auth)
|
||||
|
||||
|
||||
@ -285,38 +286,37 @@ class TestDatabase(IntegrationTest):
|
||||
|
||||
@client_context.require_no_mongos
|
||||
def test_errors(self):
|
||||
db = self.client.pymongo_test
|
||||
# We must call getlasterror, etc. on same socket as the last operation.
|
||||
db = rs_or_single_client(max_pool_size=1).pymongo_test
|
||||
db.reset_error_history()
|
||||
self.assertEqual(None, db.error())
|
||||
self.assertEqual(None, db.previous_error())
|
||||
|
||||
with self.client.start_request():
|
||||
db.reset_error_history()
|
||||
self.assertEqual(None, db.error())
|
||||
self.assertEqual(None, db.previous_error())
|
||||
db.command("forceerror", check=False)
|
||||
self.assertTrue(db.error())
|
||||
self.assertTrue(db.previous_error())
|
||||
|
||||
db.command("forceerror", check=False)
|
||||
self.assertTrue(db.error())
|
||||
self.assertTrue(db.previous_error())
|
||||
db.command("forceerror", check=False)
|
||||
self.assertTrue(db.error())
|
||||
prev_error = db.previous_error()
|
||||
self.assertEqual(prev_error["nPrev"], 1)
|
||||
del prev_error["nPrev"]
|
||||
prev_error.pop("lastOp", None)
|
||||
error = db.error()
|
||||
error.pop("lastOp", None)
|
||||
# getLastError includes "connectionId" in recent
|
||||
# server versions, getPrevError does not.
|
||||
error.pop("connectionId", None)
|
||||
self.assertEqual(error, prev_error)
|
||||
|
||||
db.command("forceerror", check=False)
|
||||
self.assertTrue(db.error())
|
||||
prev_error = db.previous_error()
|
||||
self.assertEqual(prev_error["nPrev"], 1)
|
||||
del prev_error["nPrev"]
|
||||
prev_error.pop("lastOp", None)
|
||||
error = db.error()
|
||||
error.pop("lastOp", None)
|
||||
# getLastError includes "connectionId" in recent
|
||||
# server versions, getPrevError does not.
|
||||
error.pop("connectionId", None)
|
||||
self.assertEqual(error, prev_error)
|
||||
db.test.find_one()
|
||||
self.assertEqual(None, db.error())
|
||||
self.assertTrue(db.previous_error())
|
||||
self.assertEqual(db.previous_error()["nPrev"], 2)
|
||||
|
||||
db.test.find_one()
|
||||
self.assertEqual(None, db.error())
|
||||
self.assertTrue(db.previous_error())
|
||||
self.assertEqual(db.previous_error()["nPrev"], 2)
|
||||
|
||||
db.reset_error_history()
|
||||
self.assertEqual(None, db.error())
|
||||
self.assertEqual(None, db.previous_error())
|
||||
db.reset_error_history()
|
||||
self.assertEqual(None, db.error())
|
||||
self.assertEqual(None, db.previous_error())
|
||||
|
||||
def test_command(self):
|
||||
db = self.client.admin
|
||||
@ -338,17 +338,16 @@ class TestDatabase(IntegrationTest):
|
||||
self.assertTrue(isinstance(result['result'][0]['r'], Regex))
|
||||
|
||||
def test_last_status(self):
|
||||
db = self.client.pymongo_test
|
||||
# We must call getlasterror on the same socket as the last operation.
|
||||
db = rs_or_single_client(max_pool_size=1).pymongo_test
|
||||
db.test.remove({})
|
||||
db.test.save({"i": 1})
|
||||
|
||||
with self.client.start_request():
|
||||
db.test.remove({})
|
||||
db.test.save({"i": 1})
|
||||
db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0)
|
||||
self.assertTrue(db.last_status()["updatedExisting"])
|
||||
|
||||
db.test.update({"i": 1}, {"$set": {"i": 2}}, w=0)
|
||||
self.assertTrue(db.last_status()["updatedExisting"])
|
||||
|
||||
db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0)
|
||||
self.assertFalse(db.last_status()["updatedExisting"])
|
||||
db.test.update({"i": 1}, {"$set": {"i": 500}}, w=0)
|
||||
self.assertFalse(db.last_status()["updatedExisting"])
|
||||
|
||||
def test_password_digest(self):
|
||||
self.assertRaises(TypeError, auth._password_digest, 5)
|
||||
|
||||
@ -369,28 +369,6 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
self.assertTrue(iterate_file(f))
|
||||
|
||||
def test_request(self):
|
||||
c = self.db.connection
|
||||
c.start_request()
|
||||
n = 5
|
||||
for i in range(n):
|
||||
file = self.fs.new_file(filename="test")
|
||||
file.write(b"hello")
|
||||
file.close()
|
||||
|
||||
c.end_request()
|
||||
|
||||
self.assertEqual(
|
||||
n,
|
||||
self.db.fs.files.find({'filename':'test'}).count()
|
||||
)
|
||||
|
||||
def test_gridfs_request(self):
|
||||
self.assertFalse(self.db.connection.in_request())
|
||||
self.fs.put(b"hello world")
|
||||
# Request started and ended by put(), we're back to original state
|
||||
self.assertFalse(self.db.connection.in_request())
|
||||
|
||||
def test_gridfs_lazy_connect(self):
|
||||
with client_knobs(server_wait_time=0.01):
|
||||
client = MongoClient('badhost', connect=False)
|
||||
|
||||
@ -22,20 +22,17 @@ import threading
|
||||
import time
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConfigurationError, ConnectionFailure, \
|
||||
ExceededMaxWaiters
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
ConnectionFailure,
|
||||
DuplicateKeyError,
|
||||
ExceededMaxWaiters)
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson.py3compat import thread
|
||||
from pymongo.pool import (NO_REQUEST,
|
||||
NO_SOCKET_YET,
|
||||
Pool,
|
||||
PoolOptions,
|
||||
SocketInfo, _closed)
|
||||
from pymongo.pool import Pool, PoolOptions, _closed
|
||||
from test import host, port, SkipTest, unittest, client_context
|
||||
from test.utils import get_pool, one, rs_or_single_client
|
||||
from test.utils import get_pool, one, rs_or_single_client, connected, joinall
|
||||
|
||||
|
||||
@client_context.require_connection
|
||||
@ -86,18 +83,18 @@ class SaveAndFind(MongoThread):
|
||||
class Unique(MongoThread):
|
||||
def run_mongo_thread(self):
|
||||
for _ in range(N):
|
||||
self.client.start_request()
|
||||
self.db.unique.insert({}) # no error
|
||||
self.client.end_request()
|
||||
|
||||
|
||||
class NonUnique(MongoThread):
|
||||
def run_mongo_thread(self):
|
||||
for _ in range(N):
|
||||
self.client.start_request()
|
||||
self.db.unique.insert({"_id": "jesse"}, w=0)
|
||||
assert self.db.error() is not None
|
||||
self.client.end_request()
|
||||
try:
|
||||
self.db.unique.insert({"_id": "jesse"})
|
||||
except DuplicateKeyError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Should have raised DuplicateKeyError")
|
||||
|
||||
|
||||
class Disconnect(MongoThread):
|
||||
@ -106,19 +103,6 @@ class Disconnect(MongoThread):
|
||||
self.client.disconnect()
|
||||
|
||||
|
||||
class NoRequest(MongoThread):
|
||||
def run_mongo_thread(self):
|
||||
self.client.start_request()
|
||||
errors = 0
|
||||
for _ in range(N):
|
||||
self.db.unique.insert({"_id": "jesse"}, w=0)
|
||||
if not self.db.error():
|
||||
errors += 1
|
||||
|
||||
self.client.end_request()
|
||||
assert errors == 0
|
||||
|
||||
|
||||
class SocketGetter(MongoThread):
|
||||
"""Utility for _TestMaxOpenSockets and _TestWaitQueueMultiple"""
|
||||
def __init__(self, client, pool):
|
||||
@ -150,73 +134,6 @@ def run_cases(client, cases):
|
||||
assert t.passed, "%s.run() threw an exception" % repr(t)
|
||||
|
||||
|
||||
class CreateAndReleaseSocket(MongoThread):
|
||||
"""Gets a socket, waits for all threads to reach rendezvous, and quits."""
|
||||
class Rendezvous(object):
|
||||
def __init__(self, nthreads):
|
||||
self.nthreads = nthreads
|
||||
self.nthreads_run = 0
|
||||
self.lock = threading.Lock()
|
||||
self.reset_ready()
|
||||
|
||||
def reset_ready(self):
|
||||
self.ready = threading.Event()
|
||||
|
||||
def __init__(self, client, start_request, end_request, rendezvous):
|
||||
super(CreateAndReleaseSocket, self).__init__(client)
|
||||
self.start_request = start_request
|
||||
self.end_request = end_request
|
||||
self.rendezvous = rendezvous
|
||||
|
||||
def run_mongo_thread(self):
|
||||
# Do an operation that requires a socket.
|
||||
# TestPoolMaxSize uses this to spin up lots of threads requiring
|
||||
# lots of simultaneous connections, to ensure that Pool obeys its
|
||||
# max_size configuration and closes extra sockets as they're returned.
|
||||
for i in range(self.start_request):
|
||||
self.client.start_request()
|
||||
|
||||
# Use a socket
|
||||
self.client[DB].test.find_one()
|
||||
|
||||
# Don't finish until all threads reach this point
|
||||
r = self.rendezvous
|
||||
r.lock.acquire()
|
||||
r.nthreads_run += 1
|
||||
if r.nthreads_run == r.nthreads:
|
||||
# Everyone's here, let them finish.
|
||||
r.ready.set()
|
||||
r.lock.release()
|
||||
else:
|
||||
r.lock.release()
|
||||
r.ready.wait(30) # Wait thirty seconds....
|
||||
assert r.ready.isSet(), "Rendezvous timed out"
|
||||
|
||||
for i in range(self.end_request):
|
||||
self.client.end_request()
|
||||
|
||||
|
||||
class CreateAndReleaseSocketNoRendezvous(MongoThread):
|
||||
"""Gets a socket and quits. No synchronization with other threads."""
|
||||
def __init__(self, client, start_request, end_request):
|
||||
super(CreateAndReleaseSocketNoRendezvous, self).__init__(client)
|
||||
self.start_request = start_request
|
||||
self.end_request = end_request
|
||||
|
||||
def run_mongo_thread(self):
|
||||
# Do an operation that requires a socket.
|
||||
# TestPoolMaxSize uses this to spin up lots of threads requiring
|
||||
# lots of simultaneous connections, to ensure that Pool obeys its
|
||||
# max_size configuration and closes extra sockets as they're returned.
|
||||
for i in range(self.start_request):
|
||||
self.client.start_request()
|
||||
|
||||
# Use a socket
|
||||
self.client[DB].test.find_one()
|
||||
for i in range(self.end_request):
|
||||
self.client.end_request()
|
||||
|
||||
|
||||
class _TestPoolingBase(unittest.TestCase):
|
||||
"""Base class for all connection-pool tests."""
|
||||
|
||||
@ -231,40 +148,6 @@ class _TestPoolingBase(unittest.TestCase):
|
||||
def create_pool(self, pair=(host, port), *args, **kwargs):
|
||||
return Pool(pair, PoolOptions(*args, **kwargs))
|
||||
|
||||
def assert_no_request(self):
|
||||
try:
|
||||
server = self.c._topology.select_server(
|
||||
writable_server_selector,
|
||||
server_wait_time=0)
|
||||
|
||||
self.assertEqual(NO_REQUEST, server.pool._get_request_state())
|
||||
except ConnectionFailure:
|
||||
# Success: we're asserting that we're not in a request, but there's
|
||||
# no pool at all so the assertion is true.
|
||||
pass
|
||||
|
||||
def assert_request_without_socket(self):
|
||||
self.assertEqual(NO_SOCKET_YET, get_pool(self.c)._get_request_state())
|
||||
|
||||
def assert_request_with_socket(self):
|
||||
self.assertTrue(isinstance(
|
||||
get_pool(self.c)._get_request_state(), SocketInfo))
|
||||
|
||||
def assert_pool_size(self, pool_size):
|
||||
if pool_size == 0:
|
||||
try:
|
||||
server = self.c._topology.select_server(
|
||||
writable_server_selector,
|
||||
server_wait_time=0)
|
||||
|
||||
self.assertEqual(0, len(server.pool.sockets))
|
||||
except ConnectionFailure:
|
||||
# Success: we're asserting that pool size is 0, and there's no
|
||||
# pool at all so the assertion is true.
|
||||
pass
|
||||
else:
|
||||
self.assertEqual(pool_size, len(get_pool(self.c).sockets))
|
||||
|
||||
|
||||
class TestPooling(_TestPoolingBase):
|
||||
def test_max_pool_size_validation(self):
|
||||
@ -280,95 +163,11 @@ class TestPooling(_TestPoolingBase):
|
||||
self.assertEqual(c.max_pool_size, 100)
|
||||
|
||||
def test_no_disconnect(self):
|
||||
run_cases(self.c, [NoRequest, NonUnique, Unique, SaveAndFind])
|
||||
|
||||
def test_simple_disconnect(self):
|
||||
# MongoClient just created, expect 1 free socket.
|
||||
self.assert_pool_size(1)
|
||||
self.assert_no_request()
|
||||
|
||||
self.c.start_request()
|
||||
self.assert_request_without_socket()
|
||||
cursor = self.c[DB].stuff.find()
|
||||
|
||||
# Cursor hasn't actually caused a request yet, so there's still 1 free
|
||||
# socket.
|
||||
self.assert_pool_size(1)
|
||||
self.assert_request_without_socket()
|
||||
|
||||
# Actually make a request to server, triggering a socket to be
|
||||
# allocated to the request.
|
||||
list(cursor)
|
||||
self.assert_pool_size(0)
|
||||
self.assert_request_with_socket()
|
||||
|
||||
# Pool returns to its original state
|
||||
self.c.end_request()
|
||||
self.assert_no_request()
|
||||
self.assert_pool_size(1)
|
||||
|
||||
self.c.disconnect()
|
||||
self.assert_pool_size(0)
|
||||
self.assert_no_request()
|
||||
run_cases(self.c, [NonUnique, Unique, SaveAndFind])
|
||||
|
||||
def test_disconnect(self):
|
||||
run_cases(self.c, [SaveAndFind, Disconnect, Unique])
|
||||
|
||||
def test_request(self):
|
||||
# Check that Pool gives two different sockets in two calls to
|
||||
# get_socket().
|
||||
cx_pool = self.create_pool(
|
||||
pair=(host, port),
|
||||
max_pool_size=10,
|
||||
connect_timeout=1000,
|
||||
socket_timeout=1000)
|
||||
|
||||
sock0 = cx_pool.get_socket({}, 0, 0)
|
||||
sock1 = cx_pool.get_socket({}, 0, 0)
|
||||
|
||||
self.assertNotEqual(sock0, sock1)
|
||||
|
||||
# Now in a request, we'll get the same socket both times
|
||||
cx_pool.start_request()
|
||||
|
||||
sock2 = cx_pool.get_socket({}, 0, 0)
|
||||
sock3 = cx_pool.get_socket({}, 0, 0)
|
||||
self.assertEqual(sock2, sock3)
|
||||
|
||||
# Pool didn't keep reference to sock0 or sock1; sock2 and 3 are new
|
||||
self.assertNotEqual(sock0, sock2)
|
||||
self.assertNotEqual(sock1, sock2)
|
||||
|
||||
# Return the request sock to pool
|
||||
cx_pool.end_request()
|
||||
|
||||
sock4 = cx_pool.get_socket({}, 0, 0)
|
||||
sock5 = cx_pool.get_socket({}, 0, 0)
|
||||
|
||||
# Not in a request any more, we get different sockets
|
||||
self.assertNotEqual(sock4, sock5)
|
||||
|
||||
# end_request() returned sock2 to pool
|
||||
self.assertEqual(sock4, sock2)
|
||||
|
||||
for s in [sock0, sock1, sock2, sock3, sock4, sock5]:
|
||||
s.close()
|
||||
|
||||
def test_reset_and_request(self):
|
||||
# reset() is called after a fork, or after a socket error. Ensure that
|
||||
# a new request is begun if a request was in progress when the reset()
|
||||
# occurred, otherwise no request is begun.
|
||||
p = self.create_pool(max_pool_size=10)
|
||||
self.assertFalse(p.in_request())
|
||||
p.start_request()
|
||||
self.assertTrue(p.in_request())
|
||||
p.reset()
|
||||
self.assertTrue(p.in_request())
|
||||
p.end_request()
|
||||
self.assertFalse(p.in_request())
|
||||
p.reset()
|
||||
self.assertFalse(p.in_request())
|
||||
|
||||
def test_pool_reuses_open_socket(self):
|
||||
# Test Pool's _check_closed() method doesn't close a healthy socket.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
@ -399,236 +198,6 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
def test_pool_removes_dead_request_socket_after_check(self):
|
||||
# Test that Pool keeps request going even if a socket dies in request.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get the request socket.
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
sock_info.sock.close()
|
||||
|
||||
# Although the request socket died, we're still in a request with a
|
||||
# new socket.
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertTrue(cx_pool.in_request())
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(new_sock_info, cx_pool._get_request_state())
|
||||
|
||||
self.assertEqual(new_sock_info, cx_pool._get_request_state())
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
|
||||
cx_pool.end_request()
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
def test_pool_removes_dead_request_socket(self):
|
||||
# Test that Pool keeps request going even if a socket dies in request.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get the request socket
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
|
||||
# Unlike in test_pool_removes_dead_request_socket_after_check, we
|
||||
# set sock_info.closed and *don't* wait for it to be checked.
|
||||
sock_info.close()
|
||||
|
||||
# Although the request socket died, we're still in a request with a
|
||||
# new socket
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertTrue(cx_pool.in_request())
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(new_sock_info, cx_pool._get_request_state())
|
||||
|
||||
self.assertEqual(new_sock_info, cx_pool._get_request_state())
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
|
||||
cx_pool.end_request()
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
def test_pool_removes_dead_socket_after_request(self):
|
||||
# Test that Pool handles a socket dying that *used* to be the request
|
||||
# socket.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get the request socket.
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
|
||||
# End request.
|
||||
cx_pool.end_request()
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
# Kill old request socket.
|
||||
sock_info.sock.close()
|
||||
|
||||
# Dead socket detected and removed.
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertFalse(cx_pool.in_request())
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertFalse(_closed(new_sock_info.sock))
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
def test_dead_request_socket_with_max_size(self):
|
||||
# When a pool replaces a dead request socket, the semaphore it uses
|
||||
# to enforce max_size should remain unaffected.
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get and close the request socket.
|
||||
with cx_pool.get_socket({}, 0, 0) as request_sock_info:
|
||||
request_sock_info.sock.close()
|
||||
|
||||
# Detects closed socket and creates new one, semaphore value still 0.
|
||||
with cx_pool.get_socket({}, 0, 0) as request_sock_info_2:
|
||||
self.assertNotEqual(request_sock_info, request_sock_info_2)
|
||||
|
||||
cx_pool.end_request()
|
||||
|
||||
# Semaphore value now 1; we can get a socket.
|
||||
sock = cx_pool.get_socket({}, 0, 0)
|
||||
sock.close()
|
||||
|
||||
def test_socket_reclamation(self):
|
||||
if sys.platform.startswith('java'):
|
||||
raise SkipTest("Jython can't do socket reclamation")
|
||||
|
||||
# Check that if a thread starts a request and dies without ending
|
||||
# the request, that the socket is reclaimed into the pool.
|
||||
cx_pool = self.create_pool(
|
||||
max_pool_size=10,
|
||||
connect_timeout=1000,
|
||||
socket_timeout=1000)
|
||||
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
|
||||
lock = None
|
||||
the_sock = [None]
|
||||
|
||||
def leak_request():
|
||||
self.assertEqual(NO_REQUEST, cx_pool._get_request_state())
|
||||
cx_pool.start_request()
|
||||
self.assertEqual(NO_SOCKET_YET, cx_pool._get_request_state())
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
the_sock[0] = id(sock_info.sock)
|
||||
|
||||
lock.release()
|
||||
|
||||
lock = thread.allocate_lock()
|
||||
lock.acquire()
|
||||
|
||||
# Start a thread WITHOUT a threading.Thread - important to test that
|
||||
# Pool can deal with primitive threads.
|
||||
thread.start_new_thread(leak_request, ())
|
||||
|
||||
# Join thread
|
||||
acquired = lock.acquire()
|
||||
self.assertTrue(acquired, "Thread is hung")
|
||||
|
||||
# Make sure thread is really gone
|
||||
time.sleep(1)
|
||||
|
||||
if 'PyPy' in sys.version:
|
||||
gc.collect()
|
||||
|
||||
# Access the thread local from the main thread to trigger the
|
||||
# ThreadVigil's delete callback, returning the request socket to
|
||||
# the pool.
|
||||
# In Python 2.7.0 and lesser, a dead thread's locals are deleted
|
||||
# and those locals' weakref callbacks are fired only when another
|
||||
# thread accesses the locals and finds the thread state is stale,
|
||||
# see http://bugs.python.org/issue1868. Accessing the thread
|
||||
# local from the main thread is a necessary part of this test, and
|
||||
# realistic: in a multithreaded web server a new thread will access
|
||||
# Pool._ident._local soon after an old thread has died.
|
||||
cx_pool._ident.get()
|
||||
|
||||
# Pool reclaimed the socket
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
self.assertEqual(the_sock[0], id(one(cx_pool.sockets).sock))
|
||||
self.assertEqual(0, len(cx_pool._tid_to_sock))
|
||||
|
||||
def test_request_with_fork(self):
|
||||
if sys.platform == "win32":
|
||||
raise SkipTest("Can't test forking on Windows")
|
||||
|
||||
try:
|
||||
from multiprocessing import Process, Pipe
|
||||
except ImportError:
|
||||
raise SkipTest("No multiprocessing module")
|
||||
|
||||
coll = self.c.pymongo_test.test
|
||||
coll.remove()
|
||||
coll.insert({'_id': 1})
|
||||
coll.find_one()
|
||||
self.assert_pool_size(1)
|
||||
self.c.start_request()
|
||||
self.assert_pool_size(1)
|
||||
coll.find_one()
|
||||
self.assert_pool_size(0)
|
||||
self.assert_request_with_socket()
|
||||
|
||||
def f(pipe):
|
||||
# We can still query server without error
|
||||
self.assertEqual({'_id':1}, coll.find_one())
|
||||
|
||||
# Pool has detected that we forked, but resumed the request
|
||||
self.assert_request_with_socket()
|
||||
self.assert_pool_size(0)
|
||||
pipe.send("success")
|
||||
|
||||
parent_conn, child_conn = Pipe()
|
||||
p = Process(target=f, args=(child_conn,))
|
||||
p.start()
|
||||
p.join(30)
|
||||
p.terminate()
|
||||
child_conn.close()
|
||||
self.assertEqual("success", parent_conn.recv())
|
||||
|
||||
def test_primitive_thread(self):
|
||||
p = self.create_pool(max_pool_size=10)
|
||||
|
||||
# Test that start/end_request work with a thread begun from thread
|
||||
# module, rather than threading module
|
||||
lock = thread.allocate_lock()
|
||||
lock.acquire()
|
||||
|
||||
sock_ids = []
|
||||
|
||||
def run_in_request():
|
||||
p.start_request()
|
||||
sock0 = p.get_socket({}, 0, 0)
|
||||
sock1 = p.get_socket({}, 0, 0)
|
||||
sock_ids.extend([id(sock0), id(sock1)])
|
||||
p.maybe_return_socket(sock0)
|
||||
p.maybe_return_socket(sock1)
|
||||
p.end_request()
|
||||
lock.release()
|
||||
|
||||
thread.start_new_thread(run_in_request, ())
|
||||
|
||||
# Join thread
|
||||
acquired = False
|
||||
for i in range(30):
|
||||
time.sleep(0.5)
|
||||
acquired = lock.acquire(0)
|
||||
if acquired:
|
||||
break
|
||||
|
||||
self.assertTrue(acquired, "Thread is hung")
|
||||
self.assertEqual(sock_ids[0], sock_ids[1])
|
||||
|
||||
def test_pool_with_fork(self):
|
||||
# Test that separate MongoClients have separate Pools, and that the
|
||||
# driver can create a new MongoClient after forking
|
||||
@ -770,204 +339,61 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
|
||||
class TestPoolMaxSize(_TestPoolingBase):
|
||||
"""Keep right number of sockets in various start/end_request sequences.
|
||||
"""
|
||||
def _test_max_pool_size(
|
||||
self, start_request, end_request, max_pool_size=4, nthreads=10):
|
||||
"""Start `nthreads` threads. Each calls start_request `start_request`
|
||||
times, then find_one and waits at a barrier; once all reach the barrier
|
||||
each calls end_request `end_request` times. The test asserts that the
|
||||
pool ends with min(max_pool_size, nthreads) sockets or, if
|
||||
start_request wasn't called, at least one socket.
|
||||
|
||||
This tests both max_pool_size enforcement and that leaked request
|
||||
sockets are eventually returned to the pool when their threads end.
|
||||
|
||||
You may need to increase ulimit -n on Mac.
|
||||
"""
|
||||
if start_request:
|
||||
if max_pool_size is not None and max_pool_size < nthreads:
|
||||
raise AssertionError("Deadlock")
|
||||
|
||||
c = rs_or_single_client(max_pool_size=max_pool_size)
|
||||
rendezvous = CreateAndReleaseSocket.Rendezvous(nthreads)
|
||||
|
||||
threads = []
|
||||
for i in range(nthreads):
|
||||
t = CreateAndReleaseSocket(
|
||||
c, start_request, end_request, rendezvous)
|
||||
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
if 'PyPy' in sys.version:
|
||||
# With PyPy we need to kick off the gc whenever the threads hit the
|
||||
# rendezvous since nthreads > max_pool_size.
|
||||
gc_collect_until_done(threads)
|
||||
else:
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# join() returns before the thread state is cleared; give it time.
|
||||
time.sleep(1)
|
||||
|
||||
for t in threads:
|
||||
self.assertTrue(t.passed)
|
||||
|
||||
# Socket-reclamation doesn't work in Jython
|
||||
if not sys.platform.startswith('java'):
|
||||
cx_pool = get_pool(c)
|
||||
|
||||
# Socket-reclamation depends on timely garbage-collection
|
||||
if 'PyPy' in sys.version:
|
||||
gc.collect()
|
||||
|
||||
if start_request:
|
||||
# Trigger final cleanup in Python <= 2.7.0.
|
||||
cx_pool._ident.get()
|
||||
expected_idle = min(max_pool_size, nthreads)
|
||||
message = (
|
||||
'%d idle sockets (expected %d) and %d request sockets'
|
||||
' (expected 0)' % (
|
||||
len(cx_pool.sockets), expected_idle,
|
||||
len(cx_pool._tid_to_sock)))
|
||||
|
||||
self.assertEqual(
|
||||
expected_idle, len(cx_pool.sockets), message)
|
||||
else:
|
||||
# Without calling start_request(), threads can safely share
|
||||
# sockets; the number running concurrently, and hence the
|
||||
# number of sockets needed, is between 1 and 10, depending
|
||||
# on thread-scheduling.
|
||||
self.assertTrue(len(cx_pool.sockets) >= 1)
|
||||
|
||||
# thread.join completes slightly *before* thread locals are
|
||||
# cleaned up, so wait up to 5 seconds for them.
|
||||
time.sleep(0.1)
|
||||
cx_pool._ident.get()
|
||||
start = time.time()
|
||||
|
||||
while (
|
||||
not cx_pool.sockets
|
||||
and cx_pool._socket_semaphore.counter < max_pool_size
|
||||
and (time.time() - start) < 5
|
||||
):
|
||||
time.sleep(0.1)
|
||||
cx_pool._ident.get()
|
||||
|
||||
if max_pool_size is not None:
|
||||
self.assertEqual(
|
||||
max_pool_size,
|
||||
cx_pool._socket_semaphore.counter)
|
||||
|
||||
self.assertEqual(0, len(cx_pool._tid_to_sock))
|
||||
|
||||
def _test_max_pool_size_no_rendezvous(self, start_request, end_request):
|
||||
max_pool_size = 5
|
||||
c = rs_or_single_client(max_pool_size=max_pool_size)
|
||||
def test_max_pool_size(self):
|
||||
max_pool_size = 4
|
||||
c = connected(rs_or_single_client(max_pool_size=max_pool_size))
|
||||
cx_pool = get_pool(c)
|
||||
|
||||
# nthreads had better be much larger than max_pool_size to ensure that
|
||||
# max_pool_size sockets are actually required at some point in this
|
||||
# test's execution.
|
||||
nthreads = 10
|
||||
|
||||
if (sys.platform.startswith('java')
|
||||
and start_request > end_request
|
||||
and nthreads > max_pool_size):
|
||||
|
||||
# Since Jython can't reclaim the socket and release the semaphore
|
||||
# after a thread leaks a request, we'll exhaust the semaphore and
|
||||
# deadlock.
|
||||
raise SkipTest("Jython can't do socket reclamation")
|
||||
|
||||
threads = []
|
||||
for i in range(nthreads):
|
||||
t = CreateAndReleaseSocketNoRendezvous(
|
||||
c, start_request, end_request)
|
||||
|
||||
threads.append(t)
|
||||
lock = threading.Lock()
|
||||
self.n_passed = 0
|
||||
|
||||
for t in threads:
|
||||
def f():
|
||||
for _ in range(N):
|
||||
c[DB].test.find_one()
|
||||
assert len(cx_pool.sockets) <= max_pool_size
|
||||
|
||||
with lock:
|
||||
self.n_passed += 1
|
||||
|
||||
for i in range(nthreads):
|
||||
t = threading.Thread(target=f)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
if 'PyPy' in sys.version:
|
||||
# With PyPy we need to kick off the gc whenever the threads hit the
|
||||
# rendezvous since nthreads > max_pool_size.
|
||||
gc_collect_until_done(threads)
|
||||
else:
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
for t in threads:
|
||||
self.assertTrue(t.passed)
|
||||
|
||||
cx_pool = get_pool(c)
|
||||
|
||||
# Socket-reclamation depends on timely garbage-collection
|
||||
if 'PyPy' in sys.version:
|
||||
gc.collect()
|
||||
|
||||
# thread.join completes slightly *before* thread locals are
|
||||
# cleaned up, so wait up to 5 seconds for them.
|
||||
time.sleep(0.1)
|
||||
cx_pool._ident.get()
|
||||
start = time.time()
|
||||
|
||||
while (
|
||||
not cx_pool.sockets
|
||||
and cx_pool._socket_semaphore.counter < max_pool_size
|
||||
and (time.time() - start) < 5
|
||||
):
|
||||
time.sleep(0.1)
|
||||
cx_pool._ident.get()
|
||||
|
||||
self.assertTrue(len(cx_pool.sockets) >= 1)
|
||||
joinall(threads)
|
||||
self.assertEqual(nthreads, self.n_passed)
|
||||
self.assertTrue(len(cx_pool.sockets) > 1)
|
||||
self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter)
|
||||
|
||||
def test_max_pool_size(self):
|
||||
self._test_max_pool_size(
|
||||
start_request=0, end_request=0, nthreads=10, max_pool_size=4)
|
||||
|
||||
def test_max_pool_size_none(self):
|
||||
self._test_max_pool_size(
|
||||
start_request=0, end_request=0, nthreads=10, max_pool_size=None)
|
||||
c = connected(rs_or_single_client(max_pool_size=None))
|
||||
cx_pool = get_pool(c)
|
||||
|
||||
def test_max_pool_size_with_request(self):
|
||||
self._test_max_pool_size(
|
||||
start_request=1, end_request=1, nthreads=10, max_pool_size=10)
|
||||
nthreads = 10
|
||||
threads = []
|
||||
lock = threading.Lock()
|
||||
self.n_passed = 0
|
||||
|
||||
def test_max_pool_size_with_multiple_request(self):
|
||||
self._test_max_pool_size(
|
||||
start_request=10, end_request=10, nthreads=10, max_pool_size=10)
|
||||
def f():
|
||||
for _ in range(N):
|
||||
c[DB].test.find_one()
|
||||
|
||||
def test_max_pool_size_with_redundant_request(self):
|
||||
self._test_max_pool_size(
|
||||
start_request=2, end_request=1, nthreads=10, max_pool_size=10)
|
||||
with lock:
|
||||
self.n_passed += 1
|
||||
|
||||
def test_max_pool_size_with_redundant_request2(self):
|
||||
self._test_max_pool_size(
|
||||
start_request=20, end_request=1, nthreads=10, max_pool_size=10)
|
||||
for i in range(nthreads):
|
||||
t = threading.Thread(target=f)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
def test_max_pool_size_with_redundant_request_no_rendezvous(self):
|
||||
self._test_max_pool_size_no_rendezvous(2, 1)
|
||||
|
||||
def test_max_pool_size_with_redundant_request_no_rendezvous2(self):
|
||||
self._test_max_pool_size_no_rendezvous(20, 1)
|
||||
|
||||
def test_max_pool_size_with_leaked_request(self):
|
||||
# Call start_request() but not end_request() -- when threads die, they
|
||||
# should return their request sockets to the pool.
|
||||
self._test_max_pool_size(
|
||||
start_request=1, end_request=0, nthreads=10, max_pool_size=10)
|
||||
|
||||
def test_max_pool_size_with_leaked_request_no_rendezvous(self):
|
||||
self._test_max_pool_size_no_rendezvous(1, 0)
|
||||
|
||||
def test_max_pool_size_with_end_request_only(self):
|
||||
# Call end_request() but not start_request()
|
||||
self._test_max_pool_size(0, 1)
|
||||
joinall(threads)
|
||||
self.assertEqual(nthreads, self.n_passed)
|
||||
self.assertTrue(len(cx_pool.sockets) >= 1)
|
||||
|
||||
def test_max_pool_size_with_connection_failure(self):
|
||||
# The pool acquires its semaphore before attempting to connect; ensure
|
||||
|
||||
@ -38,17 +38,19 @@ from test import (client_context,
|
||||
IntegrationTest,
|
||||
pair,
|
||||
port,
|
||||
SkipTest,
|
||||
unittest,
|
||||
db_pwd,
|
||||
db_user,
|
||||
MockClientTest)
|
||||
from test.pymongo_mocks import MockClient
|
||||
from test.utils import (
|
||||
delay, assertReadFrom, assertReadFromAll, ignore_deprecations,
|
||||
read_from_which_host, assertRaisesExactly, TestRequestMixin, get_pools,
|
||||
connected, wait_until, single_client, rs_or_single_client, one,
|
||||
rs_client)
|
||||
from test.utils import (assertRaisesExactly,
|
||||
connected,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
one,
|
||||
rs_client,
|
||||
single_client,
|
||||
wait_until)
|
||||
from test.version import Version
|
||||
|
||||
|
||||
@ -79,7 +81,7 @@ class TestReplicaSetClientBase(IntegrationTest):
|
||||
if m['stateStr'] == 'SECONDARY')
|
||||
|
||||
|
||||
class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
|
||||
class TestReplicaSetClient(TestReplicaSetClientBase):
|
||||
def test_deprecated(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", DeprecationWarning)
|
||||
@ -288,88 +290,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
|
||||
def test_kill_cursor_explicit_secondary(self):
|
||||
self._test_kill_cursor_explicit(ReadPreference.SECONDARY)
|
||||
|
||||
def test_nested_request(self):
|
||||
client = rs_or_single_client()
|
||||
connected(client)
|
||||
client.start_request()
|
||||
try:
|
||||
pools = get_pools(client)
|
||||
self.assertTrue(client.in_request())
|
||||
|
||||
# Start and end request - we're still in "outer" original request
|
||||
client.start_request()
|
||||
self.assertInRequestAndSameSock(client, pools)
|
||||
client.end_request()
|
||||
self.assertInRequestAndSameSock(client, pools)
|
||||
|
||||
# Double-nesting
|
||||
client.start_request()
|
||||
client.start_request()
|
||||
|
||||
for pool in pools:
|
||||
# Client only called start_request() once per pool, although
|
||||
# its own counter is 3.
|
||||
self.assertEqual(1, pool._request_counter.get())
|
||||
|
||||
client.end_request()
|
||||
client.end_request()
|
||||
self.assertInRequestAndSameSock(client, pools)
|
||||
|
||||
for pool in pools:
|
||||
self.assertEqual(1, pool._request_counter.get())
|
||||
|
||||
# Finally, end original request
|
||||
client.end_request()
|
||||
for pool in pools:
|
||||
self.assertFalse(pool.in_request())
|
||||
|
||||
self.assertNotInRequestAndDifferentSock(client, pools)
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
def test_pinned_member(self):
|
||||
raise SkipTest("Secondary pinning not implemented in PyMongo 3")
|
||||
|
||||
latency = 1000 * 1000
|
||||
client = rs_client(secondaryacceptablelatencyms=latency)
|
||||
|
||||
host = read_from_which_host(client, ReadPreference.SECONDARY)
|
||||
self.assertTrue(host in client.secondaries)
|
||||
|
||||
# No pinning since we're not in a request
|
||||
assertReadFromAll(
|
||||
self, client, client.secondaries,
|
||||
ReadPreference.SECONDARY, None)
|
||||
|
||||
assertReadFromAll(
|
||||
self, client, list(client.secondaries) + [client.primary],
|
||||
ReadPreference.NEAREST, None)
|
||||
|
||||
client.start_request()
|
||||
host = read_from_which_host(client, ReadPreference.SECONDARY)
|
||||
self.assertTrue(host in client.secondaries)
|
||||
assertReadFrom(self, client, host, ReadPreference.SECONDARY)
|
||||
|
||||
# Changing any part of read preference (mode, tag_sets)
|
||||
# unpins the current host and pins to a new one
|
||||
primary = client.primary
|
||||
assertReadFrom(self, client, primary, ReadPreference.PRIMARY_PREFERRED)
|
||||
|
||||
host = read_from_which_host(client, ReadPreference.NEAREST)
|
||||
assertReadFrom(self, client, host, ReadPreference.NEAREST)
|
||||
|
||||
assertReadFrom(self, client, primary, ReadPreference.PRIMARY_PREFERRED)
|
||||
|
||||
host = read_from_which_host(client, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertTrue(host in client.secondaries)
|
||||
assertReadFrom(self, client, host, ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
# Unpin
|
||||
client.end_request()
|
||||
assertReadFromAll(
|
||||
self, client, list(client.secondaries) + [client.primary],
|
||||
ReadPreference.NEAREST, None)
|
||||
|
||||
def test_not_master_error(self):
|
||||
secondary_address = one(self.secondaries)
|
||||
direct_client = single_client(*secondary_address)
|
||||
|
||||
@ -1,158 +0,0 @@
|
||||
# Copyright 2012-2014 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the thread_util module."""
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo import thread_util
|
||||
from test import SkipTest, unittest
|
||||
from test.utils import RendezvousThread
|
||||
|
||||
|
||||
class TestIdent(unittest.TestCase):
|
||||
def test_thread_ident(self):
|
||||
# 1. Store main thread's id.
|
||||
# 2. Start 2 child threads.
|
||||
# 3. Store their values for Ident.get().
|
||||
# 4. Children reach rendezvous point.
|
||||
# 5. Children call Ident.watch().
|
||||
# 6. One of the children calls Ident.unwatch().
|
||||
# 7. Children terminate.
|
||||
# 8. Assert that children got different ids from each other and from
|
||||
# main, and assert watched child's callback was executed, and that
|
||||
# unwatched child's callback was not.
|
||||
|
||||
if 'java' in sys.platform:
|
||||
raise SkipTest("Can't rely on weakref callbacks in Jython")
|
||||
|
||||
ident = thread_util.ThreadIdent()
|
||||
|
||||
ids = set([ident.get()])
|
||||
unwatched_id = []
|
||||
done = set([ident.get()]) # Start with main thread's id.
|
||||
died = set()
|
||||
|
||||
class WatchedThread(RendezvousThread):
|
||||
def __init__(self, ident, state):
|
||||
super(WatchedThread, self).__init__(state)
|
||||
self._my_ident = ident
|
||||
|
||||
def before_rendezvous(self):
|
||||
self.my_id = self._my_ident.get()
|
||||
ids.add(self.my_id)
|
||||
|
||||
def after_rendezvous(self):
|
||||
assert not self._my_ident.watching()
|
||||
self._my_ident.watch(lambda ref: died.add(self.my_id))
|
||||
assert self._my_ident.watching()
|
||||
done.add(self.my_id)
|
||||
|
||||
class UnwatchedThread(WatchedThread):
|
||||
def before_rendezvous(self):
|
||||
super(UnwatchedThread, self).before_rendezvous()
|
||||
unwatched_id.append(self.my_id)
|
||||
|
||||
def after_rendezvous(self):
|
||||
super(UnwatchedThread, self).after_rendezvous()
|
||||
self._my_ident.unwatch(self.my_id)
|
||||
assert not self._my_ident.watching()
|
||||
|
||||
state = RendezvousThread.create_shared_state(2)
|
||||
t_watched = WatchedThread(ident, state)
|
||||
t_watched.start()
|
||||
|
||||
t_unwatched = UnwatchedThread(ident, state)
|
||||
t_unwatched.start()
|
||||
|
||||
RendezvousThread.wait_for_rendezvous(state)
|
||||
RendezvousThread.resume_after_rendezvous(state)
|
||||
|
||||
t_watched.join()
|
||||
t_unwatched.join()
|
||||
|
||||
self.assertTrue(t_watched.passed)
|
||||
self.assertTrue(t_unwatched.passed)
|
||||
|
||||
# Remove references, let weakref callbacks run
|
||||
del t_watched
|
||||
del t_unwatched
|
||||
|
||||
# Trigger final cleanup in Python <= 2.7.0.
|
||||
# http://bugs.python.org/issue1868
|
||||
ident.get()
|
||||
self.assertEqual(3, len(ids))
|
||||
self.assertEqual(3, len(done))
|
||||
|
||||
# Make sure thread is really gone
|
||||
slept = 0
|
||||
while not died and slept < 10:
|
||||
time.sleep(1)
|
||||
gc.collect()
|
||||
slept += 1
|
||||
|
||||
self.assertEqual(1, len(died))
|
||||
self.assertFalse(unwatched_id[0] in died)
|
||||
|
||||
|
||||
class TestCounter(unittest.TestCase):
|
||||
def test_counter(self):
|
||||
counter = thread_util.Counter()
|
||||
|
||||
self.assertEqual(0, counter.dec())
|
||||
self.assertEqual(0, counter.get())
|
||||
self.assertEqual(0, counter.dec())
|
||||
self.assertEqual(0, counter.get())
|
||||
|
||||
done = set()
|
||||
|
||||
def f(n):
|
||||
for i in range(n):
|
||||
self.assertEqual(i, counter.get())
|
||||
self.assertEqual(i + 1, counter.inc())
|
||||
|
||||
for i in range(n, 0, -1):
|
||||
self.assertEqual(i, counter.get())
|
||||
self.assertEqual(i - 1, counter.dec())
|
||||
|
||||
self.assertEqual(0, counter.get())
|
||||
|
||||
# Extra decrements have no effect
|
||||
self.assertEqual(0, counter.dec())
|
||||
self.assertEqual(0, counter.get())
|
||||
self.assertEqual(0, counter.dec())
|
||||
self.assertEqual(0, counter.get())
|
||||
|
||||
done.add(n)
|
||||
|
||||
threads = [threading.Thread(target=partial(f, i))
|
||||
for i in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(10, len(done))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -17,14 +17,9 @@
|
||||
import threading
|
||||
|
||||
from test import unittest, client_context, IntegrationTest, db_user, db_pwd
|
||||
from test.utils import (frequent_thread_switches,
|
||||
joinall,
|
||||
RendezvousThread,
|
||||
rs_or_single_client,
|
||||
rs_or_single_client_noauth)
|
||||
from test.utils import get_pool
|
||||
from pymongo.pool import SocketInfo, _closed
|
||||
from pymongo.errors import AutoReconnect, OperationFailure
|
||||
from test.utils import rs_or_single_client_noauth
|
||||
from test.utils import frequent_thread_switches, joinall
|
||||
from pymongo.errors import OperationFailure
|
||||
|
||||
|
||||
@client_context.require_connection
|
||||
@ -130,57 +125,6 @@ class Disconnect(threading.Thread):
|
||||
self.passed = True
|
||||
|
||||
|
||||
class IgnoreAutoReconnect(threading.Thread):
|
||||
|
||||
def __init__(self, collection, n):
|
||||
threading.Thread.__init__(self)
|
||||
self.c = collection
|
||||
self.n = n
|
||||
self.setDaemon(True)
|
||||
|
||||
def run(self):
|
||||
for _ in range(self.n):
|
||||
try:
|
||||
self.c.find_one()
|
||||
except AutoReconnect:
|
||||
pass
|
||||
|
||||
|
||||
class FindPauseFind(RendezvousThread):
|
||||
"""See test_server_disconnect() for details"""
|
||||
def __init__(self, collection, state):
|
||||
"""Params:
|
||||
`collection`: A collection for testing
|
||||
`state`: A shared state object from RendezvousThread.shared_state()
|
||||
"""
|
||||
super(FindPauseFind, self).__init__(state)
|
||||
self.collection = collection
|
||||
|
||||
def before_rendezvous(self):
|
||||
# acquire a socket
|
||||
client = self.collection.database.connection
|
||||
client.start_request()
|
||||
list(self.collection.find())
|
||||
|
||||
pool = get_pool(client)
|
||||
socket_info = pool._get_request_state()
|
||||
assert isinstance(socket_info, SocketInfo)
|
||||
self.request_sock = socket_info.sock
|
||||
assert not _closed(self.request_sock)
|
||||
|
||||
def after_rendezvous(self):
|
||||
# test_server_disconnect() has closed this socket, but that's ok
|
||||
# because it's not our request socket anymore
|
||||
assert _closed(self.request_sock)
|
||||
|
||||
# if disconnect() properly replaced the pool, then this won't raise
|
||||
# AutoReconnect because it will acquire a new socket
|
||||
list(self.collection.find())
|
||||
assert self.collection.database.connection.in_request()
|
||||
pool = get_pool(self.collection.database.connection)
|
||||
assert self.request_sock != pool._get_request_state().sock
|
||||
|
||||
|
||||
class TestThreads(IntegrationTest):
|
||||
def setUp(self):
|
||||
self.db = client_context.rs_or_standalone_client.pymongo_test
|
||||
@ -236,71 +180,6 @@ class TestThreads(IntegrationTest):
|
||||
error.join()
|
||||
okay.join()
|
||||
|
||||
def test_server_disconnect(self):
|
||||
# PYTHON-345, we need to make sure that threads' request sockets are
|
||||
# closed by disconnect().
|
||||
#
|
||||
# 1. Create a client and start a request.
|
||||
# 2. Start N threads and do a find() in each to get a request socket
|
||||
# 3. Pause all threads
|
||||
# 4. In the main thread close all sockets, including threads' request
|
||||
# sockets
|
||||
# 5. In main thread, do a find(), which raises AutoReconnect and resets
|
||||
# pool
|
||||
# 6. Resume all threads, do a find() in them
|
||||
#
|
||||
# If we've fixed PYTHON-345, then only one AutoReconnect is raised,
|
||||
# and all the threads get new request sockets.
|
||||
cx = rs_or_single_client()
|
||||
cx.start_request()
|
||||
collection = cx.db.pymongo_test
|
||||
|
||||
# acquire a request socket for the main thread
|
||||
collection.find_one()
|
||||
pool = get_pool(collection.database.connection)
|
||||
socket_info = pool._get_request_state()
|
||||
assert isinstance(socket_info, SocketInfo)
|
||||
request_sock = socket_info.sock
|
||||
|
||||
state = FindPauseFind.create_shared_state(nthreads=10)
|
||||
|
||||
threads = [
|
||||
FindPauseFind(collection, state)
|
||||
for _ in range(state.nthreads)
|
||||
]
|
||||
|
||||
# Each thread does a find(), thus acquiring a request socket
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# Wait for the threads to reach the rendezvous
|
||||
FindPauseFind.wait_for_rendezvous(state)
|
||||
|
||||
try:
|
||||
# Simulate an event that closes all sockets, e.g. primary stepdown
|
||||
for t in threads:
|
||||
t.request_sock.close()
|
||||
|
||||
# Finally, ensure the main thread's socket's last_checkout is
|
||||
# updated:
|
||||
collection.find_one()
|
||||
|
||||
# ... and close it:
|
||||
request_sock.close()
|
||||
|
||||
# Doing an operation on the client raises an AutoReconnect and
|
||||
# resets the pool behind the scenes
|
||||
self.assertRaises(AutoReconnect, collection.find_one)
|
||||
|
||||
finally:
|
||||
# Let threads do a second find()
|
||||
FindPauseFind.resume_after_rendezvous(state)
|
||||
|
||||
joinall(threads)
|
||||
|
||||
for t in threads:
|
||||
self.assertTrue(t.passed, "%s threw exception" % t)
|
||||
|
||||
def test_client_disconnect(self):
|
||||
self.db.drop_collection("test")
|
||||
for i in range(1000):
|
||||
|
||||
144
test/utils.py
144
test/utils.py
@ -26,7 +26,6 @@ from functools import partial
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import AutoReconnect, OperationFailure
|
||||
from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo
|
||||
from pymongo.server_selectors import (any_server_selector,
|
||||
writable_server_selector)
|
||||
from test import (client_context,
|
||||
@ -280,104 +279,6 @@ def ignore_deprecations():
|
||||
yield
|
||||
|
||||
|
||||
class RendezvousThread(threading.Thread):
|
||||
"""A thread that starts and pauses at a rendezvous point before resuming.
|
||||
To be used in tests that must ensure that N threads are all alive
|
||||
simultaneously, regardless of thread-scheduling's vagaries.
|
||||
|
||||
1. Write a subclass of RendezvousThread and override before_rendezvous
|
||||
and / or after_rendezvous.
|
||||
2. Create a state with RendezvousThread.shared_state(N)
|
||||
3. Start N of your subclassed RendezvousThreads, passing the state to each
|
||||
one's __init__
|
||||
4. In the main thread, call RendezvousThread.wait_for_rendezvous
|
||||
5. Test whatever you need to test while threads are paused at rendezvous
|
||||
point
|
||||
6. In main thread, call RendezvousThread.resume_after_rendezvous
|
||||
7. Join all threads from main thread
|
||||
8. Assert that all threads' "passed" attribute is True
|
||||
9. Test post-conditions
|
||||
"""
|
||||
|
||||
class RendezvousState(object):
|
||||
def __init__(self, nthreads):
|
||||
# Number of threads total
|
||||
self.nthreads = nthreads
|
||||
|
||||
# Number of threads that have arrived at rendezvous point
|
||||
self.arrived_threads = 0
|
||||
self.arrived_threads_lock = threading.Lock()
|
||||
|
||||
# Set when all threads reach rendezvous
|
||||
self.ev_arrived = threading.Event()
|
||||
|
||||
# Set by resume_after_rendezvous() so threads can continue.
|
||||
self.ev_resume = threading.Event()
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_shared_state(cls, nthreads):
|
||||
return RendezvousThread.RendezvousState(nthreads)
|
||||
|
||||
def before_rendezvous(self):
|
||||
"""Overridable: Do this before the rendezvous"""
|
||||
pass
|
||||
|
||||
def after_rendezvous(self):
|
||||
"""Overridable: Do this after the rendezvous. If it throws no exception,
|
||||
`passed` is set to True
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def wait_for_rendezvous(cls, state):
|
||||
"""Wait for all threads to reach rendezvous and pause there"""
|
||||
state.ev_arrived.wait(10)
|
||||
assert state.ev_arrived.isSet(), "Thread timeout"
|
||||
assert state.nthreads == state.arrived_threads
|
||||
|
||||
@classmethod
|
||||
def resume_after_rendezvous(cls, state):
|
||||
"""Tell all the paused threads to continue"""
|
||||
state.ev_resume.set()
|
||||
|
||||
def __init__(self, state):
|
||||
"""Params:
|
||||
`state`: A shared state object from RendezvousThread.shared_state()
|
||||
"""
|
||||
super(RendezvousThread, self).__init__()
|
||||
self.state = state
|
||||
self.passed = False
|
||||
|
||||
# If this thread fails to terminate, don't hang the whole program
|
||||
self.setDaemon(True)
|
||||
|
||||
def _rendezvous(self):
|
||||
"""Pause until all threads arrive here"""
|
||||
s = self.state
|
||||
s.arrived_threads_lock.acquire()
|
||||
s.arrived_threads += 1
|
||||
if s.arrived_threads == s.nthreads:
|
||||
s.arrived_threads_lock.release()
|
||||
s.ev_arrived.set()
|
||||
else:
|
||||
s.arrived_threads_lock.release()
|
||||
s.ev_arrived.wait()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.before_rendezvous()
|
||||
finally:
|
||||
self._rendezvous()
|
||||
|
||||
# all threads have passed the rendezvous, wait for
|
||||
# resume_after_rendezvous()
|
||||
self.state.ev_resume.wait()
|
||||
|
||||
self.after_rendezvous()
|
||||
self.passed = True
|
||||
|
||||
|
||||
def read_from_which_host(
|
||||
client,
|
||||
pref,
|
||||
@ -467,51 +368,6 @@ def get_pools(client):
|
||||
client._get_topology().select_servers(any_server_selector)]
|
||||
|
||||
|
||||
class TestRequestMixin(object):
|
||||
"""Inherit from this class and from unittest.TestCase to get some
|
||||
convenient methods for testing connection pools and requests
|
||||
"""
|
||||
|
||||
def assertSameSock(self, pool):
|
||||
sock_info0 = pool.get_socket({}, 0, 0)
|
||||
sock_info1 = pool.get_socket({}, 0, 0)
|
||||
self.assertEqual(sock_info0, sock_info1)
|
||||
pool.maybe_return_socket(sock_info0)
|
||||
pool.maybe_return_socket(sock_info1)
|
||||
|
||||
def assertDifferentSock(self, pool):
|
||||
sock_info0 = pool.get_socket({}, 0, 0)
|
||||
sock_info1 = pool.get_socket({}, 0, 0)
|
||||
self.assertNotEqual(sock_info0, sock_info1)
|
||||
pool.maybe_return_socket(sock_info0)
|
||||
pool.maybe_return_socket(sock_info1)
|
||||
|
||||
def assertNoRequest(self, pool):
|
||||
self.assertEqual(NO_REQUEST, pool._get_request_state())
|
||||
|
||||
def assertNoSocketYet(self, pool):
|
||||
self.assertEqual(NO_SOCKET_YET, pool._get_request_state())
|
||||
|
||||
def assertRequestSocket(self, pool):
|
||||
self.assertTrue(isinstance(pool._get_request_state(), SocketInfo))
|
||||
|
||||
def assertInRequestAndSameSock(self, client, pools):
|
||||
self.assertTrue(client.in_request())
|
||||
if not isinstance(pools, list):
|
||||
pools = [pools]
|
||||
for pool in pools:
|
||||
self.assertTrue(pool.in_request())
|
||||
self.assertSameSock(pool)
|
||||
|
||||
def assertNotInRequestAndDifferentSock(self, client, pools):
|
||||
self.assertFalse(client.in_request())
|
||||
if not isinstance(pools, list):
|
||||
pools = [pools]
|
||||
for pool in pools:
|
||||
self.assertFalse(pool.in_request())
|
||||
self.assertDifferentSock(pool)
|
||||
|
||||
|
||||
# Constants for run_threads and lazy_client_trial.
|
||||
NTRIALS = 5
|
||||
NTHREADS = 10
|
||||
|
||||
Loading…
Reference in New Issue
Block a user