PYTHON-3064 Add typings to test package (#844)

This commit is contained in:
Steven Silvester 2022-02-07 19:33:41 -06:00 committed by GitHub
parent 561ee7cf77
commit f4cef37328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 540 additions and 259 deletions

View File

@ -46,4 +46,5 @@ jobs:
pip install -e ".[zstd, srv]"
- name: Run mypy
run: |
mypy --install-types --non-interactive bson gridfs tools
mypy --install-types --non-interactive bson gridfs tools pymongo
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test

View File

@ -28,7 +28,7 @@ from typing import (Any, Dict, Iterable, Iterator, List, Mapping,
# This is essentially the same as re._pattern_type
RE_TYPE: Type[Pattern[Any]] = type(re.compile(""))
_Key = TypeVar("_Key", bound=str)
_Key = TypeVar("_Key")
_Value = TypeVar("_Value")
_T = TypeVar("_T")

View File

@ -11,6 +11,9 @@ warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true
[mypy-gevent.*]
ignore_missing_imports = True
[mypy-kerberos.*]
ignore_missing_imports = True
@ -29,5 +32,12 @@ ignore_missing_imports = True
[mypy-snappy.*]
ignore_missing_imports = True
[mypy-test.*]
allow_redefinition = true
allow_untyped_globals = true
[mypy-winkerberos.*]
ignore_missing_imports = True
[mypy-xmlrunner.*]
ignore_missing_imports = True

View File

@ -16,9 +16,8 @@
import errno
import select
import socket
import sys
from typing import Any, Optional
from typing import Any, Optional, Union
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
# https://bugs.jython.org/issue2900
@ -43,7 +42,7 @@ class SocketChecker(object):
else:
self._poller = None
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool:
"""Select for reads or writes with a timeout in seconds (or None).
Returns True if the socket is readable/writable, False on timeout.

View File

@ -39,7 +39,7 @@ def maybe_decode(text):
def _resolve(*args, **kwargs):
if hasattr(resolver, 'resolve'):
# dnspython >= 2
return resolver.resolve(*args, **kwargs) # type: ignore
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)

View File

@ -14,7 +14,7 @@
"""Type aliases used by PyMongo"""
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
Tuple, Type, TypeVar, Union)
Sequence, Tuple, Type, TypeVar, Union)
if TYPE_CHECKING:
from bson.raw_bson import RawBSONDocument
@ -25,5 +25,5 @@ if TYPE_CHECKING:
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
_Pipeline = List[Mapping[str, Any]]
_Pipeline = Sequence[Mapping[str, Any]]
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])

View File

@ -40,6 +40,7 @@ except ImportError:
from contextlib import contextmanager
from functools import wraps
from typing import Dict, no_type_check
from unittest import SkipTest
import pymongo
@ -48,7 +49,9 @@ import pymongo.errors
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.database import Database
from pymongo.hello import HelloCompat
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl
from pymongo.uri_parser import parse_uri
@ -86,7 +89,7 @@ CLIENT_PEM = os.environ.get('CLIENT_PEM',
os.path.join(CERT_PATH, 'client.pem'))
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
TLS_OPTIONS = dict(tls=True)
TLS_OPTIONS: Dict = dict(tls=True)
if CLIENT_PEM:
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
if CA_PEM:
@ -102,13 +105,13 @@ if TEST_LOADBALANCER:
# Remove after PYTHON-2712
from pymongo import pool
pool._MOCK_SERVICE_ID = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
elif TEST_SERVERLESS:
TEST_LOADBALANCER = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
@ -184,6 +187,7 @@ class client_knobs(object):
def __enter__(self):
self.enable()
@no_type_check
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
@ -224,6 +228,8 @@ def _all_users(db):
class ClientContext(object):
client: MongoClient
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
def __init__(self):
@ -247,9 +253,9 @@ class ClientContext(object):
self.tls = False
self.tlsCertificateKeyFile = False
self.server_is_resolvable = is_server_resolvable()
self.default_client_options = {}
self.default_client_options: Dict = {}
self.sessions_enabled = False
self.client = None
self.client = None # type: ignore
self.conn_lock = threading.Lock()
self.is_data_lake = False
self.load_balancer = TEST_LOADBALANCER
@ -340,6 +346,7 @@ class ClientContext(object):
try:
self.cmd_line = self.client.admin.command('getCmdLineOpts')
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
# Unauthorized.
@ -418,6 +425,7 @@ class ClientContext(object):
else:
self.server_parameters = self.client.admin.command(
'getParameter', '*')
assert self.cmd_line is not None
if 'enableTestCommands=1' in self.cmd_line['argv']:
self.test_commands_enabled = True
elif 'parsed' in self.cmd_line:
@ -436,7 +444,8 @@ class ClientContext(object):
self.mongoses.append(address)
if not self.serverless:
# Check for another mongos on the next port.
next_address = address[0], address[1] + 1
assert address is not None
next_address = address[0], address[1] + 1
mongos_client = self._connect(
*next_address, **self.default_client_options)
if mongos_client:
@ -496,6 +505,7 @@ class ClientContext(object):
try:
return db_user in _all_users(client.admin)
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 18 or 'auth fails' in msg:
# Auth failed.
@ -505,6 +515,7 @@ class ClientContext(object):
def _server_started_with_auth(self):
# MongoDB >= 2.0
assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
parsed = self.cmd_line['parsed']
# MongoDB >= 2.6
@ -525,6 +536,7 @@ class ClientContext(object):
if not socket.has_ipv6:
return False
assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
return False
@ -932,6 +944,9 @@ class PyMongoTestCase(unittest.TestCase):
class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""
client: MongoClient
db: Database
credentials: Dict[str, str]
@classmethod
@client_context.require_connection
@ -1073,7 +1088,7 @@ class PymongoTestRunner(unittest.TextTestRunner):
if HAVE_XML:
class PymongoXMLTestRunner(XMLTestRunner):
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
def run(self, test):
setup()
result = super(PymongoXMLTestRunner, self).run(test)

View File

@ -26,6 +26,7 @@ from pymongo.uri_parser import parse_uri
class TestAuthAWS(unittest.TestCase):
uri: str
@classmethod
def setUpClass(cls):

View File

@ -21,6 +21,9 @@ import unittest
class TestCursorNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient
@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
@ -69,6 +72,9 @@ class TestCursorNamespace(unittest.TestCase):
class TestKillCursorsNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient
@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})

View File

@ -27,7 +27,7 @@ class TestGetmoreSharded(unittest.TestCase):
servers = [MockupDB(), MockupDB()]
# Collect queries to either server in one queue.
q = Queue()
q: Queue = Queue()
for server in servers:
server.subscribe(q.put)
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',

View File

@ -48,7 +48,7 @@ def test_hello_with_option(self, protocol, **kwargs):
ServerApiVersion.V1))}
client = MongoClient("mongodb://"+primary.address_string,
appname='my app', # For _check_handshake_data()
**dict([k_map.get((k, v), (k, v)) for k, v
**dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type]
in kwargs.items()]))
self.addCleanup(client.close)
@ -58,7 +58,7 @@ def test_hello_with_option(self, protocol, **kwargs):
# We do this checking here rather than in the autoresponder `respond()`
# because it runs in another Python thread so there are some funky things
# with error handling within that thread, and we want to be able to use
# with error handling within that thread, and we want to be able to use
# self.assertRaises().
self.handshake_req.assert_matches(protocol(hello, **kwargs))
_check_handshake_data(self.handshake_req)

View File

@ -30,7 +30,7 @@ class TestMixedVersionSharded(unittest.TestCase):
self.mongos_old, self.mongos_new = MockupDB(), MockupDB()
# Collect queries to either server in one queue.
self.q = Queue()
self.q: Queue = Queue()
for server in self.mongos_old, self.mongos_new:
server.subscribe(self.q.put)
server.autoresponds('getlasterror')
@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade):
def test(self):
self.setup_server(upgrade)
start = time.time()
servers_used = set()
servers_used: set = set()
while len(servers_used) < 2:
go(upgrade.function, self.client)
request = self.q.get(timeout=1)

View File

@ -233,6 +233,8 @@ operations_312 = [
class TestOpMsg(unittest.TestCase):
server: MockupDB
client: MongoClient
@classmethod
def setUpClass(cls):

View File

@ -14,6 +14,7 @@
import copy
import itertools
from typing import Any
from mockupdb import MockupDB, going, CommandBase
from pymongo import MongoClient, ReadPreference
@ -27,6 +28,8 @@ from operations import operations
class OpMsgReadPrefBase(unittest.TestCase):
single_mongod = False
primary: MockupDB
secondary: MockupDB
@classmethod
def setUpClass(cls):
@ -142,7 +145,7 @@ def create_op_msg_read_mode_test(mode, operation):
tag_sets=None)
client = self.setup_client(read_preference=pref)
expected_pref: Any
if operation.op_type == 'always-use-secondary':
expected_server = self.secondary
expected_pref = ReadPreference.SECONDARY

View File

@ -27,7 +27,7 @@ from operations import operations
class TestResetAndRequestCheck(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestResetAndRequestCheck, self).__init__(*args, **kwargs)
self.ismaster_time = 0
self.ismaster_time = 0.0
self.client = None
self.server = None
@ -45,7 +45,7 @@ class TestResetAndRequestCheck(unittest.TestCase):
kwargs = {'socketTimeoutMS': 100}
# Disable retryable reads when pymongo supports it.
kwargs['retryReads'] = False
self.client = MongoClient(self.server.uri, **kwargs)
self.client = MongoClient(self.server.uri, **kwargs) # type: ignore
wait_until(lambda: self.client.nodes, 'connect to standalone')
def tearDown(self):
@ -56,6 +56,8 @@ class TestResetAndRequestCheck(unittest.TestCase):
# Application operation fails. Test that client resets server
# description and does *not* schedule immediate check.
self.setup_server()
assert self.server is not None
assert self.client is not None
# Network error on application operation.
with self.assertRaises(ConnectionFailure):
@ -81,6 +83,8 @@ class TestResetAndRequestCheck(unittest.TestCase):
# Application operation times out. Test that client does *not* reset
# server description and does *not* schedule immediate check.
self.setup_server()
assert self.server is not None
assert self.client is not None
with self.assertRaises(ConnectionFailure):
with going(operation.function, self.client):
@ -91,6 +95,7 @@ class TestResetAndRequestCheck(unittest.TestCase):
# Server is *not* Unknown.
topology = self.client._topology
server = topology.select_server_by_address(self.server.address, 0)
assert server is not None
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)
after = self.ismaster_time
@ -99,6 +104,8 @@ class TestResetAndRequestCheck(unittest.TestCase):
def _test_not_master(self, operation):
# Application operation gets a "not master" error.
self.setup_server()
assert self.server is not None
assert self.client is not None
with self.assertRaises(ConnectionFailure):
with going(operation.function, self.client):
@ -110,6 +117,7 @@ class TestResetAndRequestCheck(unittest.TestCase):
# Server is rediscovered.
topology = self.client._topology
server = topology.select_server_by_address(self.server.address, 0)
assert server is not None
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)
after = self.ismaster_time

View File

@ -37,7 +37,7 @@ class TestSlaveOkaySharded(unittest.TestCase):
self.mongos1, self.mongos2 = MockupDB(), MockupDB()
# Collect queries to either server in one queue.
self.q = Queue()
self.q: Queue = Queue()
for server in self.mongos1, self.mongos2:
server.subscribe(self.q.put)
server.run()

View File

@ -67,6 +67,10 @@ class Timer(object):
class PerformanceTest(object):
dataset: Any
data_size: Any
do_task: Any
fail: Any
@classmethod
def setUpClass(cls):
@ -386,6 +390,7 @@ def mp_map(map_func, files):
def insert_json_file(filename):
assert proc_client is not None
with open(filename, 'r') as data:
coll = proc_client.perftest.corpus
coll.insert_many([json.loads(line) for line in data])
@ -398,11 +403,13 @@ def insert_json_file_with_file_id(filename):
doc = json.loads(line)
doc['file'] = filename
documents.append(doc)
assert proc_client is not None
coll = proc_client.perftest.corpus
coll.insert_many(documents)
def read_json_file(filename):
assert proc_client is not None
coll = proc_client.perftest.corpus
temp = tempfile.TemporaryFile(mode='w')
try:
@ -414,6 +421,7 @@ def read_json_file(filename):
def insert_gridfs_file(filename):
assert proc_client is not None
bucket = GridFSBucket(proc_client.perftest)
with open(filename, 'rb') as gfile:
@ -421,6 +429,7 @@ def insert_gridfs_file(filename):
def read_gridfs_file(filename):
assert proc_client is not None
bucket = GridFSBucket(proc_client.perftest)
temp = tempfile.TemporaryFile()

View File

@ -76,6 +76,8 @@ class AutoAuthenticateThread(threading.Thread):
class TestGSSAPI(unittest.TestCase):
mech_properties: str
service_realm_required: bool
@classmethod
def setUpClass(cls):
@ -116,6 +118,7 @@ class TestGSSAPI(unittest.TestCase):
@ignore_deprecations
def test_gssapi_simple(self):
assert GSSAPI_PRINCIPAL is not None
if GSSAPI_PASS is not None:
uri = ('mongodb://%s:%s@%s:%d/?authMechanism='
'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL),
@ -264,6 +267,8 @@ class TestSASLPlain(unittest.TestCase):
authMechanism='PLAIN')
client.ldap.test.find_one()
assert SASL_USER is not None
assert SASL_PASS is not None
uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;'
'authSource=%s' % (quote_plus(SASL_USER),
quote_plus(SASL_PASS),
@ -540,7 +545,6 @@ class TestSCRAM(IntegrationTest):
self.assertIsInstance(iterations, int)
def test_scram_threaded(self):
coll = client_context.client.db.test
coll.drop()
coll.insert_one({'_id': 1})

View File

@ -41,6 +41,8 @@ from test.utils import ignore_deprecations
class TestBinary(unittest.TestCase):
csharp_data: bytes
java_data: bytes
@classmethod
def setUpClass(cls):
@ -354,6 +356,8 @@ class TestBinary(unittest.TestCase):
class TestUuidSpecExplicitCoding(unittest.TestCase):
uuid: uuid.UUID
@classmethod
def setUpClass(cls):
super(TestUuidSpecExplicitCoding, cls).setUpClass()
@ -457,6 +461,8 @@ class TestUuidSpecExplicitCoding(unittest.TestCase):
class TestUuidSpecImplicitCoding(IntegrationTest):
uuid: uuid.UUID
@classmethod
def setUpClass(cls):
super(TestUuidSpecImplicitCoding, cls).setUpClass()

View File

@ -186,7 +186,7 @@ class TestBSON(unittest.TestCase):
decoder=lambda *args: BSON(args[0]).decode(*args[1:]))
def test_encoding_defaultdict(self):
dct = collections.defaultdict(dict, [('foo', 'bar')])
dct = collections.defaultdict(dict, [('foo', 'bar')]) # type: ignore[arg-type]
encode(dct)
self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')]))
@ -302,7 +302,7 @@ class TestBSON(unittest.TestCase):
def test_decode_all_buffer_protocol(self):
docs = [{'foo': 'bar'}, {}]
bs = b"".join(map(encode, docs))
bs = b"".join(map(encode, docs)) # type: ignore[arg-type]
self.assertEqual(docs, decode_all(bytearray(bs)))
self.assertEqual(docs, decode_all(memoryview(bs)))
self.assertEqual(docs, decode_all(memoryview(b'1' + bs + b'1')[1:-1]))
@ -530,7 +530,9 @@ class TestBSON(unittest.TestCase):
def test_aware_datetime(self):
aware = datetime.datetime(1993, 4, 4, 2,
tzinfo=FixedOffset(555, "SomeZone"))
as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc)
offset = aware.utcoffset()
assert offset is not None
as_utc = (aware - offset).replace(tzinfo=utc)
self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc),
as_utc)
after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))[
@ -591,7 +593,9 @@ class TestBSON(unittest.TestCase):
def test_naive_decode(self):
aware = datetime.datetime(1993, 4, 4, 2,
tzinfo=FixedOffset(555, "SomeZone"))
naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None)
offset = aware.utcoffset()
assert offset is not None
naive_utc = (aware - offset).replace(tzinfo=None)
self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc)
after = decode(encode({"date": aware}))["date"]
self.assertEqual(None, after.tzinfo)
@ -603,9 +607,9 @@ class TestBSON(unittest.TestCase):
@unittest.skip('Disabled due to http://bugs.python.org/issue25222')
def test_bad_encode(self):
evil_list = {'a': []}
evil_list: dict = {'a': []}
evil_list['a'].append(evil_list)
evil_dict = {}
evil_dict: dict = {}
evil_dict['a'] = evil_dict
for evil_data in [evil_dict, evil_list]:
self.assertRaises(Exception, encode, evil_data)
@ -1039,8 +1043,8 @@ class TestCodecOptions(unittest.TestCase):
def test_regex_pickling(self):
reg = Regex(".?")
pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n'
b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}'
pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n'
b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}'
b'\x94(\x8c\x07pattern\x94\x8c\x02.?\x94\x8c\x05flag'
b's\x94K\x00ub.')
self.round_trip_pickle(reg, pickled_with_3)
@ -1083,8 +1087,8 @@ class TestCodecOptions(unittest.TestCase):
def test_maxkey_pickling(self):
maxk = MaxKey()
pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c'
b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)'
pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c'
b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)'
b'\x81\x94.')
self.round_trip_pickle(maxk, pickled_with_3)

View File

@ -16,13 +16,15 @@
import sys
import uuid
from bson.binary import UuidRepresentation
from bson.codec_options import CodecOptions
from pymongo.mongo_client import MongoClient
sys.path[0:0] = [""]
from bson import Binary
from bson.binary import Binary, UuidRepresentation
from bson.codec_options import CodecOptions
from bson.objectid import ObjectId
from pymongo.collection import Collection
from pymongo.common import partition_node
from pymongo.errors import (BulkWriteError,
ConfigurationError,
@ -40,6 +42,8 @@ from test.utils import (remove_all_users,
class BulkTestBase(IntegrationTest):
coll: Collection
coll_w0: Collection
@classmethod
def setUpClass(cls):
@ -280,6 +284,7 @@ class TestBulk(BulkTestBase):
upsert=True)])
self.assertEqualResponse(expected, result.bulk_api_result)
self.assertEqual(1, result.upserted_count)
assert result.upserted_ids is not None
self.assertEqual(1, len(result.upserted_ids))
self.assertTrue(isinstance(result.upserted_ids.get(0), ObjectId))
@ -341,11 +346,11 @@ class TestBulk(BulkTestBase):
# The requests argument must be a list.
generator = (InsertOne({}) for _ in range(10))
with self.assertRaises(TypeError):
self.coll.bulk_write(generator)
self.coll.bulk_write(generator) # type: ignore[arg-type]
# Document is not wrapped in a bulk write operation.
with self.assertRaises(TypeError):
self.coll.bulk_write([{}])
self.coll.bulk_write([{}]) # type: ignore[list-item]
def test_upsert_large(self):
big = 'a' * (client_context.max_bson_size - 37)
@ -425,7 +430,7 @@ class TestBulk(BulkTestBase):
def test_upsert_uuid_standard_subdocuments(self):
options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
coll = self.coll.with_options(codec_options=options)
ids = [
ids: list = [
{'f': Binary(bytes(i)), 'f2': uuid.uuid4()}
for i in range(3)
]
@ -472,7 +477,7 @@ class TestBulk(BulkTestBase):
def test_single_error_ordered_batch(self):
self.coll.create_index('a', unique=True)
self.addCleanup(self.coll.drop_index, [('a', 1)])
requests = [
requests: list = [
InsertOne({'b': 1, 'a': 1}),
UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True),
InsertOne({'b': 3, 'a': 2}),
@ -506,7 +511,7 @@ class TestBulk(BulkTestBase):
def test_multiple_error_ordered_batch(self):
self.coll.create_index('a', unique=True)
self.addCleanup(self.coll.drop_index, [('a', 1)])
requests = [
requests: list = [
InsertOne({'b': 1, 'a': 1}),
UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True),
UpdateOne({'b': 3}, {'$set': {'a': 2}}, upsert=True),
@ -542,7 +547,7 @@ class TestBulk(BulkTestBase):
result)
def test_single_unordered_batch(self):
requests = [
requests: list = [
InsertOne({'a': 1}),
UpdateOne({'a': 1}, {'$set': {'b': 1}}),
UpdateOne({'a': 2}, {'$set': {'b': 2}}, upsert=True),
@ -564,7 +569,7 @@ class TestBulk(BulkTestBase):
def test_single_error_unordered_batch(self):
self.coll.create_index('a', unique=True)
self.addCleanup(self.coll.drop_index, [('a', 1)])
requests = [
requests: list = [
InsertOne({'b': 1, 'a': 1}),
UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True),
InsertOne({'b': 3, 'a': 2}),
@ -599,7 +604,7 @@ class TestBulk(BulkTestBase):
def test_multiple_error_unordered_batch(self):
self.coll.create_index('a', unique=True)
self.addCleanup(self.coll.drop_index, [('a', 1)])
requests = [
requests: list = [
InsertOne({'b': 1, 'a': 1}),
UpdateOne({'b': 2}, {'$set': {'a': 3}}, upsert=True),
UpdateOne({'b': 3}, {'$set': {'a': 4}}, upsert=True),
@ -662,7 +667,7 @@ class TestBulk(BulkTestBase):
self.coll.delete_many({})
big = 'x' * (1024 * 1024 * 4)
result = self.coll.bulk_write([
write_result = self.coll.bulk_write([
InsertOne({'a': 1, 'big': big}),
InsertOne({'a': 2, 'big': big}),
InsertOne({'a': 3, 'big': big}),
@ -671,7 +676,7 @@ class TestBulk(BulkTestBase):
InsertOne({'a': 6, 'big': big}),
])
self.assertEqual(6, result.inserted_count)
self.assertEqual(6, write_result.inserted_count)
self.assertEqual(6, self.coll.count_documents({}))
def test_large_inserts_unordered(self):
@ -685,12 +690,12 @@ class TestBulk(BulkTestBase):
try:
self.coll.bulk_write(requests, ordered=False)
except BulkWriteError as exc:
result = exc.details
details = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
self.assertEqual(2, result['nInserted'])
self.assertEqual(2, details['nInserted'])
self.coll.delete_many({})
@ -741,7 +746,7 @@ class TestBulkUnacknowledged(BulkTestBase):
self.coll.delete_many({})
def test_no_results_ordered_success(self):
requests = [
requests: list = [
InsertOne({'a': 1}),
UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True),
InsertOne({'a': 2}),
@ -755,7 +760,7 @@ class TestBulkUnacknowledged(BulkTestBase):
'removed {"_id": 1}')
def test_no_results_ordered_failure(self):
requests = [
requests: list = [
InsertOne({'_id': 1}),
UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True),
InsertOne({'_id': 2}),
@ -771,7 +776,7 @@ class TestBulkUnacknowledged(BulkTestBase):
self.assertEqual({'_id': 1}, self.coll.find_one({'_id': 1}))
def test_no_results_unordered_success(self):
requests = [
requests: list = [
InsertOne({'a': 1}),
UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True),
InsertOne({'a': 2}),
@ -785,7 +790,7 @@ class TestBulkUnacknowledged(BulkTestBase):
'removed {"_id": 1}')
def test_no_results_unordered_failure(self):
requests = [
requests: list = [
InsertOne({'_id': 1}),
UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True),
InsertOne({'_id': 2}),
@ -832,13 +837,15 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
class TestBulkWriteConcern(BulkTestBase):
w: Optional[int]
secondary: MongoClient
@classmethod
def setUpClass(cls):
super(TestBulkWriteConcern, cls).setUpClass()
cls.w = client_context.w
cls.secondary = None
if cls.w > 1:
if cls.w is not None and cls.w > 1:
for member in client_context.hello['hosts']:
if member != client_context.hello['primary']:
cls.secondary = single_client(*partition_node(member))
@ -886,7 +893,7 @@ class TestBulkWriteConcern(BulkTestBase):
try:
self.cause_wtimeout(requests, ordered=True)
except BulkWriteError as exc:
result = exc.details
details = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
@ -899,13 +906,13 @@ class TestBulkWriteConcern(BulkTestBase):
'nRemoved': 0,
'upserted': [],
'writeErrors': []},
result)
details)
# When talking to legacy servers there will be a
# write concern error for each operation.
self.assertTrue(len(result['writeConcernErrors']) > 0)
self.assertTrue(len(details['writeConcernErrors']) > 0)
failed = result['writeConcernErrors'][0]
failed = details['writeConcernErrors'][0]
self.assertEqual(64, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], str))
@ -924,7 +931,7 @@ class TestBulkWriteConcern(BulkTestBase):
try:
self.cause_wtimeout(requests, ordered=True)
except BulkWriteError as exc:
result = exc.details
details = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
@ -941,10 +948,10 @@ class TestBulkWriteConcern(BulkTestBase):
'code': 11000,
'errmsg': '...',
'op': {'_id': '...', 'a': 1}}]},
result)
details)
self.assertTrue(len(result['writeConcernErrors']) > 1)
failed = result['writeErrors'][0]
self.assertTrue(len(details['writeConcernErrors']) > 1)
failed = details['writeErrors'][0]
self.assertTrue("duplicate" in failed['errmsg'])
@client_context.require_replica_set
@ -966,17 +973,17 @@ class TestBulkWriteConcern(BulkTestBase):
try:
self.cause_wtimeout(requests, ordered=False)
except BulkWriteError as exc:
result = exc.details
details = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
self.assertEqual(2, result['nInserted'])
self.assertEqual(1, result['nUpserted'])
self.assertEqual(0, len(result['writeErrors']))
self.assertEqual(2, details['nInserted'])
self.assertEqual(1, details['nUpserted'])
self.assertEqual(0, len(details['writeErrors']))
# When talking to legacy servers there will be a
# write concern error for each operation.
self.assertTrue(len(result['writeConcernErrors']) > 1)
self.assertTrue(len(details['writeConcernErrors']) > 1)
self.coll.delete_many({})
self.coll.create_index('a', unique=True)
@ -984,7 +991,7 @@ class TestBulkWriteConcern(BulkTestBase):
# Fail due to write concern support as well
# as duplicate key error on unordered batch.
requests = [
requests: list = [
InsertOne({'a': 1}),
UpdateOne({'a': 3}, {'$set': {'a': 3, 'b': 1}}, upsert=True),
InsertOne({'a': 1}),
@ -993,29 +1000,29 @@ class TestBulkWriteConcern(BulkTestBase):
try:
self.cause_wtimeout(requests, ordered=False)
except BulkWriteError as exc:
result = exc.details
details = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
self.assertEqual(2, result['nInserted'])
self.assertEqual(1, result['nUpserted'])
self.assertEqual(1, len(result['writeErrors']))
self.assertEqual(2, details['nInserted'])
self.assertEqual(1, details['nUpserted'])
self.assertEqual(1, len(details['writeErrors']))
# When talking to legacy servers there will be a
# write concern error for each operation.
self.assertTrue(len(result['writeConcernErrors']) > 1)
self.assertTrue(len(details['writeConcernErrors']) > 1)
failed = result['writeErrors'][0]
failed = details['writeErrors'][0]
self.assertEqual(2, failed['index'])
self.assertEqual(11000, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], str))
self.assertEqual(1, failed['op']['a'])
failed = result['writeConcernErrors'][0]
failed = details['writeConcernErrors'][0]
self.assertEqual(64, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], str))
upserts = result['upserted']
upserts = details['upserted']
self.assertEqual(1, len(upserts))
self.assertEqual(1, upserts[0]['index'])
self.assertTrue(upserts[0].get('_id'))

View File

@ -24,6 +24,7 @@ import time
import uuid
from itertools import product
from typing import no_type_check
sys.path[0:0] = ['']
@ -121,6 +122,7 @@ class TestChangeStreamBase(IntegrationTest):
class APITestsMixin(object):
@no_type_check
def test_watch(self):
with self.change_stream(
[{'$project': {'foo': 0}}], full_document='updateLookup',
@ -145,6 +147,7 @@ class APITestsMixin(object):
with self.change_stream(resume_after=resume_token):
pass
@no_type_check
def test_try_next(self):
# ChangeStreams only read majority committed data so use w:majority.
coll = self.watched_collection().with_options(
@ -161,6 +164,7 @@ class APITestsMixin(object):
wait_until(lambda: stream.try_next() is not None,
"get change from try_next")
@no_type_check
def test_try_next_runs_one_getmore(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
@ -216,6 +220,7 @@ class APITestsMixin(object):
set(["getMore"]))
self.assertIsNone(stream.try_next())
@no_type_check
def test_batch_size_is_honored(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
@ -245,6 +250,7 @@ class APITestsMixin(object):
self.assertEqual(expected[key], cmd[key])
# $changeStream.startAtOperationTime was added in 4.0.0.
@no_type_check
@client_context.require_version_min(4, 0, 0)
def test_start_at_operation_time(self):
optime = self.get_start_at_operation_time()
@ -258,6 +264,7 @@ class APITestsMixin(object):
for i in range(ndocs):
cs.next()
@no_type_check
def _test_full_pipeline(self, expected_cs_stage):
client, listener = self.client_with_listener("aggregate")
results = listener.results
@ -273,12 +280,14 @@ class APITestsMixin(object):
{'$project': {'foo': 0}}],
command.command['pipeline'])
@no_type_check
def test_full_pipeline(self):
"""$changeStream must be the first stage in a change stream pipeline
sent to the server.
"""
self._test_full_pipeline({})
@no_type_check
def test_iteration(self):
with self.change_stream(batch_size=2) as change_stream:
num_inserted = 10
@ -292,6 +301,7 @@ class APITestsMixin(object):
break
self._test_invalidate_stops_iteration(change_stream)
@no_type_check
def _test_next_blocks(self, change_stream):
inserted_doc = {'_id': ObjectId()}
changes = []
@ -311,18 +321,21 @@ class APITestsMixin(object):
self.assertEqual(changes[0]['operationType'], 'insert')
self.assertEqual(changes[0]['fullDocument'], inserted_doc)
@no_type_check
def test_next_blocks(self):
"""Test that next blocks until a change is readable"""
# Use a short await time to speed up the test.
with self.change_stream(max_await_time_ms=250) as change_stream:
self._test_next_blocks(change_stream)
@no_type_check
def test_aggregate_cursor_blocks(self):
"""Test that an aggregate cursor blocks until a change is readable."""
with self.watched_collection().aggregate(
[{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream:
self._test_next_blocks(change_stream)
@no_type_check
def test_concurrent_close(self):
"""Ensure a ChangeStream can be closed from another thread."""
# Use a short await time to speed up the test.
@ -338,6 +351,7 @@ class APITestsMixin(object):
t.join(3)
self.assertFalse(t.is_alive())
@no_type_check
def test_unknown_full_document(self):
"""Must rely on the server to raise an error on unknown fullDocument.
"""
@ -347,6 +361,7 @@ class APITestsMixin(object):
except OperationFailure:
pass
@no_type_check
def test_change_operations(self):
"""Test each operation type."""
expected_ns = {'db': self.watched_collection().database.name,
@ -393,6 +408,7 @@ class APITestsMixin(object):
# Invalidate.
self._test_get_invalidate_event(change_stream)
@no_type_check
@client_context.require_version_min(4, 1, 1)
def test_start_after(self):
resume_token = self.get_resume_token(invalidate=True)
@ -408,6 +424,7 @@ class APITestsMixin(object):
self.assertEqual(change['operationType'], 'insert')
self.assertEqual(change['fullDocument'], {'_id': 2})
@no_type_check
@client_context.require_version_min(4, 1, 1)
def test_start_after_resume_process_with_changes(self):
resume_token = self.get_resume_token(invalidate=True)
@ -427,6 +444,7 @@ class APITestsMixin(object):
self.assertEqual(change['operationType'], 'insert')
self.assertEqual(change['fullDocument'], {'_id': 3})
@no_type_check
@client_context.require_no_mongos # Remove after SERVER-41196
@client_context.require_version_min(4, 1, 1)
def test_start_after_resume_process_without_changes(self):
@ -444,12 +462,14 @@ class APITestsMixin(object):
class ProseSpecTestsMixin(object):
@no_type_check
def _client_with_listener(self, *commands):
listener = AllowListEventListener(*commands)
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
return client, listener
@no_type_check
def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3):
self.watched_collection().insert_many(
[{"data": k} for k in range(batch_size)])
@ -485,6 +505,7 @@ class ProseSpecTestsMixin(object):
response = listener.results['succeeded'][-1].reply
return response['cursor']['postBatchResumeToken']
@no_type_check
def _test_raises_error_on_missing_id(self, expected_exception):
"""ChangeStream will raise an exception if the server response is
missing the resume token.
@ -497,6 +518,7 @@ class ProseSpecTestsMixin(object):
with self.assertRaises(StopIteration):
next(change_stream)
@no_type_check
def _test_update_resume_token(self, expected_rt_getter):
"""ChangeStream must continuously track the last seen resumeToken."""
client, listener = self._client_with_listener("aggregate", "getMore")
@ -536,6 +558,7 @@ class ProseSpecTestsMixin(object):
self._test_raises_error_on_missing_id(InvalidOperation)
# Prose test no. 3
@no_type_check
def test_resume_on_error(self):
with self.change_stream() as change_stream:
self.insert_one_and_check(change_stream, {'_id': 1})
@ -544,6 +567,7 @@ class ProseSpecTestsMixin(object):
self.insert_one_and_check(change_stream, {'_id': 2})
# Prose test no. 4
@no_type_check
@client_context.require_failCommand_fail_point
def test_no_resume_attempt_if_aggregate_command_fails(self):
# Set non-retryable error on aggregate command.
@ -568,6 +592,7 @@ class ProseSpecTestsMixin(object):
# each operation which ensure compliance with this prose test.
# Prose test no. 7
@no_type_check
def test_initial_empty_batch(self):
with self.change_stream() as change_stream:
# The first batch should be empty.
@ -579,6 +604,7 @@ class ProseSpecTestsMixin(object):
self.assertEqual(cursor_id, change_stream._cursor.cursor_id)
# Prose test no. 8
@no_type_check
def test_kill_cursors(self):
def raise_error():
raise ServerSelectionTimeoutError('mock error')
@ -591,6 +617,7 @@ class ProseSpecTestsMixin(object):
self.insert_one_and_check(change_stream, {'_id': 2})
# Prose test no. 9
@no_type_check
@client_context.require_version_min(4, 0, 0)
@client_context.require_version_max(4, 0, 7)
def test_start_at_operation_time_caching(self):
@ -619,6 +646,7 @@ class ProseSpecTestsMixin(object):
# This test is identical to prose test no. 3.
# Prose test no. 11
@no_type_check
@client_context.require_version_min(4, 0, 7)
def test_resumetoken_empty_batch(self):
client, listener = self._client_with_listener("getMore")
@ -631,6 +659,7 @@ class ProseSpecTestsMixin(object):
response["cursor"]["postBatchResumeToken"])
# Prose test no. 11
@no_type_check
@client_context.require_version_min(4, 0, 7)
def test_resumetoken_exhausted_batch(self):
client, listener = self._client_with_listener("getMore")
@ -643,6 +672,7 @@ class ProseSpecTestsMixin(object):
response["cursor"]["postBatchResumeToken"])
# Prose test no. 12
@no_type_check
@client_context.require_version_max(4, 0, 7)
def test_resumetoken_empty_batch_legacy(self):
resume_point = self.get_resume_token()
@ -659,6 +689,7 @@ class ProseSpecTestsMixin(object):
self.assertEqual(resume_token, resume_point)
# Prose test no. 12
@no_type_check
@client_context.require_version_max(4, 0, 7)
def test_resumetoken_exhausted_batch_legacy(self):
# Resume token is _id of last change.
@ -673,6 +704,7 @@ class ProseSpecTestsMixin(object):
self.assertEqual(change_stream.resume_token, change["_id"])
# Prose test no. 13
@no_type_check
def test_resumetoken_partially_iterated_batch(self):
# When batch has been iterated up to but not including the last element.
# Resume token should be _id of previous change document.
@ -686,6 +718,7 @@ class ProseSpecTestsMixin(object):
self.assertEqual(resume_token, change["_id"])
@no_type_check
def _test_resumetoken_uniterated_nonempty_batch(self, resume_option):
# When the batch is not empty and hasn't been iterated at all.
# Resume token should be same as the resume option used.
@ -704,17 +737,20 @@ class ProseSpecTestsMixin(object):
self.assertEqual(resume_token, resume_point)
# Prose test no. 14
@no_type_check
@client_context.require_no_mongos
def test_resumetoken_uniterated_nonempty_batch_resumeafter(self):
self._test_resumetoken_uniterated_nonempty_batch("resume_after")
# Prose test no. 14
@no_type_check
@client_context.require_no_mongos
@client_context.require_version_min(4, 1, 1)
def test_resumetoken_uniterated_nonempty_batch_startafter(self):
self._test_resumetoken_uniterated_nonempty_batch("start_after")
# Prose test no. 17
@no_type_check
@client_context.require_version_min(4, 1, 1)
def test_startafter_resume_uses_startafter_after_empty_getMore(self):
# Resume should use startAfter after no changes have been returned.
@ -735,6 +771,7 @@ class ProseSpecTestsMixin(object):
response.command["pipeline"][0]["$changeStream"].get("startAfter"))
# Prose test no. 18
@no_type_check
@client_context.require_version_min(4, 1, 1)
def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self):
# Resume should use resumeAfter after some changes have been returned.
@ -757,6 +794,8 @@ class ProseSpecTestsMixin(object):
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
dbs: list
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_no_mmap
@ -1045,6 +1084,7 @@ class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin,
class TestAllLegacyScenarios(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener
@classmethod
@client_context.require_connection

View File

@ -28,6 +28,8 @@ import _thread as thread
import threading
import warnings
from typing import no_type_check, Type
sys.path[0:0] = [""]
from bson import encode
@ -99,6 +101,7 @@ from test.utils import (assertRaisesExactly,
class ClientUnitTest(unittest.TestCase):
"""MongoClient tests that don't require a server."""
client: MongoClient
@classmethod
@client_context.require_connection
@ -341,7 +344,7 @@ class ClientUnitTest(unittest.TestCase):
return int(value)
# Ensure codec options are passed in correctly
document_class = SON
document_class: Type[SON] = SON
type_registry = TypeRegistry([MyFloatAsIntEncoder()])
tz_aware = True
uuid_representation_label = 'javaLegacy'
@ -614,7 +617,7 @@ class TestClient(IntegrationTest):
port are not overloaded.
"""
host, port = client_context.host, client_context.port
kwargs = client_context.default_client_options.copy()
kwargs: dict = client_context.default_client_options.copy()
if client_context.auth_enabled:
kwargs['username'] = db_user
kwargs['password'] = db_pwd
@ -1111,6 +1114,7 @@ class TestClient(IntegrationTest):
socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@no_type_check
def test_tz_aware(self):
self.assertRaises(ValueError, MongoClient, tz_aware='foo')
@ -1140,7 +1144,7 @@ class TestClient(IntegrationTest):
uri = "mongodb://%s[::1]:%d" % (auth_str, client_context.port)
if client_context.is_rs:
uri += '/?replicaSet=' + client_context.replica_set_name
uri += '/?replicaSet=' + (client_context.replica_set_name or "")
client = rs_or_single_client_noauth(uri)
client.pymongo_test.test.insert_one({"dummy": "object"})
@ -1379,7 +1383,7 @@ class TestClient(IntegrationTest):
heartbeat_times.append(time.time())
try:
ServerHeartbeatStartedEvent.__init__ = init
ServerHeartbeatStartedEvent.__init__ = init # type: ignore
listener = HeartbeatStartedListener()
uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % (
client_context.host, client_context.port)
@ -1394,7 +1398,7 @@ class TestClient(IntegrationTest):
client.close()
finally:
ServerHeartbeatStartedEvent.__init__ = old_init
ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore
def test_small_heartbeat_frequency_ms(self):
uri = "mongodb://example/?heartbeatFrequencyMS=499"
@ -1847,7 +1851,7 @@ class TestClientLazyConnect(IntegrationTest):
lazy_client_trial(reset, delete_one, test, self._get_client)
def test_find_one(self):
results = []
results: list = []
def reset(collection):
collection.drop()

View File

@ -213,11 +213,11 @@ class TestCMAP(IntegrationTest):
def run_scenario(self, scenario_def, test):
"""Run a CMAP spec test."""
self.logs = []
self.logs: list = []
self.assertEqual(scenario_def['version'], 1)
self.assertIn(scenario_def['style'], ['unit', 'integration'])
self.listener = CMAPListener()
self._ops = []
self._ops: list = []
# Configure the fail point before creating the client.
if 'failPoint' in test:
@ -259,9 +259,9 @@ class TestCMAP(IntegrationTest):
self.pool = list(client._topology._servers.values())[0].pool
# Map of target names to Thread objects.
self.targets = dict()
self.targets: dict = dict()
# Map of label names to Connection objects
self.labels = dict()
self.labels: dict = dict()
def cleanup():
for t in self.targets.values():

View File

@ -17,6 +17,7 @@
"""Tests for the Code wrapper."""
import sys
sys.path[0:0] = [""]
from bson.code import Code
@ -35,7 +36,7 @@ class TestCode(unittest.TestCase):
c = Code("blah")
def set_c():
c.scope = 5
c.scope = 5 # type: ignore
self.assertRaises(AttributeError, set_c)
def test_code(self):

View File

@ -17,6 +17,8 @@
import functools
import warnings
from typing import Any
from pymongo.collation import (
Collation,
CollationCaseFirst, CollationStrength, CollationAlternate,
@ -78,6 +80,10 @@ class TestCollationObject(unittest.TestCase):
class TestCollation(IntegrationTest):
listener: EventListener
warn_context: Any
collation: Collation
@classmethod
@client_context.require_connection
def setUpClass(cls):

View File

@ -20,8 +20,11 @@ import contextlib
import re
import sys
from codecs import utf_8_decode
from codecs import utf_8_decode # type: ignore
from collections import defaultdict
from typing import no_type_check
from pymongo.database import Database
sys.path[0:0] = [""]
@ -66,6 +69,7 @@ from test.utils import (get_pool, is_mongos,
class TestCollectionNoConnect(unittest.TestCase):
"""Test Collection features on a client that does not connect.
"""
db: Database
@classmethod
def setUpClass(cls):
@ -116,11 +120,12 @@ class TestCollectionNoConnect(unittest.TestCase):
class TestCollection(IntegrationTest):
w: int
@classmethod
def setUpClass(cls):
super(TestCollection, cls).setUpClass()
cls.w = client_context.w
cls.w = client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
@ -726,7 +731,7 @@ class TestCollection(IntegrationTest):
db = self.db
db.test.drop()
docs = [{} for _ in range(5)]
docs: list = [{} for _ in range(5)]
result = db.test.insert_many(docs)
self.assertTrue(isinstance(result, InsertManyResult))
self.assertTrue(isinstance(result.inserted_ids, list))
@ -759,7 +764,7 @@ class TestCollection(IntegrationTest):
db = db.client.get_database(db.name,
write_concern=WriteConcern(w=0))
docs = [{} for _ in range(5)]
docs: list = [{} for _ in range(5)]
result = db.test.insert_many(docs)
self.assertTrue(isinstance(result, InsertManyResult))
self.assertFalse(result.acknowledged)
@ -792,11 +797,11 @@ class TestCollection(IntegrationTest):
with self.assertRaisesRegex(
TypeError, "documents must be a non-empty list"):
db.test.insert_many(1)
db.test.insert_many(1) # type: ignore[arg-type]
with self.assertRaisesRegex(
TypeError, "documents must be a non-empty list"):
db.test.insert_many(RawBSONDocument(encode({'_id': 2})))
db.test.insert_many(RawBSONDocument(encode({'_id': 2}))) # type: ignore[arg-type]
def test_delete_one(self):
self.db.test.drop()
@ -1064,7 +1069,7 @@ class TestCollection(IntegrationTest):
db_w0 = self.db.client.get_database(
self.db.name, write_concern=WriteConcern(w=0))
ops = [InsertOne({"a": -10}),
ops: list = [InsertOne({"a": -10}),
InsertOne({"a": -11}),
InsertOne({"a": -12}),
UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}),
@ -1087,7 +1092,7 @@ class TestCollection(IntegrationTest):
def test_find_by_default_dct(self):
db = self.db
db.test.insert_one({'foo': 'bar'})
dct = defaultdict(dict, [('foo', 'bar')])
dct = defaultdict(dict, [('foo', 'bar')]) # type: ignore[arg-type]
self.assertIsNotNone(db.test.find_one(dct))
self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')]))
@ -1117,6 +1122,7 @@ class TestCollection(IntegrationTest):
doc = next(db.test.find({}, ["mike"]))
self.assertFalse("extra thing" in doc)
@no_type_check
def test_fields_specifier_as_dict(self):
db = self.db
db.test.delete_many({})
@ -1333,7 +1339,7 @@ class TestCollection(IntegrationTest):
self.assertTrue(result.acknowledged)
self.assertEqual(1, db.test.count_documents({"y": 1}))
self.assertEqual(0, db.test.count_documents({"x": 1}))
self.assertEqual(db.test.find_one(id1)["y"], 1)
self.assertEqual(db.test.find_one(id1)["y"], 1) # type: ignore
replacement = RawBSONDocument(encode({"_id": id1, "z": 1}))
result = db.test.replace_one({"y": 1}, replacement, True)
@ -1344,7 +1350,7 @@ class TestCollection(IntegrationTest):
self.assertTrue(result.acknowledged)
self.assertEqual(1, db.test.count_documents({"z": 1}))
self.assertEqual(0, db.test.count_documents({"y": 1}))
self.assertEqual(db.test.find_one(id1)["z"], 1)
self.assertEqual(db.test.find_one(id1)["z"], 1) # type: ignore
result = db.test.replace_one({"x": 2}, {"y": 2}, True)
self.assertTrue(isinstance(result, UpdateResult))
@ -1377,7 +1383,7 @@ class TestCollection(IntegrationTest):
self.assertTrue(result.modified_count in (None, 1))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(db.test.find_one(id1)["x"], 6)
self.assertEqual(db.test.find_one(id1)["x"], 6) # type: ignore
id2 = db.test.insert_one({"x": 1}).inserted_id
result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}})
@ -1386,8 +1392,8 @@ class TestCollection(IntegrationTest):
self.assertTrue(result.modified_count in (None, 1))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(db.test.find_one(id1)["x"], 7)
self.assertEqual(db.test.find_one(id2)["x"], 1)
self.assertEqual(db.test.find_one(id1)["x"], 7) # type: ignore
self.assertEqual(db.test.find_one(id2)["x"], 1) # type: ignore
result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True)
self.assertTrue(isinstance(result, UpdateResult))
@ -1587,12 +1593,12 @@ class TestCollection(IntegrationTest):
# Test that batchSize is handled properly.
cursor = db.test.aggregate([], batchSize=5)
self.assertEqual(5, len(cursor._CommandCursor__data))
self.assertEqual(5, len(cursor._CommandCursor__data)) # type: ignore
# Force a getMore
cursor._CommandCursor__data.clear()
cursor._CommandCursor__data.clear() # type: ignore
next(cursor)
# batchSize - 1
self.assertEqual(4, len(cursor._CommandCursor__data))
self.assertEqual(4, len(cursor._CommandCursor__data)) # type: ignore
# Exhaust the cursor. There shouldn't be any errors.
for doc in cursor:
pass
@ -1679,6 +1685,7 @@ class TestCollection(IntegrationTest):
with self.write_concern_collection() as coll:
coll.rename('foo')
@no_type_check
def test_find_one(self):
db = self.db
db.drop_collection("test")
@ -1973,17 +1980,17 @@ class TestCollection(IntegrationTest):
bad = BadGetAttr([('foo', 'bar')])
c.insert_one({'bad': bad})
self.assertEqual('bar', c.find_one()['bad']['foo'])
self.assertEqual('bar', c.find_one()['bad']['foo']) # type: ignore
def test_array_filters_validation(self):
# array_filters must be a list.
c = self.db.test
with self.assertRaises(TypeError):
c.update_one({}, {'$set': {'a': 1}}, array_filters={})
c.update_one({}, {'$set': {'a': 1}}, array_filters={}) # type: ignore[arg-type]
with self.assertRaises(TypeError):
c.update_many({}, {'$set': {'a': 1}}, array_filters={})
c.update_many({}, {'$set': {'a': 1}}, array_filters={} ) # type: ignore[arg-type]
with self.assertRaises(TypeError):
c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={})
c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={}) # type: ignore[arg-type]
def test_array_filters_unacknowledged(self):
c_w0 = self.db.test.with_options(write_concern=WriteConcern(w=0))
@ -2158,7 +2165,7 @@ class TestCollection(IntegrationTest):
c.drop()
c.insert_one({'r': re.compile('.*')})
self.assertTrue(isinstance(c.find_one()['r'], Regex))
self.assertTrue(isinstance(c.find_one()['r'], Regex)) # type: ignore
for doc in c.find():
self.assertTrue(isinstance(doc['r'], Regex))
@ -2189,9 +2196,9 @@ class TestCollection(IntegrationTest):
for helper, args in helpers:
with self.assertRaisesRegex(TypeError,
"let must be an instance of dict"):
helper(*args, let=let)
helper(*args, let=let) # type: ignore
for helper, args in helpers:
helper(*args, let={})
helper(*args, let={}) # type: ignore
if __name__ == "__main__":

View File

@ -43,6 +43,8 @@ def camel_to_snake(camel):
class TestAllScenarios(unittest.TestCase):
listener: EventListener
client: MongoClient
@classmethod
@client_context.require_connection

View File

@ -50,13 +50,13 @@ class TestCommon(IntegrationTest):
"uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
legacy_opts = coll.codec_options
coll.insert_one({'uu': uu})
self.assertEqual(uu, coll.find_one({'uu': uu})['uu'])
self.assertEqual(uu, coll.find_one({'uu': uu})['uu']) # type: ignore
coll = self.db.get_collection(
"uuid", CodecOptions(uuid_representation=STANDARD))
self.assertEqual(STANDARD, coll.codec_options.uuid_representation)
self.assertEqual(None, coll.find_one({'uu': uu}))
uul = Binary.from_uuid(uu, PYTHON_LEGACY)
self.assertEqual(uul, coll.find_one({'uu': uul})['uu'])
self.assertEqual(uul, coll.find_one({'uu': uul})['uu']) # type: ignore
# Test count_documents
self.assertEqual(0, coll.count_documents({'uu': uu}))
@ -81,9 +81,9 @@ class TestCommon(IntegrationTest):
coll.update_one({'_id': uu}, {'$set': {'i': 2}})
coll = self.db.get_collection(
"uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(1, coll.find_one({'_id': uu})['i'])
self.assertEqual(1, coll.find_one({'_id': uu})['i']) # type: ignore
coll.update_one({'_id': uu}, {'$set': {'i': 2}})
self.assertEqual(2, coll.find_one({'_id': uu})['i'])
self.assertEqual(2, coll.find_one({'_id': uu})['i']) # type: ignore
# Test Cursor.distinct
self.assertEqual([2], coll.find({'_id': uu}).distinct('i'))
@ -98,7 +98,7 @@ class TestCommon(IntegrationTest):
"uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(2, coll.find_one_and_update({'_id': uu},
{'$set': {'i': 5}})['i'])
self.assertEqual(5, coll.find_one({'_id': uu})['i'])
self.assertEqual(5, coll.find_one({'_id': uu})['i']) # type: ignore
# Test command
self.assertEqual(5, self.db.command(

View File

@ -20,6 +20,7 @@ sys.path[0:0] = [""]
from bson import SON
from pymongo import monitoring
from pymongo.collection import Collection
from pymongo.errors import NotPrimaryError
from pymongo.write_concern import WriteConcern
@ -33,6 +34,9 @@ from test.utils import (CMAPListener,
class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
listener: CMAPListener
coll: Collection
@classmethod
@client_context.require_replica_set
def setUpClass(cls):
@ -111,7 +115,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
# Insert record and verify failure.
with self.assertRaises(NotPrimaryError) as exc:
self.coll.insert_one({"test": 1})
self.assertEqual(exc.exception.details['code'], error_code)
self.assertEqual(exc.exception.details['code'], error_code) # type: ignore
# Retry before CMAPListener assertion if retry_before=True.
if retry:
self.coll.insert_one({"test": 1})

View File

@ -53,7 +53,7 @@ def check_result(self, expected_result, result):
# SPEC-869: Only BulkWriteResult has upserted_count.
if (prop == "upserted_count"
and not isinstance(result, BulkWriteResult)):
if result.upserted_id is not None:
if result.upserted_id is not None: # type: ignore
upserted_count = 1
else:
upserted_count = 0
@ -69,14 +69,14 @@ def check_result(self, expected_result, result):
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, msg)
self.assertEqual(ids, result.inserted_ids, msg) # type: ignore
elif prop == "upserted_ids":
# Convert indexes from strings to integers.
ids = expected_result[res]
expected_ids = {}
for str_index in ids:
expected_ids[int(str_index)] = ids[str_index]
self.assertEqual(expected_ids, result.upserted_ids, msg)
self.assertEqual(expected_ids, result.upserted_ids, msg) # type: ignore
else:
self.assertEqual(
getattr(result, prop), expected_result[res], msg)

View File

@ -57,7 +57,7 @@ class TestCursor(IntegrationTest):
re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})
cursor2 = copy.deepcopy(cursor)
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) # type: ignore
def test_add_remove_option(self):
cursor = self.db.test.find()
@ -149,9 +149,9 @@ class TestCursor(IntegrationTest):
self.assertRaises(TypeError, coll.find().allow_disk_use, 'baz')
cursor = coll.find().allow_disk_use(True)
self.assertEqual(True, cursor._Cursor__allow_disk_use)
self.assertEqual(True, cursor._Cursor__allow_disk_use) # type: ignore
cursor = coll.find().allow_disk_use(False)
self.assertEqual(False, cursor._Cursor__allow_disk_use)
self.assertEqual(False, cursor._Cursor__allow_disk_use) # type: ignore
def test_max_time_ms(self):
db = self.db
@ -165,15 +165,15 @@ class TestCursor(IntegrationTest):
coll.find().max_time_ms(1)
cursor = coll.find().max_time_ms(999)
self.assertEqual(999, cursor._Cursor__max_time_ms)
self.assertEqual(999, cursor._Cursor__max_time_ms) # type: ignore
cursor = coll.find().max_time_ms(10).max_time_ms(1000)
self.assertEqual(1000, cursor._Cursor__max_time_ms)
self.assertEqual(1000, cursor._Cursor__max_time_ms) # type: ignore
cursor = coll.find().max_time_ms(999)
c2 = cursor.clone()
self.assertEqual(999, c2._Cursor__max_time_ms)
self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())
self.assertEqual(999, c2._Cursor__max_time_ms) # type: ignore
self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec()) # type: ignore
self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec()) # type: ignore
self.assertTrue(coll.find_one(max_time_ms=1000))
@ -889,7 +889,7 @@ class TestCursor(IntegrationTest):
# Every attribute should be the same.
cursor2 = cursor.clone()
self.assertDictEqual(cursor.__dict__, cursor2.__dict__)
self.assertEqual(cursor.__dict__, cursor2.__dict__)
# Shallow copies can so can mutate
cursor2 = copy.copy(cursor)
@ -1025,7 +1025,7 @@ class TestCursor(IntegrationTest):
self.assertEqual(self.db.test, self.db.test.find().collection)
def set_coll():
self.db.test.find().collection = "hello"
self.db.test.find().collection = "hello" # type: ignore
self.assertRaises(AttributeError, set_coll)

View File

@ -21,6 +21,7 @@ import tempfile
from collections import OrderedDict
from decimal import Decimal
from random import random
from typing import Any, Tuple, Type, no_type_check
sys.path[0:0] = [""]
@ -127,6 +128,7 @@ def type_obfuscating_decoder_factory(rt_type):
class CustomBSONTypeTests(object):
@no_type_check
def roundtrip(self, doc):
bsonbytes = encode(doc, codec_options=self.codecopts)
rt_document = decode(bsonbytes, codec_options=self.codecopts)
@ -139,6 +141,7 @@ class CustomBSONTypeTests(object):
self.roundtrip({'average': [[Decimal('56.47')]]})
self.roundtrip({'average': [{'b': Decimal('56.47')}]})
@no_type_check
def test_decode_all(self):
documents = []
for dec in range(3):
@ -151,12 +154,14 @@ class CustomBSONTypeTests(object):
self.assertEqual(
decode_all(bsonstream, self.codecopts), documents)
@no_type_check
def test__bson_to_dict(self):
document = {'average': Decimal('56.47')}
rawbytes = encode(document, codec_options=self.codecopts)
decoded_document = _bson_to_dict(rawbytes, self.codecopts)
self.assertEqual(document, decoded_document)
@no_type_check
def test__dict_to_bson(self):
document = {'average': Decimal('56.47')}
rawbytes = encode(document, codec_options=self.codecopts)
@ -172,12 +177,14 @@ class CustomBSONTypeTests(object):
bsonstream += encode(doc)
return edocs, bsonstream
@no_type_check
def test_decode_iter(self):
expected, bson_data = self._generate_multidocument_bson_stream()
for expected_doc, decoded_doc in zip(
expected, decode_iter(bson_data, self.codecopts)):
self.assertEqual(expected_doc, decoded_doc)
@no_type_check
def test_decode_file_iter(self):
expected, bson_data = self._generate_multidocument_bson_stream()
fileobj = tempfile.TemporaryFile()
@ -293,6 +300,15 @@ class TestBSONTypeEnDeCodecs(unittest.TestCase):
class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
TypeA: Any
TypeB: Any
fallback_encoder_A2B: Any
fallback_encoder_A2BSON: Any
B2BSON: Type[TypeEncoder]
B2A: Type[TypeEncoder]
A2B: Type[TypeEncoder]
@classmethod
def setUpClass(cls):
class TypeA(object):
@ -378,6 +394,10 @@ class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
class TestTypeRegistry(unittest.TestCase):
types: Tuple[object, object]
codecs: Tuple[Type[TypeCodec], Type[TypeCodec]]
fallback_encoder: Any
@classmethod
def setUpClass(cls):
class MyIntType(object):
@ -466,32 +486,32 @@ class TestTypeRegistry(unittest.TestCase):
def transform_bson(self, value):
return self.types[0](value)
codec_instances = [MyIntDecoder(), MyIntEncoder()]
codec_instances: list = [MyIntDecoder(), MyIntEncoder()]
type_registry = TypeRegistry(codec_instances)
self.assertEqual(
type_registry._encoder_map,
{MyIntEncoder.python_type: codec_instances[1].transform_python})
{MyIntEncoder.python_type: codec_instances[1].transform_python}) # type: ignore
self.assertEqual(
type_registry._decoder_map,
{MyIntDecoder.bson_type: codec_instances[0].transform_bson})
{MyIntDecoder.bson_type: codec_instances[0].transform_bson}) # type: ignore
def test_initialize_fail(self):
err_msg = ("Expected an instance of TypeEncoder, TypeDecoder, "
"or TypeCodec, got .* instead")
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry(self.codecs)
TypeRegistry(self.codecs) # type: ignore[arg-type]
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry([type('AnyType', (object,), {})()])
err_msg = "fallback_encoder %r is not a callable" % (True,)
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry([], True)
TypeRegistry([], True) # type: ignore[arg-type]
err_msg = "fallback_encoder %r is not a callable" % ('hello',)
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry(fallback_encoder='hello')
TypeRegistry(fallback_encoder='hello') # type: ignore[arg-type]
def test_type_registry_repr(self):
codec_instances = [codec() for codec in self.codecs]
@ -525,7 +545,7 @@ class TestTypeRegistry(unittest.TestCase):
if pytype in [bool, type(None), RE_TYPE,]:
continue
class MyType(pytype):
class MyType(pytype): # type: ignore
pass
attrs.update({'python_type': MyType,
'transform_python': lambda x: x})
@ -598,7 +618,7 @@ class TestCollectionWCustomType(IntegrationTest):
test = db.get_collection(
'test', codec_options=UNINT_DECODER_CODECOPTS)
pipeline = [
pipeline: list = [
{'$match': {'status': 'complete'}},
{'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},]
result = test.aggregate(pipeline)
@ -680,15 +700,18 @@ class TestGridFileCustomType(IntegrationTest):
class ChangeStreamsWCustomTypesTestMixin(object):
@no_type_check
def change_stream(self, *args, **kwargs):
return self.watched_target.watch(*args, **kwargs)
@no_type_check
def insert_and_check(self, change_stream, insert_doc,
expected_doc):
self.input_target.insert_one(insert_doc)
change = next(change_stream)
self.assertEqual(change['fullDocument'], expected_doc)
@no_type_check
def kill_change_stream_cursor(self, change_stream):
# Cause a cursor not found error on the next getMore.
cursor = change_stream._cursor
@ -696,6 +719,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
client = self.input_target.database.client
client._close_cursor_now(cursor.cursor_id, address)
@no_type_check
def test_simple(self):
codecopts = CodecOptions(type_registry=TypeRegistry([
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
@ -718,6 +742,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
@no_type_check
def test_custom_type_in_pipeline(self):
codecopts = CodecOptions(type_registry=TypeRegistry([
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
@ -741,6 +766,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[2], expected_docs[1])
@no_type_check
def test_break_resume_token(self):
# Get one document from a change stream to determine resumeToken type.
self.create_targets()
@ -766,6 +792,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, docs[2], docs[2])
@no_type_check
def test_document_class(self):
def run_test(doc_cls):
codecopts = CodecOptions(type_registry=TypeRegistry([

View File

@ -17,6 +17,7 @@
import datetime
import re
import sys
from typing import Any, List, Mapping
sys.path[0:0] = [""]
@ -57,6 +58,7 @@ from test.test_custom_types import DECIMAL_CODECOPTS
class TestDatabaseNoConnect(unittest.TestCase):
"""Test Database features on a client that does not connect.
"""
client: MongoClient
@classmethod
def setUpClass(cls):
@ -143,7 +145,7 @@ class TestDatabase(IntegrationTest):
test = db.create_collection("test")
self.assertTrue("test" in db.list_collection_names())
test.insert_one({"hello": "world"})
self.assertEqual(db.test.find_one()["hello"], "world")
self.assertEqual(db.test.find_one()["hello"], "world") # type: ignore
db.drop_collection("test.foo")
db.create_collection("test.foo")
@ -198,6 +200,7 @@ class TestDatabase(IntegrationTest):
self.assertNotIn("nameOnly", results["started"][0].command)
# Should send nameOnly (except on 2.6).
filter: Any
for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}):
results.clear()
names = db.list_collection_names(filter=filter)
@ -225,7 +228,7 @@ class TestDatabase(IntegrationTest):
self.assertTrue("$" not in coll)
# Duplicate check.
coll_cnt = {}
coll_cnt: dict = {}
for coll in colls:
try:
# Found duplicate.
@ -233,7 +236,7 @@ class TestDatabase(IntegrationTest):
self.assertTrue(False)
except KeyError:
coll_cnt[coll] = 1
coll_cnt = {}
coll_cnt: dict = {}
# Checking if is there any collection which don't exists.
if (len(set(colls) - set(["test","test.mike"])) == 0 or
@ -466,6 +469,7 @@ class TestDatabase(IntegrationTest):
self.assertEqual(None, db.test.find_one({"hello": "test"}))
b = db.test.find_one()
assert b is not None
b["hello"] = "mike"
db.test.replace_one({"_id": b["_id"]}, b)
@ -482,12 +486,12 @@ class TestDatabase(IntegrationTest):
db = self.client.pymongo_test
db.test.drop()
db.test.insert_one({"x": 9223372036854775807})
retrieved = db.test.find_one()['x']
retrieved = db.test.find_one()['x'] # type: ignore
self.assertEqual(Int64(9223372036854775807), retrieved)
self.assertIsInstance(retrieved, Int64)
db.test.delete_many({})
db.test.insert_one({"x": Int64(1)})
retrieved = db.test.find_one()['x']
retrieved = db.test.find_one()['x'] # type: ignore
self.assertEqual(Int64(1), retrieved)
self.assertIsInstance(retrieved, Int64)
@ -509,8 +513,8 @@ class TestDatabase(IntegrationTest):
length += 1
self.assertEqual(length, 2)
db.test.delete_one(db.test.find_one())
db.test.delete_one(db.test.find_one())
db.test.delete_one(db.test.find_one()) # type: ignore[arg-type]
db.test.delete_one(db.test.find_one()) # type: ignore[arg-type]
self.assertEqual(db.test.find_one(), None)
db.test.insert_one({"x": 1})
@ -625,7 +629,7 @@ class TestDatabase(IntegrationTest):
'read_preference': ReadPreference.PRIMARY,
'write_concern': WriteConcern(w=1),
'read_concern': ReadConcern(level="local")}
db2 = db1.with_options(**newopts)
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
for opt in newopts:
self.assertEqual(
getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
@ -633,7 +637,7 @@ class TestDatabase(IntegrationTest):
class TestDatabaseAggregation(IntegrationTest):
def setUp(self):
self.pipeline = [{"$listLocalSessions": {}},
self.pipeline: List[Mapping[str, Any]] = [{"$listLocalSessions": {}},
{"$limit": 1},
{"$addFields": {"dummy": "dummy field"}},
{"$project": {"_id": 0, "dummy": 1}}]
@ -648,6 +652,7 @@ class TestDatabaseAggregation(IntegrationTest):
@client_context.require_no_mongos
def test_database_aggregation_fake_cursor(self):
coll_name = "test_output"
write_stage: dict
if client_context.version < (4, 3):
db_name = "admin"
write_stage = {"$out": coll_name}

View File

@ -16,6 +16,7 @@
import pickle
import sys
from typing import Any
sys.path[0:0] = [""]
from bson import encode, decode
@ -44,10 +45,10 @@ class TestDBRef(unittest.TestCase):
a = DBRef("coll", ObjectId())
def foo():
a.collection = "blah"
a.collection = "blah" # type: ignore[misc]
def bar():
a.id = "aoeu"
a.id = "aoeu" # type: ignore[misc]
self.assertEqual("coll", a.collection)
a.id
@ -136,6 +137,7 @@ class TestDBRef(unittest.TestCase):
# https://github.com/mongodb/specifications/blob/master/source/dbref.rst#test-plan
class TestDBRefSpec(unittest.TestCase):
def test_decoding_1_2_3(self):
doc: Any
for doc in [
# 1, Valid documents MUST be decoded to a DBRef:
{"$ref": "coll0", "$id": ObjectId("60a6fe9a54f4180c86309efa")},
@ -183,6 +185,7 @@ class TestDBRefSpec(unittest.TestCase):
self.assertIsInstance(dbref, dict)
def test_encoding_1_2(self):
doc: Any
for doc in [
# 1, Encoding DBRefs with basic fields:
{"$ref": "coll0", "$id": ObjectId("60a6fe9a54f4180c86309efa")},

View File

@ -35,6 +35,7 @@ class TestDecimal128(unittest.TestCase):
b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0')
coll.insert_one({'dec128': dec128})
doc = coll.find_one({'dec128': dec128})
assert doc is not None
self.assertIsNotNone(doc)
self.assertEqual(doc['dec128'], dec128)

View File

@ -364,10 +364,12 @@ class TestIntegration(SpecRunner):
def marked_unknown(e):
return (isinstance(e, monitoring.ServerDescriptionChangedEvent)
and not e.new_description.is_server_type_known)
assert self.server_listener is not None
return len(self.server_listener.matching(marked_unknown))
# Only support CMAP events for now.
self.assertTrue(event.startswith('Pool') or event.startswith('Conn'))
event_type = getattr(monitoring, event)
assert self.pool_listener is not None
return self.pool_listener.event_count(event_type)
def assert_event_count(self, event, count):

View File

@ -25,6 +25,10 @@ import textwrap
import traceback
import uuid
from typing import Any
from pymongo.collection import Collection
sys.path[0:0] = [""]
from bson import encode, json_util
@ -126,6 +130,7 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
with self.assertRaisesRegex(
TypeError, r'kms_tls_options\["kmip"\] must be a dict'):
AutoEncryptionOpts({}, 'k.d', kms_tls_options={'kmip': 1})
tls_opts: Any
for tls_opts in [
{'kmip': {'tls': True, 'tlsInsecure': True}},
{'kmip': {'tls': True, 'tlsAllowInvalidCertificates': True}},
@ -138,6 +143,7 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
AutoEncryptionOpts({}, 'k.d', kms_tls_options={
'kmip': {'tlsCAFile': 'does-not-exist'}})
# Success cases:
tls_opts: Any
for tls_opts in [None, {}]:
opts = AutoEncryptionOpts({}, 'k.d', kms_tls_options=tls_opts)
self.assertEqual(opts._kms_ssl_contexts, {})
@ -432,14 +438,14 @@ class TestExplicitSimple(EncryptionIntegrationTest):
msg = 'value to decrypt must be a bson.binary.Binary with subtype 6'
with self.assertRaisesRegex(TypeError, msg):
client_encryption.decrypt('str')
client_encryption.decrypt('str') # type: ignore[arg-type]
with self.assertRaisesRegex(TypeError, msg):
client_encryption.decrypt(Binary(b'123'))
msg = 'key_id must be a bson.binary.Binary with subtype 4'
algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
with self.assertRaisesRegex(TypeError, msg):
client_encryption.encrypt('str', algo, key_id=uuid.uuid4())
client_encryption.encrypt('str', algo, key_id=uuid.uuid4()) # type: ignore[arg-type]
with self.assertRaisesRegex(TypeError, msg):
client_encryption.encrypt('str', algo, key_id=Binary(b'123'))
@ -459,7 +465,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
def test_codec_options(self):
with self.assertRaisesRegex(TypeError, 'codec_options must be'):
ClientEncryption(
KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None)
KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None) # type: ignore[arg-type]
opts = CodecOptions(uuid_representation=JAVA_LEGACY)
client_encryption_legacy = ClientEncryption(
@ -708,6 +714,10 @@ def create_key_vault(vault, *data_keys):
class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
client_encrypted: MongoClient
client_encryption: ClientEncryption
listener: OvertCommandListener
vault: Any
KMS_PROVIDERS = ALL_KMS_PROVIDERS
@ -776,7 +786,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
def run_test(self, provider_name):
# Create data key.
master_key = self.MASTER_KEYS[provider_name]
master_key: Any = self.MASTER_KEYS[provider_name]
datakey_id = self.client_encryption.create_data_key(
provider_name, master_key=master_key,
key_alt_names=['%s_altname' % (provider_name,)])
@ -798,7 +808,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
{'_id': provider_name, 'value': encrypted})
doc_decrypted = self.client_encrypted.db.coll.find_one(
{'_id': provider_name})
self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,))
self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,)) # type: ignore
# Encrypt by key_alt_name.
encrypted_altname = self.client_encryption.encrypt(
@ -985,7 +995,7 @@ class TestCorpus(EncryptionIntegrationTest):
self.addCleanup(client_encryption.close)
corpus = self.fix_up_curpus(json_data('corpus', 'corpus.json'))
corpus_copied = SON()
corpus_copied: SON = SON()
for key, value in corpus.items():
corpus_copied[key] = copy.deepcopy(value)
if key in ('_id', 'altname_aws', 'altname_azure', 'altname_gcp',
@ -1021,7 +1031,7 @@ class TestCorpus(EncryptionIntegrationTest):
try:
encrypted_val = client_encryption.encrypt(
value['value'], algo, **kwargs)
value['value'], algo, **kwargs) # type: ignore[arg-type]
if not value['allowed']:
self.fail('encrypt should have failed: %r: %r' % (
key, value))
@ -1082,6 +1092,10 @@ _16_MiB = 16777216
class TestBsonSizeBatches(EncryptionIntegrationTest):
"""Prose tests for BSON size limits and batch splitting."""
coll: Collection
coll_encrypted: Collection
client_encrypted: MongoClient
listener: OvertCommandListener
@classmethod
def setUpClass(cls):
@ -1397,6 +1411,7 @@ class AzureGCPEncryptionTestMixin(object):
KMS_PROVIDER_MAP = None
KEYVAULT_DB = 'keyvault'
KEYVAULT_COLL = 'datakeys'
client: MongoClient
def setUp(self):
keyvault = self.client.get_database(
@ -1406,7 +1421,7 @@ class AzureGCPEncryptionTestMixin(object):
def _test_explicit(self, expectation):
client_encryption = ClientEncryption(
self.KMS_PROVIDER_MAP,
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
'.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
client_context.client,
OPTS)
@ -1426,7 +1441,7 @@ class AzureGCPEncryptionTestMixin(object):
keyvault_namespace = '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
encryption_opts = AutoEncryptionOpts(
self.KMS_PROVIDER_MAP,
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
keyvault_namespace,
schema_map=self.SCHEMA_MAP)
@ -1818,7 +1833,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
def setUp(self):
super(TestKmsTLSOptions, self).setUp()
# 1, create client with only tlsCAFile.
providers = copy.deepcopy(ALL_KMS_PROVIDERS)
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8002'
providers['gcp']['endpoint'] = '127.0.0.1:8002'
kms_tls_opts_ca_only = {
@ -1840,7 +1855,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
kms_tls_options=kms_tls_opts)
self.addCleanup(self.client_encryption_with_tls.close)
# 3, update endpoints to expired host.
providers = copy.deepcopy(providers)
providers: dict = copy.deepcopy(providers)
providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8000'
providers['gcp']['endpoint'] = '127.0.0.1:8000'
providers['kmip']['endpoint'] = '127.0.0.1:8000'
@ -1849,7 +1864,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
kms_tls_options=kms_tls_opts_ca_only)
self.addCleanup(self.client_encryption_expired.close)
# 3, update endpoints to invalid host.
providers = copy.deepcopy(providers)
providers: dict = copy.deepcopy(providers)
providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8001'
providers['gcp']['endpoint'] = '127.0.0.1:8001'
providers['kmip']['endpoint'] = '127.0.0.1:8001'

View File

@ -890,6 +890,7 @@ class TestTransactionExamples(IntegrationTest):
update_employee_info(session)
employee = employees.find_one({"employee": 3})
assert employee is not None
self.assertIsNotNone(employee)
self.assertEqual(employee['status'], 'Inactive')
@ -916,6 +917,7 @@ class TestTransactionExamples(IntegrationTest):
run_transaction_with_retry(update_employee_info, session)
employee = employees.find_one({"employee": 3})
assert employee is not None
self.assertIsNotNone(employee)
self.assertEqual(employee['status'], 'Inactive')
@ -954,6 +956,7 @@ class TestTransactionExamples(IntegrationTest):
run_transaction_with_retry(_insert_employee_retry_commit, session)
employee = employees.find_one({"employee": 4})
assert employee is not None
self.assertIsNotNone(employee)
self.assertEqual(employee['status'], 'Active')
@ -1021,6 +1024,7 @@ class TestTransactionExamples(IntegrationTest):
# End Transactions Retry Example 3
employee = employees.find_one({"employee": 3})
assert employee is not None
self.assertIsNotNone(employee)
self.assertEqual(employee['status'], 'Inactive')
@ -1089,6 +1093,9 @@ class TestCausalConsistencyExamples(IntegrationTest):
'start': current_date}, session=s1)
# End Causal Consistency Example 1
assert s1.cluster_time is not None
assert s1.operation_time is not None
# Start Causal Consistency Example 2
with client.start_session(causal_consistency=True) as s2:
s2.advance_cluster_time(s1.cluster_time)

View File

@ -24,6 +24,8 @@ import zipfile
from io import BytesIO
from pymongo.database import Database
sys.path[0:0] = [""]
from bson.objectid import ObjectId
@ -47,6 +49,7 @@ from test.utils import rs_or_single_client, EventListener
class TestGridFileNoConnect(unittest.TestCase):
"""Test GridFile features on a client that does not connect.
"""
db: Database
@classmethod
def setUpClass(cls):

View File

@ -27,6 +27,7 @@ from io import BytesIO
sys.path[0:0] = [""]
from bson.binary import Binary
from pymongo.database import Database
from pymongo.mongo_client import MongoClient
from pymongo.errors import (ConfigurationError,
NotPrimaryError,
@ -78,6 +79,7 @@ class JustRead(threading.Thread):
class TestGridfsNoConnect(unittest.TestCase):
db: Database
@classmethod
def setUpClass(cls):
@ -89,6 +91,8 @@ class TestGridfsNoConnect(unittest.TestCase):
class TestGridfs(IntegrationTest):
fs: gridfs.GridFS
alt: gridfs.GridFS
@classmethod
def setUpClass(cls):
@ -152,6 +156,7 @@ class TestGridfs(IntegrationTest):
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
raw = self.db.fs.files.find_one()
assert raw is not None
self.assertEqual(0, raw["length"])
self.assertEqual(oid, raw["_id"])
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
@ -213,7 +218,7 @@ class TestGridfs(IntegrationTest):
self.fs.put(b"hello", _id="test")
threads = []
results = []
results: list = []
for i in range(10):
threads.append(JustRead(self.fs, 10, results))
threads[i].start()
@ -396,6 +401,7 @@ class TestGridfs(IntegrationTest):
# Test fix that guards against PHP-237
self.fs.put(b"", filename="empty")
doc = self.db.fs.files.find_one({"filename": "empty"})
assert doc is not None
doc.pop("length")
self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
f = self.fs.get_last_version(filename="empty")
@ -447,23 +453,32 @@ class TestGridfs(IntegrationTest):
# but will still call __del__.
cursor = GridOutCursor.__new__(GridOutCursor) # Skip calling __init__
with self.assertRaises(TypeError):
cursor.__init__(self.db.fs.files, {}, {"_id": True})
cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore
cursor.__del__() # no error
def test_gridfs_find_one(self):
self.assertEqual(None, self.fs.find_one())
id1 = self.fs.put(b'test1', filename='file1')
self.assertEqual(b'test1', self.fs.find_one().read())
res = self.fs.find_one()
assert res is not None
self.assertEqual(b'test1', res.read())
id2 = self.fs.put(b'test2', filename='file2', meta='data')
self.assertEqual(b'test1', self.fs.find_one(id1).read())
self.assertEqual(b'test2', self.fs.find_one(id2).read())
res1 = self.fs.find_one(id1)
assert res1 is not None
self.assertEqual(b'test1', res1.read())
res2 = self.fs.find_one(id2)
assert res2 is not None
self.assertEqual(b'test2', res2.read())
self.assertEqual(b'test1',
self.fs.find_one({'filename': 'file1'}).read())
res3 = self.fs.find_one({'filename': 'file1'})
assert res3 is not None
self.assertEqual(b'test1', res3.read())
self.assertEqual('data', self.fs.find_one(id2).meta)
res4 = self.fs.find_one(id2)
assert res4 is not None
self.assertEqual('data', res4.meta)
def test_grid_in_non_int_chunksize(self):
# Lua, and perhaps other buggy GridFS clients, store size as a float.

View File

@ -77,6 +77,8 @@ class JustRead(threading.Thread):
class TestGridfs(IntegrationTest):
fs: gridfs.GridFSBucket
alt: gridfs.GridFSBucket
@classmethod
def setUpClass(cls):
@ -123,6 +125,7 @@ class TestGridfs(IntegrationTest):
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
raw = self.db.fs.files.find_one()
assert raw is not None
self.assertEqual(0, raw["length"])
self.assertEqual(oid, raw["_id"])
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
@ -208,7 +211,7 @@ class TestGridfs(IntegrationTest):
self.fs.upload_from_stream("test", b"hello")
threads = []
results = []
results: list = []
for i in range(10):
threads.append(JustRead(self.fs, 10, results))
threads[i].start()
@ -322,6 +325,7 @@ class TestGridfs(IntegrationTest):
# Test fix that guards against PHP-237
self.fs.upload_from_stream("empty", b"")
doc = self.db.fs.files.find_one({"filename": "empty"})
assert doc is not None
doc.pop("length")
self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
fstr = self.fs.open_download_stream_by_name("empty")

View File

@ -55,6 +55,9 @@ def camel_to_snake(camel):
class TestAllScenarios(IntegrationTest):
fs: gridfs.GridFSBucket
str_to_cmd: dict
@classmethod
def setUpClass(cls):
super(TestAllScenarios, cls).setUpClass()

View File

@ -20,6 +20,8 @@ import re
import sys
import uuid
from typing import Any, List, MutableMapping
sys.path[0:0] = [""]
from bson import json_util, EPOCH_AWARE, EPOCH_NAIVE, SON
@ -466,7 +468,7 @@ class TestJsonUtilRoundtrip(IntegrationTest):
db = self.db
db.drop_collection("test")
docs = [
docs: List[MutableMapping[str, Any]] = [
{'foo': [1, 2]},
{'bar': {'hello': 'world'}},
{'code': Code("function x() { return 1; }")},

View File

@ -35,7 +35,7 @@ _TEST_PATH = os.path.join(
'max_staleness')
class TestAllScenarios(create_selection_tests(_TEST_PATH)):
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
pass

View File

@ -59,7 +59,7 @@ class TestMonitor(IntegrationTest):
# Each executor stores a weakref to itself in _EXECUTORS.
executor_refs = [
(r, r()._name) for r in _EXECUTORS.copy() if r() in executors]
(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] # type: ignore
del executors
del client

View File

@ -16,6 +16,7 @@ import copy
import datetime
import sys
import time
from typing import Any
import warnings
sys.path[0:0] = [""]
@ -43,6 +44,7 @@ from test.utils import (EventListener,
class TestCommandMonitoring(IntegrationTest):
listener: EventListener
@classmethod
@client_context.require_connection
@ -754,7 +756,7 @@ class TestCommandMonitoring(IntegrationTest):
# delete_one
self.listener.results.clear()
res = coll.delete_one({'x': 3})
res2 = coll.delete_one({'x': 3})
results = self.listener.results
started = results['started'][0]
succeeded = results['succeeded'][0]
@ -1091,6 +1093,8 @@ class TestCommandMonitoring(IntegrationTest):
class TestGlobalListener(IntegrationTest):
listener: EventListener
saved_listeners: Any
@classmethod
@client_context.require_connection
@ -1167,13 +1171,13 @@ class TestEventClasses(unittest.TestCase):
"<ServerHeartbeatStartedEvent ('localhost', 27017)>")
delta = 0.1
event = monitoring.ServerHeartbeatSucceededEvent(
delta, {'ok': 1}, connection_id)
delta, {'ok': 1}, connection_id) # type: ignore[arg-type]
self.assertEqual(
repr(event),
"<ServerHeartbeatSucceededEvent ('localhost', 27017) "
"duration: 0.1, awaited: False, reply: {'ok': 1}>")
event = monitoring.ServerHeartbeatFailedEvent(
delta, 'ERROR', connection_id)
delta, 'ERROR', connection_id) # type: ignore[arg-type]
self.assertEqual(
repr(event),
"<ServerHeartbeatFailedEvent ('localhost', 27017) "
@ -1188,7 +1192,7 @@ class TestEventClasses(unittest.TestCase):
"<ServerOpeningEvent ('localhost', 27017) "
"topology_id: 000000000000000000000001>")
event = monitoring.ServerDescriptionChangedEvent(
'PREV', 'NEW', server_address, topology_id)
'PREV', 'NEW', server_address, topology_id) # type: ignore[arg-type]
self.assertEqual(
repr(event),
"<ServerDescriptionChangedEvent ('localhost', 27017) "
@ -1206,7 +1210,7 @@ class TestEventClasses(unittest.TestCase):
repr(event),
"<TopologyOpenedEvent topology_id: 000000000000000000000001>")
event = monitoring.TopologyDescriptionChangedEvent(
'PREV', 'NEW', topology_id)
'PREV', 'NEW', topology_id) # type: ignore[arg-type]
self.assertEqual(
repr(event),
"<TopologyDescriptionChangedEvent "

View File

@ -106,7 +106,9 @@ class TestObjectId(unittest.TestCase):
aware = datetime.datetime(1993, 4, 4, 2,
tzinfo=FixedOffset(555, "SomeZone"))
as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc)
offset = aware.utcoffset()
assert offset is not None
as_utc = (aware - offset).replace(tzinfo=utc)
oid = ObjectId.from_datetime(aware)
self.assertEqual(as_utc, oid.generation_time)

View File

@ -21,6 +21,8 @@ import random
import sys
from time import sleep
from typing import Any
sys.path[0:0] = [""]
from pymongo.ocsp_cache import _OCSPCache
@ -28,14 +30,18 @@ from test import unittest
class TestOcspCache(unittest.TestCase):
MockHashAlgorithm: Any
MockOcspRequest: Any
MockOcspResponse: Any
@classmethod
def setUpClass(cls):
cls.MockHashAlgorithm = namedtuple(
cls.MockHashAlgorithm = namedtuple( # type: ignore
"MockHashAlgorithm", ['name'])
cls.MockOcspRequest = namedtuple(
cls.MockOcspRequest = namedtuple( # type: ignore
"MockOcspRequest", ['hash_algorithm', 'issuer_name_hash',
'issuer_key_hash', 'serial_number'])
cls.MockOcspResponse = namedtuple(
cls.MockOcspResponse = namedtuple( # type: ignore
"MockOcspResponse", ["this_update", "next_update"])
def setUp(self):

View File

@ -82,6 +82,7 @@ class TestRawBSONDocument(IntegrationTest):
codec_options=CodecOptions(document_class=RawBSONDocument))
db.test_raw.insert_one(self.document)
result = db.test_raw.find_one(self.document['_id'])
assert result is not None
self.assertIsInstance(result, RawBSONDocument)
self.assertEqual(dict(self.document.items()), dict(result.items()))
@ -146,6 +147,7 @@ class TestRawBSONDocument(IntegrationTest):
db = self.client.pymongo_test
db.test_raw.insert_one(doc)
result = db.test_raw.find_one()
assert result is not None
self.assertEqual(decode(self.document.raw), result['embedded'])
# Make sure that CodecOptions are preserved.
@ -169,6 +171,7 @@ class TestRawBSONDocument(IntegrationTest):
db.test_raw.insert_one(rbd)
result = db.get_collection('test_raw', codec_options=CodecOptions(
uuid_representation=JAVA_LEGACY)).find_one()
assert result is not None
self.assertEqual(rbd['embedded'][0]['_id'],
result['embedded'][0]['_id'])

View File

@ -23,6 +23,7 @@ from test.utils import single_client, rs_or_single_client, OvertCommandListener
class TestReadConcern(IntegrationTest):
listener: OvertCommandListener
@classmethod
@client_context.require_connection

View File

@ -225,7 +225,7 @@ class TestReadPreferences(TestReadPreferencesBase):
localthresholdms=-1)
def test_zero_latency(self):
ping_times = set()
ping_times: set = set()
# Generate unique ping times.
while len(ping_times) < len(self.client.nodes):
ping_times.add(random.random())
@ -278,7 +278,7 @@ class TestReadPreferences(TestReadPreferencesBase):
# far, and keep reading until we've used all the members or give up.
# Chance of using only 2 of 3 members 10k times if there's no bug =
# 3 * (2/3)**10000, very low.
used = set()
used: set = set()
i = 0
while data_members.difference(used) and i < 10000:
address = self.read_from_which_host(c)
@ -335,6 +335,8 @@ _PREF_MAP = [
class TestCommandAndReadPreference(IntegrationTest):
c: ReadPrefTester
client_version: Version
@classmethod
@client_context.require_secondaries_count(1)
@ -378,6 +380,7 @@ class TestCommandAndReadPreference(IntegrationTest):
# Success
break
assert self.c.primary is not None
unused = self.c.secondaries.union(
set([self.c.primary])
).difference(used)
@ -445,11 +448,11 @@ class TestMovingAverage(unittest.TestCase):
avg = MovingAverage()
self.assertIsNone(avg.get())
avg.add_sample(10)
self.assertAlmostEqual(10, avg.get())
self.assertAlmostEqual(10, avg.get()) # type: ignore
avg.add_sample(20)
self.assertAlmostEqual(12, avg.get())
self.assertAlmostEqual(12, avg.get()) # type: ignore
avg.add_sample(30)
self.assertAlmostEqual(15.6, avg.get())
self.assertAlmostEqual(15.6, avg.get()) # type: ignore
class TestMongosAndReadPreference(IntegrationTest):
@ -526,7 +529,8 @@ class TestMongosAndReadPreference(IntegrationTest):
'maxStalenessSeconds': 30})
with self.assertRaises(TypeError):
Nearest(max_staleness=1.5) # Float is prohibited.
# Float is prohibited.
Nearest(max_staleness=1.5) # type: ignore
with self.assertRaises(ValueError):
Nearest(max_staleness=0)
@ -543,7 +547,7 @@ class TestMongosAndReadPreference(IntegrationTest):
}
for mode, cls in cases.items():
with self.assertRaises(TypeError):
cls(hedge=[])
cls(hedge=[]) # type: ignore
pref = cls(hedge={})
self.assertEqual(pref.document, {'mode': mode})
@ -668,7 +672,9 @@ class TestMongosAndReadPreference(IntegrationTest):
@client_context.require_mongos
def test_mongos(self):
shard = client_context.client.config.shards.find_one()['host']
res = client_context.client.config.shards.find_one()
assert res is not None
shard = res['host']
num_members = shard.count(',') + 1
if num_members == 1:
raise SkipTest("Need a replica set shard to test.")

View File

@ -149,12 +149,14 @@ class TestReadWriteConcernSpec(IntegrationTest):
f()
if expected == BulkWriteError:
bulk_result = cm.exception.details
assert bulk_result is not None
wc_errors = bulk_result['writeConcernErrors']
self.assertTrue(wc_errors)
@client_context.require_replica_set
def test_raise_write_concern_error(self):
self.addCleanup(client_context.client.drop_database, 'pymongo_test')
assert client_context.w is not None
self.assertWriteOpsRaise(
WriteConcern(w=client_context.w+1, wtimeout=1), WriteConcernError)
@ -219,6 +221,7 @@ class TestReadWriteConcernSpec(IntegrationTest):
db.test.insert_one({'x': 1})
self.assertEqual(ctx.exception.code, 121)
self.assertIsNotNone(ctx.exception.details)
assert ctx.exception.details is not None
self.assertIsNotNone(ctx.exception.details.get('errInfo'))
for event in listener.results['succeeded']:
if event.command_name == 'insert':
@ -290,19 +293,19 @@ def create_document_test(test_case):
WriteConcern,
**normalized)
else:
concern = WriteConcern(**normalized)
write_concern = WriteConcern(**normalized)
self.assertEqual(
concern.document, test_case['writeConcernDocument'])
write_concern.document, test_case['writeConcernDocument'])
self.assertEqual(
concern.acknowledged, test_case['isAcknowledged'])
write_concern.acknowledged, test_case['isAcknowledged'])
self.assertEqual(
concern.is_server_default, test_case['isServerDefault'])
write_concern.is_server_default, test_case['isServerDefault'])
if 'readConcern' in test_case:
# Any string for 'level' is equaly valid
concern = ReadConcern(**test_case['readConcern'])
self.assertEqual(concern.document, test_case['readConcernDocument'])
read_concern = ReadConcern(**test_case['readConcern'])
self.assertEqual(read_concern.document, test_case['readConcernDocument'])
self.assertEqual(
not bool(concern.level), test_case['isServerDefault'])
not bool(read_concern.level), test_case['isServerDefault'])
return run_test

View File

@ -135,6 +135,7 @@ def non_retryable_single_statement_ops(coll):
class IgnoreDeprecationsTest(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter
@classmethod
def setUpClass(cls):
@ -148,6 +149,7 @@ class IgnoreDeprecationsTest(IntegrationTest):
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs
@classmethod
def setUpClass(cls):
@ -180,6 +182,8 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener
knobs: client_knobs
@classmethod
@client_context.require_no_mmap
@ -426,6 +430,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
class TestWriteConcernError(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
fail_insert: dict
@classmethod
@client_context.require_replica_set

View File

@ -24,6 +24,7 @@ sys.path[0:0] = [""]
from pymongo import MongoClient
from bson.json_util import object_hook
from pymongo import monitoring
from pymongo.collection import Collection
from pymongo.common import clean_node
from pymongo.errors import (ConnectionFailure,
NotPrimaryError)
@ -253,6 +254,10 @@ create_tests()
class TestSdamMonitoring(IntegrationTest):
knobs: client_knobs
listener: ServerAndTopologyEventListener
test_client: MongoClient
coll: Collection
@classmethod
@client_context.require_failCommand_fail_point

View File

@ -52,7 +52,7 @@ class SelectionStoreSelector(object):
class TestAllScenarios(create_selection_tests(_TEST_PATH)):
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
pass
@ -125,7 +125,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
def test_latency_threshold_application(self):
selector = SelectionStoreSelector()
scenario_def = {
scenario_def: dict = {
'topology_description': {
'type': 'ReplicaSetWithPrimary', 'servers': [
{'address': 'b:27017',
@ -160,6 +160,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
# Invoke server selection and assert no filtering based on latency
# prior to custom server selection logic kicking in.
server = topology.select_server(ReadPreference.NEAREST)
assert selector.selection is not None
self.assertEqual(
len(selector.selection),
len(topology.description.server_descriptions()))

View File

@ -116,7 +116,7 @@ class TestProse(IntegrationTest):
self.assertEqual(len(events), N_FINDS * N_THREADS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0 for address in nodes}
freqs = {address: 0.0 for address in nodes}
for event in events:
freqs[event.connection_id] += 1
for address in freqs:

View File

@ -20,6 +20,9 @@ import sys
import time
from io import BytesIO
from typing import Set
from pymongo.mongo_client import MongoClient
sys.path[0:0] = [""]
@ -64,6 +67,8 @@ def session_ids(client):
class TestSession(IntegrationTest):
client2: MongoClient
sensitive_commands: Set[str]
@classmethod
@client_context.require_sessions
@ -231,7 +236,7 @@ class TestSession(IntegrationTest):
def test_client(self):
client = self.client
ops = [
ops: list = [
(client.server_info, [], {}),
(client.list_database_names, [], {}),
(client.drop_database, ['pymongo_test'], {}),
@ -242,7 +247,7 @@ class TestSession(IntegrationTest):
def test_database(self):
client = self.client
db = client.pymongo_test
ops = [
ops: list = [
(db.command, ['ping'], {}),
(db.create_collection, ['collection'], {}),
(db.list_collection_names, [], {}),
@ -493,6 +498,7 @@ class TestSession(IntegrationTest):
# Explicit session.
with client.start_session() as s:
cursor = bucket.find(session=s)
assert cursor.session is not None
s = cursor.session
files = list(cursor)
cursor.__del__()
@ -680,7 +686,7 @@ class TestSession(IntegrationTest):
self.addCleanup(client.close)
db = client.pymongo_test
coll = db.test_unacked_writes
ops = [
ops: list = [
(client.drop_database, [db.name], {}),
(db.create_collection, ['collection'], {}),
(db.drop_collection, ['collection'], {}),
@ -722,6 +728,8 @@ class TestSession(IntegrationTest):
self.assertRaises(TypeError, lambda: copy.copy(s))
class TestCausalConsistency(unittest.TestCase):
listener: SessionTestListener
client: MongoClient
@classmethod
def setUpClass(cls):
@ -778,6 +786,8 @@ class TestCausalConsistency(unittest.TestCase):
self.assertRaises(ValueError, sess2.advance_cluster_time, {})
self.assertRaises(TypeError, sess2.advance_operation_time, 1)
# No error
assert sess.cluster_time is not None
assert sess.operation_time is not None
sess2.advance_cluster_time(sess.cluster_time)
sess2.advance_operation_time(sess.operation_time)
self.assertEqual(sess.cluster_time, sess2.cluster_time)

View File

@ -17,6 +17,7 @@
import sys
from time import sleep
from typing import Any
sys.path[0:0] = [""]
@ -54,6 +55,7 @@ class SrvPollingKnobs(object):
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
def mock_get_hosts_and_min_ttl(resolver, *args):
assert self.old_dns_resolver_response is not None
nodes, ttl = self.old_dns_resolver_response(resolver)
if self.nodelist_callback is not None:
nodes = self.nodelist_callback()
@ -61,20 +63,22 @@ class SrvPollingKnobs(object):
ttl = self.ttl_time
return nodes, ttl
patch_func: Any
if self.count_resolver_calls:
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
else:
patch_func = mock_get_hosts_and_min_ttl
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
def __enter__(self):
self.enable()
def disable(self):
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \
self.old_dns_resolver_response
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
self.old_dns_resolver_response # type: ignore
)
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
@ -128,7 +132,7 @@ class TestSrvPolling(unittest.TestCase):
msg = "Client nodelist %s changed unexpectedly (expected %s)"
raise self.fail(msg % (nodelist, expected_nodelist))
self.assertGreaterEqual(
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count,
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
1, "resolver was never called")
return True
@ -196,7 +200,7 @@ class TestSrvPolling(unittest.TestCase):
self.run_scenario(response_callback, False)
def test_dns_record_lookup_empty(self):
response = []
response: list = []
self.run_scenario(response, False)
def _test_recover_from_initial(self, initial_callback):

View File

@ -17,6 +17,7 @@
import os
import socket
import sys
from typing import Any
sys.path[0:0] = [""]
@ -52,7 +53,7 @@ try:
from pymongo.ocsp_support import _load_trusted_ca_certs
_HAVE_PYOPENSSL = True
except ImportError:
_load_trusted_ca_certs = None
_load_trusted_ca_certs = None # type: ignore
if HAVE_SSL:
@ -148,6 +149,7 @@ class TestClientSSL(unittest.TestCase):
class TestSSL(IntegrationTest):
saved_port: int
def assertClientWorks(self, client):
coll = client.pymongo_test.ssl_test.with_options(
@ -200,13 +202,13 @@ class TestSSL(IntegrationTest):
tlsCertificateKeyFilePassword="qwerty",
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=5000,
**self.credentials))
**self.credentials)) # type: ignore
uri_fmt = ("mongodb://localhost/?ssl=true"
"&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty"
"&tlsCAFile=%s&serverSelectionTimeoutMS=5000")
connected(MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM),
**self.credentials))
**self.credentials)) # type: ignore
@client_context.require_tlsCertificateKeyFile
@client_context.require_no_auth
@ -313,7 +315,7 @@ class TestSSL(IntegrationTest):
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=500,
**self.credentials))
**self.credentials)) # type: ignore
connected(MongoClient('server',
ssl=True,
@ -322,7 +324,7 @@ class TestSSL(IntegrationTest):
tlsCAFile=CA_PEM,
tlsAllowInvalidHostnames=True,
serverSelectionTimeoutMS=500,
**self.credentials))
**self.credentials)) # type: ignore
if 'setName' in response:
with self.assertRaises(ConnectionFailure):
@ -333,7 +335,7 @@ class TestSSL(IntegrationTest):
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=500,
**self.credentials))
**self.credentials)) # type: ignore
connected(MongoClient('server',
replicaSet=response['setName'],
@ -343,7 +345,7 @@ class TestSSL(IntegrationTest):
tlsCAFile=CA_PEM,
tlsAllowInvalidHostnames=True,
serverSelectionTimeoutMS=500,
**self.credentials))
**self.credentials)) # type: ignore
@client_context.require_tlsCertificateKeyFile
@ignore_deprecations
@ -362,7 +364,7 @@ class TestSSL(IntegrationTest):
ssl=True,
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=100,
**self.credentials))
**self.credentials)) # type: ignore
with self.assertRaises(ConnectionFailure):
connected(MongoClient('localhost',
@ -370,18 +372,18 @@ class TestSSL(IntegrationTest):
tlsCAFile=CA_PEM,
tlsCRLFile=CRL_PEM,
serverSelectionTimeoutMS=100,
**self.credentials))
**self.credentials)) # type: ignore
uri_fmt = ("mongodb://localhost/?ssl=true&"
"tlsCAFile=%s&serverSelectionTimeoutMS=100")
connected(MongoClient(uri_fmt % (CA_PEM,),
**self.credentials))
**self.credentials)) # type: ignore
uri_fmt = ("mongodb://localhost/?ssl=true&tlsCRLFile=%s"
"&tlsCAFile=%s&serverSelectionTimeoutMS=100")
with self.assertRaises(ConnectionFailure):
connected(MongoClient(uri_fmt % (CRL_PEM, CA_PEM),
**self.credentials))
**self.credentials)) # type: ignore
@client_context.require_tlsCertificateKeyFile
@client_context.require_server_resolvable
@ -399,26 +401,26 @@ class TestSSL(IntegrationTest):
connected(MongoClient('server',
ssl=True,
serverSelectionTimeoutMS=100,
**self.credentials))
**self.credentials)) # type: ignore
# Server cert is verified. Disable hostname matching.
connected(MongoClient('server',
ssl=True,
tlsAllowInvalidHostnames=True,
serverSelectionTimeoutMS=100,
**self.credentials))
**self.credentials)) # type: ignore
# Server cert and hostname are verified.
connected(MongoClient('localhost',
ssl=True,
serverSelectionTimeoutMS=100,
**self.credentials))
**self.credentials)) # type: ignore
# Server cert and hostname are verified.
connected(
MongoClient(
'mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=100',
**self.credentials))
**self.credentials)) # type: ignore
def test_system_certs_config_error(self):
ctx = get_ssl_context(None, None, None, None, True, True, False)
@ -428,6 +430,7 @@ class TestSSL(IntegrationTest):
raise SkipTest(
"Can't test when system CA certificates are loadable.")
ssl_support: Any
have_certifi = ssl_support.HAVE_CERTIFI
have_wincertstore = ssl_support.HAVE_WINCERTSTORE
# Force the test regardless of environment.
@ -446,6 +449,7 @@ class TestSSL(IntegrationTest):
# with SSLContext and SSLContext provides no information
# about ca_certs.
raise SkipTest("Can't test when SSLContext available.")
ssl_support: Any
if not ssl_support.HAVE_CERTIFI:
raise SkipTest("Need certifi to test certifi support.")

View File

@ -87,7 +87,7 @@ class TestStreamingProtocol(IntegrationTest):
# 1-15 millisecond resolution. We need to delay the initial hello
# to ensure that RTT is never zero.
name = 'streamingRttTest'
delay_hello = {
delay_hello: dict = {
'configureFailPoint': 'failCommand',
'mode': {'times': 1000},
'data': {

View File

@ -90,19 +90,19 @@ class TestTransactions(TransactionsBase):
read_preference=ReadPreference.PRIMARY,
max_commit_time_ms=10000)
with self.assertRaisesRegex(TypeError, "read_concern must be "):
TransactionOptions(read_concern={})
TransactionOptions(read_concern={}) # type: ignore
with self.assertRaisesRegex(TypeError, "write_concern must be "):
TransactionOptions(write_concern={})
TransactionOptions(write_concern={}) # type: ignore
with self.assertRaisesRegex(
ConfigurationError,
"transactions do not support unacknowledged write concern"):
TransactionOptions(write_concern=WriteConcern(w=0))
with self.assertRaisesRegex(
TypeError, "is not valid for read_preference"):
TransactionOptions(read_preference={})
TransactionOptions(read_preference={}) # type: ignore
with self.assertRaisesRegex(
TypeError, "max_commit_time_ms must be an integer or None"):
TransactionOptions(max_commit_time_ms="10000")
TransactionOptions(max_commit_time_ms="10000") # type: ignore
@client_context.require_transactions
def test_transaction_write_concern_override(self):
@ -131,7 +131,7 @@ class TestTransactions(TransactionsBase):
coll.find_one_and_replace({}, {}, session=s)
coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s)
unsupported_txn_writes = [
unsupported_txn_writes: list = [
(client.drop_database, [db.name], {}),
(db.drop_collection, ['collection'], {}),
(coll.drop, [], {}),
@ -284,7 +284,7 @@ class TestTransactions(TransactionsBase):
InvalidOperation,
'GridFS does not support multi-document transactions',
):
op(*args, session=s)
op(*args, session=s) # type: ignore
# Require 4.2+ for large (16MB+) transactions.
@client_context.require_version_min(4, 2)
@ -360,11 +360,11 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.db.test.insert_one({})
def callback(session):
def callback2(session):
self.db.test.insert_one({}, session=session)
return 'Foo'
with self.client.start_session() as s:
self.assertEqual(s.with_transaction(callback), 'Foo')
self.assertEqual(s.with_transaction(callback2), 'Foo')
@client_context.require_transactions
def test_callback_not_retried_after_timeout(self):
@ -375,7 +375,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
def callback(session):
coll.insert_one({}, session=session)
err = {
err: dict = {
'ok': 0,
'errmsg': 'Transaction 7819 has been aborted.',
'code': 251,

View File

@ -186,7 +186,7 @@ class TestURI(unittest.TestCase):
self.assertRaises(ValueError,
parse_uri, "mongodb://::1", 27017)
orig = {
orig: dict = {
'nodelist': [("localhost", 27017)],
'username': None,
'password': None,
@ -196,7 +196,7 @@ class TestURI(unittest.TestCase):
'fqdn': None
}
res = copy.deepcopy(orig)
res: dict = copy.deepcopy(orig)
self.assertEqual(res, parse_uri("mongodb://localhost"))
res.update({'username': 'fred', 'password': 'foobar'})

View File

@ -66,7 +66,7 @@ class TestWriteConcern(unittest.TestCase):
self.assertNotEqual(WriteConcern(wtimeout=42), _FakeWriteConcern(wtimeout=2000))
def test_equality_incompatible_type(self):
_fake_type = collections.namedtuple('NotAWriteConcern', ['document'])
_fake_type = collections.namedtuple('NotAWriteConcern', ['document']) # type: ignore
self.assertNotEqual(WriteConcern(j=True), _fake_type({'j': True}))

View File

@ -27,6 +27,7 @@ import time
import types
from collections import abc
from typing import Any
from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey
from bson.binary import Binary
@ -296,7 +297,7 @@ class EntityMapUtil(object):
entity_type, spec = next(iter(entity_spec.items()))
if entity_type == 'client':
kwargs = {}
kwargs: dict = {}
observe_events = spec.get('observeEvents', [])
ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
observe_sensitive_commands = spec.get(
@ -691,6 +692,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
SCHEMA_VERSION = Version.from_string('1.5')
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
TEST_SPEC: Any
@staticmethod
def should_run_on(run_on_spec):
@ -1213,6 +1215,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
class UnifiedSpecTestMeta(type):
"""Metaclass for generating test classes."""
TEST_SPEC: Any
EXPECTED_FAILURES: Any
def __init__(cls, *args, **kwargs):
super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs)
@ -1258,7 +1263,7 @@ def generate_test_classes(test_path, module=__name__, class_name_prefix='',
"""Utility that creates the base class to use for test generation.
This is needed to ensure that cls.TEST_SPEC is appropriately set when
the metaclass __init__ is invoked."""
class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)):
class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore
TEST_SPEC = test_spec
EXPECTED_FAILURES = expected_failures
return SpecTestBase

View File

@ -226,7 +226,7 @@ class ServerEventListener(_ServerEventListener,
"""Listens to Server events."""
class ServerAndTopologyEventListener(ServerEventListener,
class ServerAndTopologyEventListener(ServerEventListener, # type: ignore
monitoring.TopologyListener):
"""Listens to Server and Topology events."""
@ -519,7 +519,7 @@ def _mongo_client(host, port, authenticate=True, directConnection=None,
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options = client_context.default_client_options.copy()
client_options: dict = client_context.default_client_options.copy()
if client_context.replica_set_name and not directConnection:
client_options['replicaSet'] = client_context.replica_set_name
if directConnection is not None:
@ -678,7 +678,7 @@ def server_started_with_auth(client):
try:
command_line = get_command_line(client)
except OperationFailure as e:
msg = e.details.get('errmsg', '')
msg = e.details.get('errmsg', '') # type: ignore
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
# Unauthorized.
return True
@ -818,8 +818,8 @@ class DeprecationFilter(object):
def stop(self):
"""Stop filtering deprecations."""
self.warn_context.__exit__()
self.warn_context = None
self.warn_context.__exit__() # type: ignore
self.warn_context = None # type: ignore
def get_pool(client):
@ -862,23 +862,13 @@ def run_threads(collection, target):
@contextlib.contextmanager
def frequent_thread_switches():
"""Make concurrency bugs more likely to manifest."""
interval = None
if not sys.platform.startswith('java'):
if hasattr(sys, 'getswitchinterval'):
interval = sys.getswitchinterval()
sys.setswitchinterval(1e-6)
else:
interval = sys.getcheckinterval()
sys.setcheckinterval(1)
interval = sys.getswitchinterval()
sys.setswitchinterval(1e-6)
try:
yield
finally:
if not sys.platform.startswith('java'):
if hasattr(sys, 'setswitchinterval'):
sys.setswitchinterval(interval)
else:
sys.setcheckinterval(interval)
sys.setswitchinterval(interval)
def lazy_client_trial(reset, target, test, get_client):
@ -994,6 +984,7 @@ def assertion_context(msg):
except AssertionError as exc:
msg = '%s (%s)' % (exc, msg)
exc_type, exc_val, exc_tb = sys.exc_info()
assert exc_type is not None
raise exc_type(exc_val).with_traceback(exc_tb)

View File

@ -18,6 +18,7 @@ import functools
import threading
from collections import abc
from typing import List
from bson import decode, encode
from bson.binary import Binary
@ -40,7 +41,7 @@ from pymongo.write_concern import WriteConcern
from test import (client_context,
client_knobs,
IntegrationTest)
from test.utils import (camel_to_snake,
from test.utils import (EventListener, camel_to_snake,
camel_to_snake_args,
CompareType,
CMAPListener,
@ -86,6 +87,9 @@ class SpecRunnerThread(threading.Thread):
class SpecRunner(IntegrationTest):
mongos_clients: List
knobs: client_knobs
listener: EventListener
@classmethod
def setUpClass(cls):
@ -105,7 +109,7 @@ class SpecRunner(IntegrationTest):
def setUp(self):
super(SpecRunner, self).setUp()
self.targets = {}
self.listener = None
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
@ -219,6 +223,7 @@ class SpecRunner(IntegrationTest):
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, prop)
elif prop == "upserted_ids":
# Convert indexes from strings to integers.