PYTHON-3064 Add typings to test package (#844)
This commit is contained in:
parent
561ee7cf77
commit
f4cef37328
3
.github/workflows/test-python.yml
vendored
3
.github/workflows/test-python.yml
vendored
@ -46,4 +46,5 @@ jobs:
|
||||
pip install -e ".[zstd, srv]"
|
||||
- name: Run mypy
|
||||
run: |
|
||||
mypy --install-types --non-interactive bson gridfs tools
|
||||
mypy --install-types --non-interactive bson gridfs tools pymongo
|
||||
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test
|
||||
|
||||
@ -28,7 +28,7 @@ from typing import (Any, Dict, Iterable, Iterator, List, Mapping,
|
||||
# This is essentially the same as re._pattern_type
|
||||
RE_TYPE: Type[Pattern[Any]] = type(re.compile(""))
|
||||
|
||||
_Key = TypeVar("_Key", bound=str)
|
||||
_Key = TypeVar("_Key")
|
||||
_Value = TypeVar("_Value")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
10
mypy.ini
10
mypy.ini
@ -11,6 +11,9 @@ warn_unused_configs = true
|
||||
warn_unused_ignores = true
|
||||
warn_redundant_casts = true
|
||||
|
||||
[mypy-gevent.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-kerberos.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@ -29,5 +32,12 @@ ignore_missing_imports = True
|
||||
[mypy-snappy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-test.*]
|
||||
allow_redefinition = true
|
||||
allow_untyped_globals = true
|
||||
|
||||
[mypy-winkerberos.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-xmlrunner.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@ -16,9 +16,8 @@
|
||||
|
||||
import errno
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
|
||||
# https://bugs.jython.org/issue2900
|
||||
@ -43,7 +42,7 @@ class SocketChecker(object):
|
||||
else:
|
||||
self._poller = None
|
||||
|
||||
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
|
||||
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool:
|
||||
"""Select for reads or writes with a timeout in seconds (or None).
|
||||
|
||||
Returns True if the socket is readable/writable, False on timeout.
|
||||
|
||||
@ -39,7 +39,7 @@ def maybe_decode(text):
|
||||
def _resolve(*args, **kwargs):
|
||||
if hasattr(resolver, 'resolve'):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs) # type: ignore
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
|
||||
"""Type aliases used by PyMongo"""
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
Sequence, Tuple, Type, TypeVar, Union)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
@ -25,5 +25,5 @@ if TYPE_CHECKING:
|
||||
_Address = Tuple[str, Optional[int]]
|
||||
_CollationIn = Union[Mapping[str, Any], "Collation"]
|
||||
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
|
||||
_Pipeline = List[Mapping[str, Any]]
|
||||
_Pipeline = Sequence[Mapping[str, Any]]
|
||||
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])
|
||||
|
||||
@ -40,6 +40,7 @@ except ImportError:
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Dict, no_type_check
|
||||
from unittest import SkipTest
|
||||
|
||||
import pymongo
|
||||
@ -48,7 +49,9 @@ import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo import common, message
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.database import Database
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl
|
||||
from pymongo.uri_parser import parse_uri
|
||||
@ -86,7 +89,7 @@ CLIENT_PEM = os.environ.get('CLIENT_PEM',
|
||||
os.path.join(CERT_PATH, 'client.pem'))
|
||||
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
|
||||
|
||||
TLS_OPTIONS = dict(tls=True)
|
||||
TLS_OPTIONS: Dict = dict(tls=True)
|
||||
if CLIENT_PEM:
|
||||
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
|
||||
if CA_PEM:
|
||||
@ -102,13 +105,13 @@ if TEST_LOADBALANCER:
|
||||
# Remove after PYTHON-2712
|
||||
from pymongo import pool
|
||||
pool._MOCK_SERVICE_ID = True
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI)
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
||||
host, port = res['nodelist'][0]
|
||||
db_user = res['username'] or db_user
|
||||
db_pwd = res['password'] or db_pwd
|
||||
elif TEST_SERVERLESS:
|
||||
TEST_LOADBALANCER = True
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI)
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
||||
host, port = res['nodelist'][0]
|
||||
db_user = res['username'] or db_user
|
||||
db_pwd = res['password'] or db_pwd
|
||||
@ -184,6 +187,7 @@ class client_knobs(object):
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
@no_type_check
|
||||
def disable(self):
|
||||
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
|
||||
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
|
||||
@ -224,6 +228,8 @@ def _all_users(db):
|
||||
|
||||
|
||||
class ClientContext(object):
|
||||
client: MongoClient
|
||||
|
||||
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
|
||||
|
||||
def __init__(self):
|
||||
@ -247,9 +253,9 @@ class ClientContext(object):
|
||||
self.tls = False
|
||||
self.tlsCertificateKeyFile = False
|
||||
self.server_is_resolvable = is_server_resolvable()
|
||||
self.default_client_options = {}
|
||||
self.default_client_options: Dict = {}
|
||||
self.sessions_enabled = False
|
||||
self.client = None
|
||||
self.client = None # type: ignore
|
||||
self.conn_lock = threading.Lock()
|
||||
self.is_data_lake = False
|
||||
self.load_balancer = TEST_LOADBALANCER
|
||||
@ -340,6 +346,7 @@ class ClientContext(object):
|
||||
try:
|
||||
self.cmd_line = self.client.admin.command('getCmdLineOpts')
|
||||
except pymongo.errors.OperationFailure as e:
|
||||
assert e.details is not None
|
||||
msg = e.details.get('errmsg', '')
|
||||
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
|
||||
# Unauthorized.
|
||||
@ -418,6 +425,7 @@ class ClientContext(object):
|
||||
else:
|
||||
self.server_parameters = self.client.admin.command(
|
||||
'getParameter', '*')
|
||||
assert self.cmd_line is not None
|
||||
if 'enableTestCommands=1' in self.cmd_line['argv']:
|
||||
self.test_commands_enabled = True
|
||||
elif 'parsed' in self.cmd_line:
|
||||
@ -436,7 +444,8 @@ class ClientContext(object):
|
||||
self.mongoses.append(address)
|
||||
if not self.serverless:
|
||||
# Check for another mongos on the next port.
|
||||
next_address = address[0], address[1] + 1
|
||||
assert address is not None
|
||||
next_address = address[0], address[1] + 1
|
||||
mongos_client = self._connect(
|
||||
*next_address, **self.default_client_options)
|
||||
if mongos_client:
|
||||
@ -496,6 +505,7 @@ class ClientContext(object):
|
||||
try:
|
||||
return db_user in _all_users(client.admin)
|
||||
except pymongo.errors.OperationFailure as e:
|
||||
assert e.details is not None
|
||||
msg = e.details.get('errmsg', '')
|
||||
if e.code == 18 or 'auth fails' in msg:
|
||||
# Auth failed.
|
||||
@ -505,6 +515,7 @@ class ClientContext(object):
|
||||
|
||||
def _server_started_with_auth(self):
|
||||
# MongoDB >= 2.0
|
||||
assert self.cmd_line is not None
|
||||
if 'parsed' in self.cmd_line:
|
||||
parsed = self.cmd_line['parsed']
|
||||
# MongoDB >= 2.6
|
||||
@ -525,6 +536,7 @@ class ClientContext(object):
|
||||
if not socket.has_ipv6:
|
||||
return False
|
||||
|
||||
assert self.cmd_line is not None
|
||||
if 'parsed' in self.cmd_line:
|
||||
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
|
||||
return False
|
||||
@ -932,6 +944,9 @@ class PyMongoTestCase(unittest.TestCase):
|
||||
|
||||
class IntegrationTest(PyMongoTestCase):
|
||||
"""Base class for TestCases that need a connection to MongoDB to pass."""
|
||||
client: MongoClient
|
||||
db: Database
|
||||
credentials: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
@ -1073,7 +1088,7 @@ class PymongoTestRunner(unittest.TextTestRunner):
|
||||
|
||||
|
||||
if HAVE_XML:
|
||||
class PymongoXMLTestRunner(XMLTestRunner):
|
||||
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
|
||||
def run(self, test):
|
||||
setup()
|
||||
result = super(PymongoXMLTestRunner, self).run(test)
|
||||
|
||||
@ -26,6 +26,7 @@ from pymongo.uri_parser import parse_uri
|
||||
|
||||
|
||||
class TestAuthAWS(unittest.TestCase):
|
||||
uri: str
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -21,6 +21,9 @@ import unittest
|
||||
|
||||
|
||||
class TestCursorNamespace(unittest.TestCase):
|
||||
server: MockupDB
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
|
||||
@ -69,6 +72,9 @@ class TestCursorNamespace(unittest.TestCase):
|
||||
|
||||
|
||||
class TestKillCursorsNamespace(unittest.TestCase):
|
||||
server: MockupDB
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
|
||||
|
||||
@ -27,7 +27,7 @@ class TestGetmoreSharded(unittest.TestCase):
|
||||
servers = [MockupDB(), MockupDB()]
|
||||
|
||||
# Collect queries to either server in one queue.
|
||||
q = Queue()
|
||||
q: Queue = Queue()
|
||||
for server in servers:
|
||||
server.subscribe(q.put)
|
||||
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',
|
||||
|
||||
@ -48,7 +48,7 @@ def test_hello_with_option(self, protocol, **kwargs):
|
||||
ServerApiVersion.V1))}
|
||||
client = MongoClient("mongodb://"+primary.address_string,
|
||||
appname='my app', # For _check_handshake_data()
|
||||
**dict([k_map.get((k, v), (k, v)) for k, v
|
||||
**dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type]
|
||||
in kwargs.items()]))
|
||||
|
||||
self.addCleanup(client.close)
|
||||
@ -58,7 +58,7 @@ def test_hello_with_option(self, protocol, **kwargs):
|
||||
|
||||
# We do this checking here rather than in the autoresponder `respond()`
|
||||
# because it runs in another Python thread so there are some funky things
|
||||
# with error handling within that thread, and we want to be able to use
|
||||
# with error handling within that thread, and we want to be able to use
|
||||
# self.assertRaises().
|
||||
self.handshake_req.assert_matches(protocol(hello, **kwargs))
|
||||
_check_handshake_data(self.handshake_req)
|
||||
|
||||
@ -30,7 +30,7 @@ class TestMixedVersionSharded(unittest.TestCase):
|
||||
self.mongos_old, self.mongos_new = MockupDB(), MockupDB()
|
||||
|
||||
# Collect queries to either server in one queue.
|
||||
self.q = Queue()
|
||||
self.q: Queue = Queue()
|
||||
for server in self.mongos_old, self.mongos_new:
|
||||
server.subscribe(self.q.put)
|
||||
server.autoresponds('getlasterror')
|
||||
@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade):
|
||||
def test(self):
|
||||
self.setup_server(upgrade)
|
||||
start = time.time()
|
||||
servers_used = set()
|
||||
servers_used: set = set()
|
||||
while len(servers_used) < 2:
|
||||
go(upgrade.function, self.client)
|
||||
request = self.q.get(timeout=1)
|
||||
|
||||
@ -233,6 +233,8 @@ operations_312 = [
|
||||
|
||||
|
||||
class TestOpMsg(unittest.TestCase):
|
||||
server: MockupDB
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
from mockupdb import MockupDB, going, CommandBase
|
||||
from pymongo import MongoClient, ReadPreference
|
||||
@ -27,6 +28,8 @@ from operations import operations
|
||||
|
||||
class OpMsgReadPrefBase(unittest.TestCase):
|
||||
single_mongod = False
|
||||
primary: MockupDB
|
||||
secondary: MockupDB
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -142,7 +145,7 @@ def create_op_msg_read_mode_test(mode, operation):
|
||||
tag_sets=None)
|
||||
|
||||
client = self.setup_client(read_preference=pref)
|
||||
|
||||
expected_pref: Any
|
||||
if operation.op_type == 'always-use-secondary':
|
||||
expected_server = self.secondary
|
||||
expected_pref = ReadPreference.SECONDARY
|
||||
|
||||
@ -27,7 +27,7 @@ from operations import operations
|
||||
class TestResetAndRequestCheck(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestResetAndRequestCheck, self).__init__(*args, **kwargs)
|
||||
self.ismaster_time = 0
|
||||
self.ismaster_time = 0.0
|
||||
self.client = None
|
||||
self.server = None
|
||||
|
||||
@ -45,7 +45,7 @@ class TestResetAndRequestCheck(unittest.TestCase):
|
||||
kwargs = {'socketTimeoutMS': 100}
|
||||
# Disable retryable reads when pymongo supports it.
|
||||
kwargs['retryReads'] = False
|
||||
self.client = MongoClient(self.server.uri, **kwargs)
|
||||
self.client = MongoClient(self.server.uri, **kwargs) # type: ignore
|
||||
wait_until(lambda: self.client.nodes, 'connect to standalone')
|
||||
|
||||
def tearDown(self):
|
||||
@ -56,6 +56,8 @@ class TestResetAndRequestCheck(unittest.TestCase):
|
||||
# Application operation fails. Test that client resets server
|
||||
# description and does *not* schedule immediate check.
|
||||
self.setup_server()
|
||||
assert self.server is not None
|
||||
assert self.client is not None
|
||||
|
||||
# Network error on application operation.
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
@ -81,6 +83,8 @@ class TestResetAndRequestCheck(unittest.TestCase):
|
||||
# Application operation times out. Test that client does *not* reset
|
||||
# server description and does *not* schedule immediate check.
|
||||
self.setup_server()
|
||||
assert self.server is not None
|
||||
assert self.client is not None
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
with going(operation.function, self.client):
|
||||
@ -91,6 +95,7 @@ class TestResetAndRequestCheck(unittest.TestCase):
|
||||
# Server is *not* Unknown.
|
||||
topology = self.client._topology
|
||||
server = topology.select_server_by_address(self.server.address, 0)
|
||||
assert server is not None
|
||||
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)
|
||||
|
||||
after = self.ismaster_time
|
||||
@ -99,6 +104,8 @@ class TestResetAndRequestCheck(unittest.TestCase):
|
||||
def _test_not_master(self, operation):
|
||||
# Application operation gets a "not master" error.
|
||||
self.setup_server()
|
||||
assert self.server is not None
|
||||
assert self.client is not None
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
with going(operation.function, self.client):
|
||||
@ -110,6 +117,7 @@ class TestResetAndRequestCheck(unittest.TestCase):
|
||||
# Server is rediscovered.
|
||||
topology = self.client._topology
|
||||
server = topology.select_server_by_address(self.server.address, 0)
|
||||
assert server is not None
|
||||
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)
|
||||
|
||||
after = self.ismaster_time
|
||||
|
||||
@ -37,7 +37,7 @@ class TestSlaveOkaySharded(unittest.TestCase):
|
||||
self.mongos1, self.mongos2 = MockupDB(), MockupDB()
|
||||
|
||||
# Collect queries to either server in one queue.
|
||||
self.q = Queue()
|
||||
self.q: Queue = Queue()
|
||||
for server in self.mongos1, self.mongos2:
|
||||
server.subscribe(self.q.put)
|
||||
server.run()
|
||||
|
||||
@ -67,6 +67,10 @@ class Timer(object):
|
||||
|
||||
|
||||
class PerformanceTest(object):
|
||||
dataset: Any
|
||||
data_size: Any
|
||||
do_task: Any
|
||||
fail: Any
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -386,6 +390,7 @@ def mp_map(map_func, files):
|
||||
|
||||
|
||||
def insert_json_file(filename):
|
||||
assert proc_client is not None
|
||||
with open(filename, 'r') as data:
|
||||
coll = proc_client.perftest.corpus
|
||||
coll.insert_many([json.loads(line) for line in data])
|
||||
@ -398,11 +403,13 @@ def insert_json_file_with_file_id(filename):
|
||||
doc = json.loads(line)
|
||||
doc['file'] = filename
|
||||
documents.append(doc)
|
||||
assert proc_client is not None
|
||||
coll = proc_client.perftest.corpus
|
||||
coll.insert_many(documents)
|
||||
|
||||
|
||||
def read_json_file(filename):
|
||||
assert proc_client is not None
|
||||
coll = proc_client.perftest.corpus
|
||||
temp = tempfile.TemporaryFile(mode='w')
|
||||
try:
|
||||
@ -414,6 +421,7 @@ def read_json_file(filename):
|
||||
|
||||
|
||||
def insert_gridfs_file(filename):
|
||||
assert proc_client is not None
|
||||
bucket = GridFSBucket(proc_client.perftest)
|
||||
|
||||
with open(filename, 'rb') as gfile:
|
||||
@ -421,6 +429,7 @@ def insert_gridfs_file(filename):
|
||||
|
||||
|
||||
def read_gridfs_file(filename):
|
||||
assert proc_client is not None
|
||||
bucket = GridFSBucket(proc_client.perftest)
|
||||
|
||||
temp = tempfile.TemporaryFile()
|
||||
|
||||
@ -76,6 +76,8 @@ class AutoAuthenticateThread(threading.Thread):
|
||||
|
||||
|
||||
class TestGSSAPI(unittest.TestCase):
|
||||
mech_properties: str
|
||||
service_realm_required: bool
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -116,6 +118,7 @@ class TestGSSAPI(unittest.TestCase):
|
||||
|
||||
@ignore_deprecations
|
||||
def test_gssapi_simple(self):
|
||||
assert GSSAPI_PRINCIPAL is not None
|
||||
if GSSAPI_PASS is not None:
|
||||
uri = ('mongodb://%s:%s@%s:%d/?authMechanism='
|
||||
'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL),
|
||||
@ -264,6 +267,8 @@ class TestSASLPlain(unittest.TestCase):
|
||||
authMechanism='PLAIN')
|
||||
client.ldap.test.find_one()
|
||||
|
||||
assert SASL_USER is not None
|
||||
assert SASL_PASS is not None
|
||||
uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;'
|
||||
'authSource=%s' % (quote_plus(SASL_USER),
|
||||
quote_plus(SASL_PASS),
|
||||
@ -540,7 +545,6 @@ class TestSCRAM(IntegrationTest):
|
||||
self.assertIsInstance(iterations, int)
|
||||
|
||||
def test_scram_threaded(self):
|
||||
|
||||
coll = client_context.client.db.test
|
||||
coll.drop()
|
||||
coll.insert_one({'_id': 1})
|
||||
|
||||
@ -41,6 +41,8 @@ from test.utils import ignore_deprecations
|
||||
|
||||
|
||||
class TestBinary(unittest.TestCase):
|
||||
csharp_data: bytes
|
||||
java_data: bytes
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -354,6 +356,8 @@ class TestBinary(unittest.TestCase):
|
||||
|
||||
|
||||
class TestUuidSpecExplicitCoding(unittest.TestCase):
|
||||
uuid: uuid.UUID
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestUuidSpecExplicitCoding, cls).setUpClass()
|
||||
@ -457,6 +461,8 @@ class TestUuidSpecExplicitCoding(unittest.TestCase):
|
||||
|
||||
|
||||
class TestUuidSpecImplicitCoding(IntegrationTest):
|
||||
uuid: uuid.UUID
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestUuidSpecImplicitCoding, cls).setUpClass()
|
||||
|
||||
@ -186,7 +186,7 @@ class TestBSON(unittest.TestCase):
|
||||
decoder=lambda *args: BSON(args[0]).decode(*args[1:]))
|
||||
|
||||
def test_encoding_defaultdict(self):
|
||||
dct = collections.defaultdict(dict, [('foo', 'bar')])
|
||||
dct = collections.defaultdict(dict, [('foo', 'bar')]) # type: ignore[arg-type]
|
||||
encode(dct)
|
||||
self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')]))
|
||||
|
||||
@ -302,7 +302,7 @@ class TestBSON(unittest.TestCase):
|
||||
|
||||
def test_decode_all_buffer_protocol(self):
|
||||
docs = [{'foo': 'bar'}, {}]
|
||||
bs = b"".join(map(encode, docs))
|
||||
bs = b"".join(map(encode, docs)) # type: ignore[arg-type]
|
||||
self.assertEqual(docs, decode_all(bytearray(bs)))
|
||||
self.assertEqual(docs, decode_all(memoryview(bs)))
|
||||
self.assertEqual(docs, decode_all(memoryview(b'1' + bs + b'1')[1:-1]))
|
||||
@ -530,7 +530,9 @@ class TestBSON(unittest.TestCase):
|
||||
def test_aware_datetime(self):
|
||||
aware = datetime.datetime(1993, 4, 4, 2,
|
||||
tzinfo=FixedOffset(555, "SomeZone"))
|
||||
as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc)
|
||||
offset = aware.utcoffset()
|
||||
assert offset is not None
|
||||
as_utc = (aware - offset).replace(tzinfo=utc)
|
||||
self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc),
|
||||
as_utc)
|
||||
after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))[
|
||||
@ -591,7 +593,9 @@ class TestBSON(unittest.TestCase):
|
||||
def test_naive_decode(self):
|
||||
aware = datetime.datetime(1993, 4, 4, 2,
|
||||
tzinfo=FixedOffset(555, "SomeZone"))
|
||||
naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None)
|
||||
offset = aware.utcoffset()
|
||||
assert offset is not None
|
||||
naive_utc = (aware - offset).replace(tzinfo=None)
|
||||
self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc)
|
||||
after = decode(encode({"date": aware}))["date"]
|
||||
self.assertEqual(None, after.tzinfo)
|
||||
@ -603,9 +607,9 @@ class TestBSON(unittest.TestCase):
|
||||
|
||||
@unittest.skip('Disabled due to http://bugs.python.org/issue25222')
|
||||
def test_bad_encode(self):
|
||||
evil_list = {'a': []}
|
||||
evil_list: dict = {'a': []}
|
||||
evil_list['a'].append(evil_list)
|
||||
evil_dict = {}
|
||||
evil_dict: dict = {}
|
||||
evil_dict['a'] = evil_dict
|
||||
for evil_data in [evil_dict, evil_list]:
|
||||
self.assertRaises(Exception, encode, evil_data)
|
||||
@ -1039,8 +1043,8 @@ class TestCodecOptions(unittest.TestCase):
|
||||
|
||||
def test_regex_pickling(self):
|
||||
reg = Regex(".?")
|
||||
pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n'
|
||||
b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}'
|
||||
pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n'
|
||||
b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}'
|
||||
b'\x94(\x8c\x07pattern\x94\x8c\x02.?\x94\x8c\x05flag'
|
||||
b's\x94K\x00ub.')
|
||||
self.round_trip_pickle(reg, pickled_with_3)
|
||||
@ -1083,8 +1087,8 @@ class TestCodecOptions(unittest.TestCase):
|
||||
|
||||
def test_maxkey_pickling(self):
|
||||
maxk = MaxKey()
|
||||
pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c'
|
||||
b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)'
|
||||
pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c'
|
||||
b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)'
|
||||
b'\x81\x94.')
|
||||
|
||||
self.round_trip_pickle(maxk, pickled_with_3)
|
||||
|
||||
@ -16,13 +16,15 @@
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from bson.binary import UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson import Binary
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.errors import (BulkWriteError,
|
||||
ConfigurationError,
|
||||
@ -40,6 +42,8 @@ from test.utils import (remove_all_users,
|
||||
|
||||
|
||||
class BulkTestBase(IntegrationTest):
|
||||
coll: Collection
|
||||
coll_w0: Collection
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -280,6 +284,7 @@ class TestBulk(BulkTestBase):
|
||||
upsert=True)])
|
||||
self.assertEqualResponse(expected, result.bulk_api_result)
|
||||
self.assertEqual(1, result.upserted_count)
|
||||
assert result.upserted_ids is not None
|
||||
self.assertEqual(1, len(result.upserted_ids))
|
||||
self.assertTrue(isinstance(result.upserted_ids.get(0), ObjectId))
|
||||
|
||||
@ -341,11 +346,11 @@ class TestBulk(BulkTestBase):
|
||||
# The requests argument must be a list.
|
||||
generator = (InsertOne({}) for _ in range(10))
|
||||
with self.assertRaises(TypeError):
|
||||
self.coll.bulk_write(generator)
|
||||
self.coll.bulk_write(generator) # type: ignore[arg-type]
|
||||
|
||||
# Document is not wrapped in a bulk write operation.
|
||||
with self.assertRaises(TypeError):
|
||||
self.coll.bulk_write([{}])
|
||||
self.coll.bulk_write([{}]) # type: ignore[list-item]
|
||||
|
||||
def test_upsert_large(self):
|
||||
big = 'a' * (client_context.max_bson_size - 37)
|
||||
@ -425,7 +430,7 @@ class TestBulk(BulkTestBase):
|
||||
def test_upsert_uuid_standard_subdocuments(self):
|
||||
options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
coll = self.coll.with_options(codec_options=options)
|
||||
ids = [
|
||||
ids: list = [
|
||||
{'f': Binary(bytes(i)), 'f2': uuid.uuid4()}
|
||||
for i in range(3)
|
||||
]
|
||||
@ -472,7 +477,7 @@ class TestBulk(BulkTestBase):
|
||||
def test_single_error_ordered_batch(self):
|
||||
self.coll.create_index('a', unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [('a', 1)])
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'b': 1, 'a': 1}),
|
||||
UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True),
|
||||
InsertOne({'b': 3, 'a': 2}),
|
||||
@ -506,7 +511,7 @@ class TestBulk(BulkTestBase):
|
||||
def test_multiple_error_ordered_batch(self):
|
||||
self.coll.create_index('a', unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [('a', 1)])
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'b': 1, 'a': 1}),
|
||||
UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True),
|
||||
UpdateOne({'b': 3}, {'$set': {'a': 2}}, upsert=True),
|
||||
@ -542,7 +547,7 @@ class TestBulk(BulkTestBase):
|
||||
result)
|
||||
|
||||
def test_single_unordered_batch(self):
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'a': 1}),
|
||||
UpdateOne({'a': 1}, {'$set': {'b': 1}}),
|
||||
UpdateOne({'a': 2}, {'$set': {'b': 2}}, upsert=True),
|
||||
@ -564,7 +569,7 @@ class TestBulk(BulkTestBase):
|
||||
def test_single_error_unordered_batch(self):
|
||||
self.coll.create_index('a', unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [('a', 1)])
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'b': 1, 'a': 1}),
|
||||
UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True),
|
||||
InsertOne({'b': 3, 'a': 2}),
|
||||
@ -599,7 +604,7 @@ class TestBulk(BulkTestBase):
|
||||
def test_multiple_error_unordered_batch(self):
|
||||
self.coll.create_index('a', unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [('a', 1)])
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'b': 1, 'a': 1}),
|
||||
UpdateOne({'b': 2}, {'$set': {'a': 3}}, upsert=True),
|
||||
UpdateOne({'b': 3}, {'$set': {'a': 4}}, upsert=True),
|
||||
@ -662,7 +667,7 @@ class TestBulk(BulkTestBase):
|
||||
self.coll.delete_many({})
|
||||
|
||||
big = 'x' * (1024 * 1024 * 4)
|
||||
result = self.coll.bulk_write([
|
||||
write_result = self.coll.bulk_write([
|
||||
InsertOne({'a': 1, 'big': big}),
|
||||
InsertOne({'a': 2, 'big': big}),
|
||||
InsertOne({'a': 3, 'big': big}),
|
||||
@ -671,7 +676,7 @@ class TestBulk(BulkTestBase):
|
||||
InsertOne({'a': 6, 'big': big}),
|
||||
])
|
||||
|
||||
self.assertEqual(6, result.inserted_count)
|
||||
self.assertEqual(6, write_result.inserted_count)
|
||||
self.assertEqual(6, self.coll.count_documents({}))
|
||||
|
||||
def test_large_inserts_unordered(self):
|
||||
@ -685,12 +690,12 @@ class TestBulk(BulkTestBase):
|
||||
try:
|
||||
self.coll.bulk_write(requests, ordered=False)
|
||||
except BulkWriteError as exc:
|
||||
result = exc.details
|
||||
details = exc.details
|
||||
self.assertEqual(exc.code, 65)
|
||||
else:
|
||||
self.fail("Error not raised")
|
||||
|
||||
self.assertEqual(2, result['nInserted'])
|
||||
self.assertEqual(2, details['nInserted'])
|
||||
|
||||
self.coll.delete_many({})
|
||||
|
||||
@ -741,7 +746,7 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
self.coll.delete_many({})
|
||||
|
||||
def test_no_results_ordered_success(self):
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'a': 1}),
|
||||
UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True),
|
||||
InsertOne({'a': 2}),
|
||||
@ -755,7 +760,7 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
'removed {"_id": 1}')
|
||||
|
||||
def test_no_results_ordered_failure(self):
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'_id': 1}),
|
||||
UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True),
|
||||
InsertOne({'_id': 2}),
|
||||
@ -771,7 +776,7 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
self.assertEqual({'_id': 1}, self.coll.find_one({'_id': 1}))
|
||||
|
||||
def test_no_results_unordered_success(self):
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'a': 1}),
|
||||
UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True),
|
||||
InsertOne({'a': 2}),
|
||||
@ -785,7 +790,7 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
'removed {"_id": 1}')
|
||||
|
||||
def test_no_results_unordered_failure(self):
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'_id': 1}),
|
||||
UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True),
|
||||
InsertOne({'_id': 2}),
|
||||
@ -832,13 +837,15 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
|
||||
|
||||
|
||||
class TestBulkWriteConcern(BulkTestBase):
|
||||
w: Optional[int]
|
||||
secondary: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestBulkWriteConcern, cls).setUpClass()
|
||||
cls.w = client_context.w
|
||||
cls.secondary = None
|
||||
if cls.w > 1:
|
||||
if cls.w is not None and cls.w > 1:
|
||||
for member in client_context.hello['hosts']:
|
||||
if member != client_context.hello['primary']:
|
||||
cls.secondary = single_client(*partition_node(member))
|
||||
@ -886,7 +893,7 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
try:
|
||||
self.cause_wtimeout(requests, ordered=True)
|
||||
except BulkWriteError as exc:
|
||||
result = exc.details
|
||||
details = exc.details
|
||||
self.assertEqual(exc.code, 65)
|
||||
else:
|
||||
self.fail("Error not raised")
|
||||
@ -899,13 +906,13 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
'nRemoved': 0,
|
||||
'upserted': [],
|
||||
'writeErrors': []},
|
||||
result)
|
||||
details)
|
||||
|
||||
# When talking to legacy servers there will be a
|
||||
# write concern error for each operation.
|
||||
self.assertTrue(len(result['writeConcernErrors']) > 0)
|
||||
self.assertTrue(len(details['writeConcernErrors']) > 0)
|
||||
|
||||
failed = result['writeConcernErrors'][0]
|
||||
failed = details['writeConcernErrors'][0]
|
||||
self.assertEqual(64, failed['code'])
|
||||
self.assertTrue(isinstance(failed['errmsg'], str))
|
||||
|
||||
@ -924,7 +931,7 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
try:
|
||||
self.cause_wtimeout(requests, ordered=True)
|
||||
except BulkWriteError as exc:
|
||||
result = exc.details
|
||||
details = exc.details
|
||||
self.assertEqual(exc.code, 65)
|
||||
else:
|
||||
self.fail("Error not raised")
|
||||
@ -941,10 +948,10 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
'code': 11000,
|
||||
'errmsg': '...',
|
||||
'op': {'_id': '...', 'a': 1}}]},
|
||||
result)
|
||||
details)
|
||||
|
||||
self.assertTrue(len(result['writeConcernErrors']) > 1)
|
||||
failed = result['writeErrors'][0]
|
||||
self.assertTrue(len(details['writeConcernErrors']) > 1)
|
||||
failed = details['writeErrors'][0]
|
||||
self.assertTrue("duplicate" in failed['errmsg'])
|
||||
|
||||
@client_context.require_replica_set
|
||||
@ -966,17 +973,17 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
try:
|
||||
self.cause_wtimeout(requests, ordered=False)
|
||||
except BulkWriteError as exc:
|
||||
result = exc.details
|
||||
details = exc.details
|
||||
self.assertEqual(exc.code, 65)
|
||||
else:
|
||||
self.fail("Error not raised")
|
||||
|
||||
self.assertEqual(2, result['nInserted'])
|
||||
self.assertEqual(1, result['nUpserted'])
|
||||
self.assertEqual(0, len(result['writeErrors']))
|
||||
self.assertEqual(2, details['nInserted'])
|
||||
self.assertEqual(1, details['nUpserted'])
|
||||
self.assertEqual(0, len(details['writeErrors']))
|
||||
# When talking to legacy servers there will be a
|
||||
# write concern error for each operation.
|
||||
self.assertTrue(len(result['writeConcernErrors']) > 1)
|
||||
self.assertTrue(len(details['writeConcernErrors']) > 1)
|
||||
|
||||
self.coll.delete_many({})
|
||||
self.coll.create_index('a', unique=True)
|
||||
@ -984,7 +991,7 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
|
||||
# Fail due to write concern support as well
|
||||
# as duplicate key error on unordered batch.
|
||||
requests = [
|
||||
requests: list = [
|
||||
InsertOne({'a': 1}),
|
||||
UpdateOne({'a': 3}, {'$set': {'a': 3, 'b': 1}}, upsert=True),
|
||||
InsertOne({'a': 1}),
|
||||
@ -993,29 +1000,29 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
try:
|
||||
self.cause_wtimeout(requests, ordered=False)
|
||||
except BulkWriteError as exc:
|
||||
result = exc.details
|
||||
details = exc.details
|
||||
self.assertEqual(exc.code, 65)
|
||||
else:
|
||||
self.fail("Error not raised")
|
||||
|
||||
self.assertEqual(2, result['nInserted'])
|
||||
self.assertEqual(1, result['nUpserted'])
|
||||
self.assertEqual(1, len(result['writeErrors']))
|
||||
self.assertEqual(2, details['nInserted'])
|
||||
self.assertEqual(1, details['nUpserted'])
|
||||
self.assertEqual(1, len(details['writeErrors']))
|
||||
# When talking to legacy servers there will be a
|
||||
# write concern error for each operation.
|
||||
self.assertTrue(len(result['writeConcernErrors']) > 1)
|
||||
self.assertTrue(len(details['writeConcernErrors']) > 1)
|
||||
|
||||
failed = result['writeErrors'][0]
|
||||
failed = details['writeErrors'][0]
|
||||
self.assertEqual(2, failed['index'])
|
||||
self.assertEqual(11000, failed['code'])
|
||||
self.assertTrue(isinstance(failed['errmsg'], str))
|
||||
self.assertEqual(1, failed['op']['a'])
|
||||
|
||||
failed = result['writeConcernErrors'][0]
|
||||
failed = details['writeConcernErrors'][0]
|
||||
self.assertEqual(64, failed['code'])
|
||||
self.assertTrue(isinstance(failed['errmsg'], str))
|
||||
|
||||
upserts = result['upserted']
|
||||
upserts = details['upserted']
|
||||
self.assertEqual(1, len(upserts))
|
||||
self.assertEqual(1, upserts[0]['index'])
|
||||
self.assertTrue(upserts[0].get('_id'))
|
||||
|
||||
@ -24,6 +24,7 @@ import time
|
||||
import uuid
|
||||
|
||||
from itertools import product
|
||||
from typing import no_type_check
|
||||
|
||||
sys.path[0:0] = ['']
|
||||
|
||||
@ -121,6 +122,7 @@ class TestChangeStreamBase(IntegrationTest):
|
||||
|
||||
|
||||
class APITestsMixin(object):
|
||||
@no_type_check
|
||||
def test_watch(self):
|
||||
with self.change_stream(
|
||||
[{'$project': {'foo': 0}}], full_document='updateLookup',
|
||||
@ -145,6 +147,7 @@ class APITestsMixin(object):
|
||||
with self.change_stream(resume_after=resume_token):
|
||||
pass
|
||||
|
||||
@no_type_check
|
||||
def test_try_next(self):
|
||||
# ChangeStreams only read majority committed data so use w:majority.
|
||||
coll = self.watched_collection().with_options(
|
||||
@ -161,6 +164,7 @@ class APITestsMixin(object):
|
||||
wait_until(lambda: stream.try_next() is not None,
|
||||
"get change from try_next")
|
||||
|
||||
@no_type_check
|
||||
def test_try_next_runs_one_getmore(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
@ -216,6 +220,7 @@ class APITestsMixin(object):
|
||||
set(["getMore"]))
|
||||
self.assertIsNone(stream.try_next())
|
||||
|
||||
@no_type_check
|
||||
def test_batch_size_is_honored(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
@ -245,6 +250,7 @@ class APITestsMixin(object):
|
||||
self.assertEqual(expected[key], cmd[key])
|
||||
|
||||
# $changeStream.startAtOperationTime was added in 4.0.0.
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
def test_start_at_operation_time(self):
|
||||
optime = self.get_start_at_operation_time()
|
||||
@ -258,6 +264,7 @@ class APITestsMixin(object):
|
||||
for i in range(ndocs):
|
||||
cs.next()
|
||||
|
||||
@no_type_check
|
||||
def _test_full_pipeline(self, expected_cs_stage):
|
||||
client, listener = self.client_with_listener("aggregate")
|
||||
results = listener.results
|
||||
@ -273,12 +280,14 @@ class APITestsMixin(object):
|
||||
{'$project': {'foo': 0}}],
|
||||
command.command['pipeline'])
|
||||
|
||||
@no_type_check
|
||||
def test_full_pipeline(self):
|
||||
"""$changeStream must be the first stage in a change stream pipeline
|
||||
sent to the server.
|
||||
"""
|
||||
self._test_full_pipeline({})
|
||||
|
||||
@no_type_check
|
||||
def test_iteration(self):
|
||||
with self.change_stream(batch_size=2) as change_stream:
|
||||
num_inserted = 10
|
||||
@ -292,6 +301,7 @@ class APITestsMixin(object):
|
||||
break
|
||||
self._test_invalidate_stops_iteration(change_stream)
|
||||
|
||||
@no_type_check
|
||||
def _test_next_blocks(self, change_stream):
|
||||
inserted_doc = {'_id': ObjectId()}
|
||||
changes = []
|
||||
@ -311,18 +321,21 @@ class APITestsMixin(object):
|
||||
self.assertEqual(changes[0]['operationType'], 'insert')
|
||||
self.assertEqual(changes[0]['fullDocument'], inserted_doc)
|
||||
|
||||
@no_type_check
|
||||
def test_next_blocks(self):
|
||||
"""Test that next blocks until a change is readable"""
|
||||
# Use a short await time to speed up the test.
|
||||
with self.change_stream(max_await_time_ms=250) as change_stream:
|
||||
self._test_next_blocks(change_stream)
|
||||
|
||||
@no_type_check
|
||||
def test_aggregate_cursor_blocks(self):
|
||||
"""Test that an aggregate cursor blocks until a change is readable."""
|
||||
with self.watched_collection().aggregate(
|
||||
[{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream:
|
||||
self._test_next_blocks(change_stream)
|
||||
|
||||
@no_type_check
|
||||
def test_concurrent_close(self):
|
||||
"""Ensure a ChangeStream can be closed from another thread."""
|
||||
# Use a short await time to speed up the test.
|
||||
@ -338,6 +351,7 @@ class APITestsMixin(object):
|
||||
t.join(3)
|
||||
self.assertFalse(t.is_alive())
|
||||
|
||||
@no_type_check
|
||||
def test_unknown_full_document(self):
|
||||
"""Must rely on the server to raise an error on unknown fullDocument.
|
||||
"""
|
||||
@ -347,6 +361,7 @@ class APITestsMixin(object):
|
||||
except OperationFailure:
|
||||
pass
|
||||
|
||||
@no_type_check
|
||||
def test_change_operations(self):
|
||||
"""Test each operation type."""
|
||||
expected_ns = {'db': self.watched_collection().database.name,
|
||||
@ -393,6 +408,7 @@ class APITestsMixin(object):
|
||||
# Invalidate.
|
||||
self._test_get_invalidate_event(change_stream)
|
||||
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 1, 1)
|
||||
def test_start_after(self):
|
||||
resume_token = self.get_resume_token(invalidate=True)
|
||||
@ -408,6 +424,7 @@ class APITestsMixin(object):
|
||||
self.assertEqual(change['operationType'], 'insert')
|
||||
self.assertEqual(change['fullDocument'], {'_id': 2})
|
||||
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 1, 1)
|
||||
def test_start_after_resume_process_with_changes(self):
|
||||
resume_token = self.get_resume_token(invalidate=True)
|
||||
@ -427,6 +444,7 @@ class APITestsMixin(object):
|
||||
self.assertEqual(change['operationType'], 'insert')
|
||||
self.assertEqual(change['fullDocument'], {'_id': 3})
|
||||
|
||||
@no_type_check
|
||||
@client_context.require_no_mongos # Remove after SERVER-41196
|
||||
@client_context.require_version_min(4, 1, 1)
|
||||
def test_start_after_resume_process_without_changes(self):
|
||||
@ -444,12 +462,14 @@ class APITestsMixin(object):
|
||||
|
||||
|
||||
class ProseSpecTestsMixin(object):
|
||||
@no_type_check
|
||||
def _client_with_listener(self, *commands):
|
||||
listener = AllowListEventListener(*commands)
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
return client, listener
|
||||
|
||||
@no_type_check
|
||||
def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3):
|
||||
self.watched_collection().insert_many(
|
||||
[{"data": k} for k in range(batch_size)])
|
||||
@ -485,6 +505,7 @@ class ProseSpecTestsMixin(object):
|
||||
response = listener.results['succeeded'][-1].reply
|
||||
return response['cursor']['postBatchResumeToken']
|
||||
|
||||
@no_type_check
|
||||
def _test_raises_error_on_missing_id(self, expected_exception):
|
||||
"""ChangeStream will raise an exception if the server response is
|
||||
missing the resume token.
|
||||
@ -497,6 +518,7 @@ class ProseSpecTestsMixin(object):
|
||||
with self.assertRaises(StopIteration):
|
||||
next(change_stream)
|
||||
|
||||
@no_type_check
|
||||
def _test_update_resume_token(self, expected_rt_getter):
|
||||
"""ChangeStream must continuously track the last seen resumeToken."""
|
||||
client, listener = self._client_with_listener("aggregate", "getMore")
|
||||
@ -536,6 +558,7 @@ class ProseSpecTestsMixin(object):
|
||||
self._test_raises_error_on_missing_id(InvalidOperation)
|
||||
|
||||
# Prose test no. 3
|
||||
@no_type_check
|
||||
def test_resume_on_error(self):
|
||||
with self.change_stream() as change_stream:
|
||||
self.insert_one_and_check(change_stream, {'_id': 1})
|
||||
@ -544,6 +567,7 @@ class ProseSpecTestsMixin(object):
|
||||
self.insert_one_and_check(change_stream, {'_id': 2})
|
||||
|
||||
# Prose test no. 4
|
||||
@no_type_check
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_no_resume_attempt_if_aggregate_command_fails(self):
|
||||
# Set non-retryable error on aggregate command.
|
||||
@ -568,6 +592,7 @@ class ProseSpecTestsMixin(object):
|
||||
# each operation which ensure compliance with this prose test.
|
||||
|
||||
# Prose test no. 7
|
||||
@no_type_check
|
||||
def test_initial_empty_batch(self):
|
||||
with self.change_stream() as change_stream:
|
||||
# The first batch should be empty.
|
||||
@ -579,6 +604,7 @@ class ProseSpecTestsMixin(object):
|
||||
self.assertEqual(cursor_id, change_stream._cursor.cursor_id)
|
||||
|
||||
# Prose test no. 8
|
||||
@no_type_check
|
||||
def test_kill_cursors(self):
|
||||
def raise_error():
|
||||
raise ServerSelectionTimeoutError('mock error')
|
||||
@ -591,6 +617,7 @@ class ProseSpecTestsMixin(object):
|
||||
self.insert_one_and_check(change_stream, {'_id': 2})
|
||||
|
||||
# Prose test no. 9
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_version_max(4, 0, 7)
|
||||
def test_start_at_operation_time_caching(self):
|
||||
@ -619,6 +646,7 @@ class ProseSpecTestsMixin(object):
|
||||
# This test is identical to prose test no. 3.
|
||||
|
||||
# Prose test no. 11
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 0, 7)
|
||||
def test_resumetoken_empty_batch(self):
|
||||
client, listener = self._client_with_listener("getMore")
|
||||
@ -631,6 +659,7 @@ class ProseSpecTestsMixin(object):
|
||||
response["cursor"]["postBatchResumeToken"])
|
||||
|
||||
# Prose test no. 11
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 0, 7)
|
||||
def test_resumetoken_exhausted_batch(self):
|
||||
client, listener = self._client_with_listener("getMore")
|
||||
@ -643,6 +672,7 @@ class ProseSpecTestsMixin(object):
|
||||
response["cursor"]["postBatchResumeToken"])
|
||||
|
||||
# Prose test no. 12
|
||||
@no_type_check
|
||||
@client_context.require_version_max(4, 0, 7)
|
||||
def test_resumetoken_empty_batch_legacy(self):
|
||||
resume_point = self.get_resume_token()
|
||||
@ -659,6 +689,7 @@ class ProseSpecTestsMixin(object):
|
||||
self.assertEqual(resume_token, resume_point)
|
||||
|
||||
# Prose test no. 12
|
||||
@no_type_check
|
||||
@client_context.require_version_max(4, 0, 7)
|
||||
def test_resumetoken_exhausted_batch_legacy(self):
|
||||
# Resume token is _id of last change.
|
||||
@ -673,6 +704,7 @@ class ProseSpecTestsMixin(object):
|
||||
self.assertEqual(change_stream.resume_token, change["_id"])
|
||||
|
||||
# Prose test no. 13
|
||||
@no_type_check
|
||||
def test_resumetoken_partially_iterated_batch(self):
|
||||
# When batch has been iterated up to but not including the last element.
|
||||
# Resume token should be _id of previous change document.
|
||||
@ -686,6 +718,7 @@ class ProseSpecTestsMixin(object):
|
||||
|
||||
self.assertEqual(resume_token, change["_id"])
|
||||
|
||||
@no_type_check
|
||||
def _test_resumetoken_uniterated_nonempty_batch(self, resume_option):
|
||||
# When the batch is not empty and hasn't been iterated at all.
|
||||
# Resume token should be same as the resume option used.
|
||||
@ -704,17 +737,20 @@ class ProseSpecTestsMixin(object):
|
||||
self.assertEqual(resume_token, resume_point)
|
||||
|
||||
# Prose test no. 14
|
||||
@no_type_check
|
||||
@client_context.require_no_mongos
|
||||
def test_resumetoken_uniterated_nonempty_batch_resumeafter(self):
|
||||
self._test_resumetoken_uniterated_nonempty_batch("resume_after")
|
||||
|
||||
# Prose test no. 14
|
||||
@no_type_check
|
||||
@client_context.require_no_mongos
|
||||
@client_context.require_version_min(4, 1, 1)
|
||||
def test_resumetoken_uniterated_nonempty_batch_startafter(self):
|
||||
self._test_resumetoken_uniterated_nonempty_batch("start_after")
|
||||
|
||||
# Prose test no. 17
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 1, 1)
|
||||
def test_startafter_resume_uses_startafter_after_empty_getMore(self):
|
||||
# Resume should use startAfter after no changes have been returned.
|
||||
@ -735,6 +771,7 @@ class ProseSpecTestsMixin(object):
|
||||
response.command["pipeline"][0]["$changeStream"].get("startAfter"))
|
||||
|
||||
# Prose test no. 18
|
||||
@no_type_check
|
||||
@client_context.require_version_min(4, 1, 1)
|
||||
def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self):
|
||||
# Resume should use resumeAfter after some changes have been returned.
|
||||
@ -757,6 +794,8 @@ class ProseSpecTestsMixin(object):
|
||||
|
||||
|
||||
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
dbs: list
|
||||
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0, -1)
|
||||
@client_context.require_no_mmap
|
||||
@ -1045,6 +1084,7 @@ class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin,
|
||||
|
||||
class TestAllLegacyScenarios(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
listener: AllowListEventListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
|
||||
@ -28,6 +28,8 @@ import _thread as thread
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from typing import no_type_check, Type
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson import encode
|
||||
@ -99,6 +101,7 @@ from test.utils import (assertRaisesExactly,
|
||||
|
||||
class ClientUnitTest(unittest.TestCase):
|
||||
"""MongoClient tests that don't require a server."""
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
@ -341,7 +344,7 @@ class ClientUnitTest(unittest.TestCase):
|
||||
return int(value)
|
||||
|
||||
# Ensure codec options are passed in correctly
|
||||
document_class = SON
|
||||
document_class: Type[SON] = SON
|
||||
type_registry = TypeRegistry([MyFloatAsIntEncoder()])
|
||||
tz_aware = True
|
||||
uuid_representation_label = 'javaLegacy'
|
||||
@ -614,7 +617,7 @@ class TestClient(IntegrationTest):
|
||||
port are not overloaded.
|
||||
"""
|
||||
host, port = client_context.host, client_context.port
|
||||
kwargs = client_context.default_client_options.copy()
|
||||
kwargs: dict = client_context.default_client_options.copy()
|
||||
if client_context.auth_enabled:
|
||||
kwargs['username'] = db_user
|
||||
kwargs['password'] = db_pwd
|
||||
@ -1111,6 +1114,7 @@ class TestClient(IntegrationTest):
|
||||
socket.SO_KEEPALIVE)
|
||||
self.assertTrue(keepalive)
|
||||
|
||||
@no_type_check
|
||||
def test_tz_aware(self):
|
||||
self.assertRaises(ValueError, MongoClient, tz_aware='foo')
|
||||
|
||||
@ -1140,7 +1144,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
uri = "mongodb://%s[::1]:%d" % (auth_str, client_context.port)
|
||||
if client_context.is_rs:
|
||||
uri += '/?replicaSet=' + client_context.replica_set_name
|
||||
uri += '/?replicaSet=' + (client_context.replica_set_name or "")
|
||||
|
||||
client = rs_or_single_client_noauth(uri)
|
||||
client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||
@ -1379,7 +1383,7 @@ class TestClient(IntegrationTest):
|
||||
heartbeat_times.append(time.time())
|
||||
|
||||
try:
|
||||
ServerHeartbeatStartedEvent.__init__ = init
|
||||
ServerHeartbeatStartedEvent.__init__ = init # type: ignore
|
||||
listener = HeartbeatStartedListener()
|
||||
uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % (
|
||||
client_context.host, client_context.port)
|
||||
@ -1394,7 +1398,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
client.close()
|
||||
finally:
|
||||
ServerHeartbeatStartedEvent.__init__ = old_init
|
||||
ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore
|
||||
|
||||
def test_small_heartbeat_frequency_ms(self):
|
||||
uri = "mongodb://example/?heartbeatFrequencyMS=499"
|
||||
@ -1847,7 +1851,7 @@ class TestClientLazyConnect(IntegrationTest):
|
||||
lazy_client_trial(reset, delete_one, test, self._get_client)
|
||||
|
||||
def test_find_one(self):
|
||||
results = []
|
||||
results: list = []
|
||||
|
||||
def reset(collection):
|
||||
collection.drop()
|
||||
|
||||
@ -213,11 +213,11 @@ class TestCMAP(IntegrationTest):
|
||||
|
||||
def run_scenario(self, scenario_def, test):
|
||||
"""Run a CMAP spec test."""
|
||||
self.logs = []
|
||||
self.logs: list = []
|
||||
self.assertEqual(scenario_def['version'], 1)
|
||||
self.assertIn(scenario_def['style'], ['unit', 'integration'])
|
||||
self.listener = CMAPListener()
|
||||
self._ops = []
|
||||
self._ops: list = []
|
||||
|
||||
# Configure the fail point before creating the client.
|
||||
if 'failPoint' in test:
|
||||
@ -259,9 +259,9 @@ class TestCMAP(IntegrationTest):
|
||||
self.pool = list(client._topology._servers.values())[0].pool
|
||||
|
||||
# Map of target names to Thread objects.
|
||||
self.targets = dict()
|
||||
self.targets: dict = dict()
|
||||
# Map of label names to Connection objects
|
||||
self.labels = dict()
|
||||
self.labels: dict = dict()
|
||||
|
||||
def cleanup():
|
||||
for t in self.targets.values():
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"""Tests for the Code wrapper."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson.code import Code
|
||||
@ -35,7 +36,7 @@ class TestCode(unittest.TestCase):
|
||||
c = Code("blah")
|
||||
|
||||
def set_c():
|
||||
c.scope = 5
|
||||
c.scope = 5 # type: ignore
|
||||
self.assertRaises(AttributeError, set_c)
|
||||
|
||||
def test_code(self):
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pymongo.collation import (
|
||||
Collation,
|
||||
CollationCaseFirst, CollationStrength, CollationAlternate,
|
||||
@ -78,6 +80,10 @@ class TestCollationObject(unittest.TestCase):
|
||||
|
||||
|
||||
class TestCollation(IntegrationTest):
|
||||
listener: EventListener
|
||||
warn_context: Any
|
||||
collation: Collation
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -20,8 +20,11 @@ import contextlib
|
||||
import re
|
||||
import sys
|
||||
|
||||
from codecs import utf_8_decode
|
||||
from codecs import utf_8_decode # type: ignore
|
||||
from collections import defaultdict
|
||||
from typing import no_type_check
|
||||
|
||||
from pymongo.database import Database
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -66,6 +69,7 @@ from test.utils import (get_pool, is_mongos,
|
||||
class TestCollectionNoConnect(unittest.TestCase):
|
||||
"""Test Collection features on a client that does not connect.
|
||||
"""
|
||||
db: Database
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -116,11 +120,12 @@ class TestCollectionNoConnect(unittest.TestCase):
|
||||
|
||||
|
||||
class TestCollection(IntegrationTest):
|
||||
w: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestCollection, cls).setUpClass()
|
||||
cls.w = client_context.w
|
||||
cls.w = client_context.w # type: ignore
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@ -726,7 +731,7 @@ class TestCollection(IntegrationTest):
|
||||
db = self.db
|
||||
db.test.drop()
|
||||
|
||||
docs = [{} for _ in range(5)]
|
||||
docs: list = [{} for _ in range(5)]
|
||||
result = db.test.insert_many(docs)
|
||||
self.assertTrue(isinstance(result, InsertManyResult))
|
||||
self.assertTrue(isinstance(result.inserted_ids, list))
|
||||
@ -759,7 +764,7 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
db = db.client.get_database(db.name,
|
||||
write_concern=WriteConcern(w=0))
|
||||
docs = [{} for _ in range(5)]
|
||||
docs: list = [{} for _ in range(5)]
|
||||
result = db.test.insert_many(docs)
|
||||
self.assertTrue(isinstance(result, InsertManyResult))
|
||||
self.assertFalse(result.acknowledged)
|
||||
@ -792,11 +797,11 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "documents must be a non-empty list"):
|
||||
db.test.insert_many(1)
|
||||
db.test.insert_many(1) # type: ignore[arg-type]
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "documents must be a non-empty list"):
|
||||
db.test.insert_many(RawBSONDocument(encode({'_id': 2})))
|
||||
db.test.insert_many(RawBSONDocument(encode({'_id': 2}))) # type: ignore[arg-type]
|
||||
|
||||
def test_delete_one(self):
|
||||
self.db.test.drop()
|
||||
@ -1064,7 +1069,7 @@ class TestCollection(IntegrationTest):
|
||||
db_w0 = self.db.client.get_database(
|
||||
self.db.name, write_concern=WriteConcern(w=0))
|
||||
|
||||
ops = [InsertOne({"a": -10}),
|
||||
ops: list = [InsertOne({"a": -10}),
|
||||
InsertOne({"a": -11}),
|
||||
InsertOne({"a": -12}),
|
||||
UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}),
|
||||
@ -1087,7 +1092,7 @@ class TestCollection(IntegrationTest):
|
||||
def test_find_by_default_dct(self):
|
||||
db = self.db
|
||||
db.test.insert_one({'foo': 'bar'})
|
||||
dct = defaultdict(dict, [('foo', 'bar')])
|
||||
dct = defaultdict(dict, [('foo', 'bar')]) # type: ignore[arg-type]
|
||||
self.assertIsNotNone(db.test.find_one(dct))
|
||||
self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')]))
|
||||
|
||||
@ -1117,6 +1122,7 @@ class TestCollection(IntegrationTest):
|
||||
doc = next(db.test.find({}, ["mike"]))
|
||||
self.assertFalse("extra thing" in doc)
|
||||
|
||||
@no_type_check
|
||||
def test_fields_specifier_as_dict(self):
|
||||
db = self.db
|
||||
db.test.delete_many({})
|
||||
@ -1333,7 +1339,7 @@ class TestCollection(IntegrationTest):
|
||||
self.assertTrue(result.acknowledged)
|
||||
self.assertEqual(1, db.test.count_documents({"y": 1}))
|
||||
self.assertEqual(0, db.test.count_documents({"x": 1}))
|
||||
self.assertEqual(db.test.find_one(id1)["y"], 1)
|
||||
self.assertEqual(db.test.find_one(id1)["y"], 1) # type: ignore
|
||||
|
||||
replacement = RawBSONDocument(encode({"_id": id1, "z": 1}))
|
||||
result = db.test.replace_one({"y": 1}, replacement, True)
|
||||
@ -1344,7 +1350,7 @@ class TestCollection(IntegrationTest):
|
||||
self.assertTrue(result.acknowledged)
|
||||
self.assertEqual(1, db.test.count_documents({"z": 1}))
|
||||
self.assertEqual(0, db.test.count_documents({"y": 1}))
|
||||
self.assertEqual(db.test.find_one(id1)["z"], 1)
|
||||
self.assertEqual(db.test.find_one(id1)["z"], 1) # type: ignore
|
||||
|
||||
result = db.test.replace_one({"x": 2}, {"y": 2}, True)
|
||||
self.assertTrue(isinstance(result, UpdateResult))
|
||||
@ -1377,7 +1383,7 @@ class TestCollection(IntegrationTest):
|
||||
self.assertTrue(result.modified_count in (None, 1))
|
||||
self.assertIsNone(result.upserted_id)
|
||||
self.assertTrue(result.acknowledged)
|
||||
self.assertEqual(db.test.find_one(id1)["x"], 6)
|
||||
self.assertEqual(db.test.find_one(id1)["x"], 6) # type: ignore
|
||||
|
||||
id2 = db.test.insert_one({"x": 1}).inserted_id
|
||||
result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}})
|
||||
@ -1386,8 +1392,8 @@ class TestCollection(IntegrationTest):
|
||||
self.assertTrue(result.modified_count in (None, 1))
|
||||
self.assertIsNone(result.upserted_id)
|
||||
self.assertTrue(result.acknowledged)
|
||||
self.assertEqual(db.test.find_one(id1)["x"], 7)
|
||||
self.assertEqual(db.test.find_one(id2)["x"], 1)
|
||||
self.assertEqual(db.test.find_one(id1)["x"], 7) # type: ignore
|
||||
self.assertEqual(db.test.find_one(id2)["x"], 1) # type: ignore
|
||||
|
||||
result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True)
|
||||
self.assertTrue(isinstance(result, UpdateResult))
|
||||
@ -1587,12 +1593,12 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
# Test that batchSize is handled properly.
|
||||
cursor = db.test.aggregate([], batchSize=5)
|
||||
self.assertEqual(5, len(cursor._CommandCursor__data))
|
||||
self.assertEqual(5, len(cursor._CommandCursor__data)) # type: ignore
|
||||
# Force a getMore
|
||||
cursor._CommandCursor__data.clear()
|
||||
cursor._CommandCursor__data.clear() # type: ignore
|
||||
next(cursor)
|
||||
# batchSize - 1
|
||||
self.assertEqual(4, len(cursor._CommandCursor__data))
|
||||
self.assertEqual(4, len(cursor._CommandCursor__data)) # type: ignore
|
||||
# Exhaust the cursor. There shouldn't be any errors.
|
||||
for doc in cursor:
|
||||
pass
|
||||
@ -1679,6 +1685,7 @@ class TestCollection(IntegrationTest):
|
||||
with self.write_concern_collection() as coll:
|
||||
coll.rename('foo')
|
||||
|
||||
@no_type_check
|
||||
def test_find_one(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
@ -1973,17 +1980,17 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
bad = BadGetAttr([('foo', 'bar')])
|
||||
c.insert_one({'bad': bad})
|
||||
self.assertEqual('bar', c.find_one()['bad']['foo'])
|
||||
self.assertEqual('bar', c.find_one()['bad']['foo']) # type: ignore
|
||||
|
||||
def test_array_filters_validation(self):
|
||||
# array_filters must be a list.
|
||||
c = self.db.test
|
||||
with self.assertRaises(TypeError):
|
||||
c.update_one({}, {'$set': {'a': 1}}, array_filters={})
|
||||
c.update_one({}, {'$set': {'a': 1}}, array_filters={}) # type: ignore[arg-type]
|
||||
with self.assertRaises(TypeError):
|
||||
c.update_many({}, {'$set': {'a': 1}}, array_filters={})
|
||||
c.update_many({}, {'$set': {'a': 1}}, array_filters={} ) # type: ignore[arg-type]
|
||||
with self.assertRaises(TypeError):
|
||||
c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={})
|
||||
c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={}) # type: ignore[arg-type]
|
||||
|
||||
def test_array_filters_unacknowledged(self):
|
||||
c_w0 = self.db.test.with_options(write_concern=WriteConcern(w=0))
|
||||
@ -2158,7 +2165,7 @@ class TestCollection(IntegrationTest):
|
||||
c.drop()
|
||||
c.insert_one({'r': re.compile('.*')})
|
||||
|
||||
self.assertTrue(isinstance(c.find_one()['r'], Regex))
|
||||
self.assertTrue(isinstance(c.find_one()['r'], Regex)) # type: ignore
|
||||
for doc in c.find():
|
||||
self.assertTrue(isinstance(doc['r'], Regex))
|
||||
|
||||
@ -2189,9 +2196,9 @@ class TestCollection(IntegrationTest):
|
||||
for helper, args in helpers:
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"let must be an instance of dict"):
|
||||
helper(*args, let=let)
|
||||
helper(*args, let=let) # type: ignore
|
||||
for helper, args in helpers:
|
||||
helper(*args, let={})
|
||||
helper(*args, let={}) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -43,6 +43,8 @@ def camel_to_snake(camel):
|
||||
|
||||
|
||||
class TestAllScenarios(unittest.TestCase):
|
||||
listener: EventListener
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
|
||||
@ -50,13 +50,13 @@ class TestCommon(IntegrationTest):
|
||||
"uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
legacy_opts = coll.codec_options
|
||||
coll.insert_one({'uu': uu})
|
||||
self.assertEqual(uu, coll.find_one({'uu': uu})['uu'])
|
||||
self.assertEqual(uu, coll.find_one({'uu': uu})['uu']) # type: ignore
|
||||
coll = self.db.get_collection(
|
||||
"uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
self.assertEqual(STANDARD, coll.codec_options.uuid_representation)
|
||||
self.assertEqual(None, coll.find_one({'uu': uu}))
|
||||
uul = Binary.from_uuid(uu, PYTHON_LEGACY)
|
||||
self.assertEqual(uul, coll.find_one({'uu': uul})['uu'])
|
||||
self.assertEqual(uul, coll.find_one({'uu': uul})['uu']) # type: ignore
|
||||
|
||||
# Test count_documents
|
||||
self.assertEqual(0, coll.count_documents({'uu': uu}))
|
||||
@ -81,9 +81,9 @@ class TestCommon(IntegrationTest):
|
||||
coll.update_one({'_id': uu}, {'$set': {'i': 2}})
|
||||
coll = self.db.get_collection(
|
||||
"uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(1, coll.find_one({'_id': uu})['i'])
|
||||
self.assertEqual(1, coll.find_one({'_id': uu})['i']) # type: ignore
|
||||
coll.update_one({'_id': uu}, {'$set': {'i': 2}})
|
||||
self.assertEqual(2, coll.find_one({'_id': uu})['i'])
|
||||
self.assertEqual(2, coll.find_one({'_id': uu})['i']) # type: ignore
|
||||
|
||||
# Test Cursor.distinct
|
||||
self.assertEqual([2], coll.find({'_id': uu}).distinct('i'))
|
||||
@ -98,7 +98,7 @@ class TestCommon(IntegrationTest):
|
||||
"uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(2, coll.find_one_and_update({'_id': uu},
|
||||
{'$set': {'i': 5}})['i'])
|
||||
self.assertEqual(5, coll.find_one({'_id': uu})['i'])
|
||||
self.assertEqual(5, coll.find_one({'_id': uu})['i']) # type: ignore
|
||||
|
||||
# Test command
|
||||
self.assertEqual(5, self.db.command(
|
||||
|
||||
@ -20,6 +20,7 @@ sys.path[0:0] = [""]
|
||||
|
||||
from bson import SON
|
||||
from pymongo import monitoring
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.errors import NotPrimaryError
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
@ -33,6 +34,9 @@ from test.utils import (CMAPListener,
|
||||
|
||||
|
||||
class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
listener: CMAPListener
|
||||
coll: Collection
|
||||
|
||||
@classmethod
|
||||
@client_context.require_replica_set
|
||||
def setUpClass(cls):
|
||||
@ -111,7 +115,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
# Insert record and verify failure.
|
||||
with self.assertRaises(NotPrimaryError) as exc:
|
||||
self.coll.insert_one({"test": 1})
|
||||
self.assertEqual(exc.exception.details['code'], error_code)
|
||||
self.assertEqual(exc.exception.details['code'], error_code) # type: ignore
|
||||
# Retry before CMAPListener assertion if retry_before=True.
|
||||
if retry:
|
||||
self.coll.insert_one({"test": 1})
|
||||
|
||||
@ -53,7 +53,7 @@ def check_result(self, expected_result, result):
|
||||
# SPEC-869: Only BulkWriteResult has upserted_count.
|
||||
if (prop == "upserted_count"
|
||||
and not isinstance(result, BulkWriteResult)):
|
||||
if result.upserted_id is not None:
|
||||
if result.upserted_id is not None: # type: ignore
|
||||
upserted_count = 1
|
||||
else:
|
||||
upserted_count = 0
|
||||
@ -69,14 +69,14 @@ def check_result(self, expected_result, result):
|
||||
ids = expected_result[res]
|
||||
if isinstance(ids, dict):
|
||||
ids = [ids[str(i)] for i in range(len(ids))]
|
||||
self.assertEqual(ids, result.inserted_ids, msg)
|
||||
self.assertEqual(ids, result.inserted_ids, msg) # type: ignore
|
||||
elif prop == "upserted_ids":
|
||||
# Convert indexes from strings to integers.
|
||||
ids = expected_result[res]
|
||||
expected_ids = {}
|
||||
for str_index in ids:
|
||||
expected_ids[int(str_index)] = ids[str_index]
|
||||
self.assertEqual(expected_ids, result.upserted_ids, msg)
|
||||
self.assertEqual(expected_ids, result.upserted_ids, msg) # type: ignore
|
||||
else:
|
||||
self.assertEqual(
|
||||
getattr(result, prop), expected_result[res], msg)
|
||||
|
||||
@ -57,7 +57,7 @@ class TestCursor(IntegrationTest):
|
||||
re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})
|
||||
|
||||
cursor2 = copy.deepcopy(cursor)
|
||||
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)
|
||||
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) # type: ignore
|
||||
|
||||
def test_add_remove_option(self):
|
||||
cursor = self.db.test.find()
|
||||
@ -149,9 +149,9 @@ class TestCursor(IntegrationTest):
|
||||
self.assertRaises(TypeError, coll.find().allow_disk_use, 'baz')
|
||||
|
||||
cursor = coll.find().allow_disk_use(True)
|
||||
self.assertEqual(True, cursor._Cursor__allow_disk_use)
|
||||
self.assertEqual(True, cursor._Cursor__allow_disk_use) # type: ignore
|
||||
cursor = coll.find().allow_disk_use(False)
|
||||
self.assertEqual(False, cursor._Cursor__allow_disk_use)
|
||||
self.assertEqual(False, cursor._Cursor__allow_disk_use) # type: ignore
|
||||
|
||||
def test_max_time_ms(self):
|
||||
db = self.db
|
||||
@ -165,15 +165,15 @@ class TestCursor(IntegrationTest):
|
||||
coll.find().max_time_ms(1)
|
||||
|
||||
cursor = coll.find().max_time_ms(999)
|
||||
self.assertEqual(999, cursor._Cursor__max_time_ms)
|
||||
self.assertEqual(999, cursor._Cursor__max_time_ms) # type: ignore
|
||||
cursor = coll.find().max_time_ms(10).max_time_ms(1000)
|
||||
self.assertEqual(1000, cursor._Cursor__max_time_ms)
|
||||
self.assertEqual(1000, cursor._Cursor__max_time_ms) # type: ignore
|
||||
|
||||
cursor = coll.find().max_time_ms(999)
|
||||
c2 = cursor.clone()
|
||||
self.assertEqual(999, c2._Cursor__max_time_ms)
|
||||
self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
|
||||
self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())
|
||||
self.assertEqual(999, c2._Cursor__max_time_ms) # type: ignore
|
||||
self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec()) # type: ignore
|
||||
self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec()) # type: ignore
|
||||
|
||||
self.assertTrue(coll.find_one(max_time_ms=1000))
|
||||
|
||||
@ -889,7 +889,7 @@ class TestCursor(IntegrationTest):
|
||||
|
||||
# Every attribute should be the same.
|
||||
cursor2 = cursor.clone()
|
||||
self.assertDictEqual(cursor.__dict__, cursor2.__dict__)
|
||||
self.assertEqual(cursor.__dict__, cursor2.__dict__)
|
||||
|
||||
# Shallow copies can so can mutate
|
||||
cursor2 = copy.copy(cursor)
|
||||
@ -1025,7 +1025,7 @@ class TestCursor(IntegrationTest):
|
||||
self.assertEqual(self.db.test, self.db.test.find().collection)
|
||||
|
||||
def set_coll():
|
||||
self.db.test.find().collection = "hello"
|
||||
self.db.test.find().collection = "hello" # type: ignore
|
||||
|
||||
self.assertRaises(AttributeError, set_coll)
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import tempfile
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
from random import random
|
||||
from typing import Any, Tuple, Type, no_type_check
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -127,6 +128,7 @@ def type_obfuscating_decoder_factory(rt_type):
|
||||
|
||||
|
||||
class CustomBSONTypeTests(object):
|
||||
@no_type_check
|
||||
def roundtrip(self, doc):
|
||||
bsonbytes = encode(doc, codec_options=self.codecopts)
|
||||
rt_document = decode(bsonbytes, codec_options=self.codecopts)
|
||||
@ -139,6 +141,7 @@ class CustomBSONTypeTests(object):
|
||||
self.roundtrip({'average': [[Decimal('56.47')]]})
|
||||
self.roundtrip({'average': [{'b': Decimal('56.47')}]})
|
||||
|
||||
@no_type_check
|
||||
def test_decode_all(self):
|
||||
documents = []
|
||||
for dec in range(3):
|
||||
@ -151,12 +154,14 @@ class CustomBSONTypeTests(object):
|
||||
self.assertEqual(
|
||||
decode_all(bsonstream, self.codecopts), documents)
|
||||
|
||||
@no_type_check
|
||||
def test__bson_to_dict(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
rawbytes = encode(document, codec_options=self.codecopts)
|
||||
decoded_document = _bson_to_dict(rawbytes, self.codecopts)
|
||||
self.assertEqual(document, decoded_document)
|
||||
|
||||
@no_type_check
|
||||
def test__dict_to_bson(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
rawbytes = encode(document, codec_options=self.codecopts)
|
||||
@ -172,12 +177,14 @@ class CustomBSONTypeTests(object):
|
||||
bsonstream += encode(doc)
|
||||
return edocs, bsonstream
|
||||
|
||||
@no_type_check
|
||||
def test_decode_iter(self):
|
||||
expected, bson_data = self._generate_multidocument_bson_stream()
|
||||
for expected_doc, decoded_doc in zip(
|
||||
expected, decode_iter(bson_data, self.codecopts)):
|
||||
self.assertEqual(expected_doc, decoded_doc)
|
||||
|
||||
@no_type_check
|
||||
def test_decode_file_iter(self):
|
||||
expected, bson_data = self._generate_multidocument_bson_stream()
|
||||
fileobj = tempfile.TemporaryFile()
|
||||
@ -293,6 +300,15 @@ class TestBSONTypeEnDeCodecs(unittest.TestCase):
|
||||
|
||||
|
||||
class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
|
||||
|
||||
TypeA: Any
|
||||
TypeB: Any
|
||||
fallback_encoder_A2B: Any
|
||||
fallback_encoder_A2BSON: Any
|
||||
B2BSON: Type[TypeEncoder]
|
||||
B2A: Type[TypeEncoder]
|
||||
A2B: Type[TypeEncoder]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
class TypeA(object):
|
||||
@ -378,6 +394,10 @@ class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
|
||||
|
||||
|
||||
class TestTypeRegistry(unittest.TestCase):
|
||||
types: Tuple[object, object]
|
||||
codecs: Tuple[Type[TypeCodec], Type[TypeCodec]]
|
||||
fallback_encoder: Any
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
class MyIntType(object):
|
||||
@ -466,32 +486,32 @@ class TestTypeRegistry(unittest.TestCase):
|
||||
def transform_bson(self, value):
|
||||
return self.types[0](value)
|
||||
|
||||
codec_instances = [MyIntDecoder(), MyIntEncoder()]
|
||||
codec_instances: list = [MyIntDecoder(), MyIntEncoder()]
|
||||
type_registry = TypeRegistry(codec_instances)
|
||||
|
||||
self.assertEqual(
|
||||
type_registry._encoder_map,
|
||||
{MyIntEncoder.python_type: codec_instances[1].transform_python})
|
||||
{MyIntEncoder.python_type: codec_instances[1].transform_python}) # type: ignore
|
||||
self.assertEqual(
|
||||
type_registry._decoder_map,
|
||||
{MyIntDecoder.bson_type: codec_instances[0].transform_bson})
|
||||
{MyIntDecoder.bson_type: codec_instances[0].transform_bson}) # type: ignore
|
||||
|
||||
def test_initialize_fail(self):
|
||||
err_msg = ("Expected an instance of TypeEncoder, TypeDecoder, "
|
||||
"or TypeCodec, got .* instead")
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(self.codecs)
|
||||
TypeRegistry(self.codecs) # type: ignore[arg-type]
|
||||
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry([type('AnyType', (object,), {})()])
|
||||
|
||||
err_msg = "fallback_encoder %r is not a callable" % (True,)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry([], True)
|
||||
TypeRegistry([], True) # type: ignore[arg-type]
|
||||
|
||||
err_msg = "fallback_encoder %r is not a callable" % ('hello',)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(fallback_encoder='hello')
|
||||
TypeRegistry(fallback_encoder='hello') # type: ignore[arg-type]
|
||||
|
||||
def test_type_registry_repr(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
@ -525,7 +545,7 @@ class TestTypeRegistry(unittest.TestCase):
|
||||
if pytype in [bool, type(None), RE_TYPE,]:
|
||||
continue
|
||||
|
||||
class MyType(pytype):
|
||||
class MyType(pytype): # type: ignore
|
||||
pass
|
||||
attrs.update({'python_type': MyType,
|
||||
'transform_python': lambda x: x})
|
||||
@ -598,7 +618,7 @@ class TestCollectionWCustomType(IntegrationTest):
|
||||
test = db.get_collection(
|
||||
'test', codec_options=UNINT_DECODER_CODECOPTS)
|
||||
|
||||
pipeline = [
|
||||
pipeline: list = [
|
||||
{'$match': {'status': 'complete'}},
|
||||
{'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},]
|
||||
result = test.aggregate(pipeline)
|
||||
@ -680,15 +700,18 @@ class TestGridFileCustomType(IntegrationTest):
|
||||
|
||||
|
||||
class ChangeStreamsWCustomTypesTestMixin(object):
|
||||
@no_type_check
|
||||
def change_stream(self, *args, **kwargs):
|
||||
return self.watched_target.watch(*args, **kwargs)
|
||||
|
||||
@no_type_check
|
||||
def insert_and_check(self, change_stream, insert_doc,
|
||||
expected_doc):
|
||||
self.input_target.insert_one(insert_doc)
|
||||
change = next(change_stream)
|
||||
self.assertEqual(change['fullDocument'], expected_doc)
|
||||
|
||||
@no_type_check
|
||||
def kill_change_stream_cursor(self, change_stream):
|
||||
# Cause a cursor not found error on the next getMore.
|
||||
cursor = change_stream._cursor
|
||||
@ -696,6 +719,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
|
||||
client = self.input_target.database.client
|
||||
client._close_cursor_now(cursor.cursor_id, address)
|
||||
|
||||
@no_type_check
|
||||
def test_simple(self):
|
||||
codecopts = CodecOptions(type_registry=TypeRegistry([
|
||||
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
|
||||
@ -718,6 +742,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
|
||||
|
||||
@no_type_check
|
||||
def test_custom_type_in_pipeline(self):
|
||||
codecopts = CodecOptions(type_registry=TypeRegistry([
|
||||
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
|
||||
@ -741,6 +766,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, input_docs[2], expected_docs[1])
|
||||
|
||||
@no_type_check
|
||||
def test_break_resume_token(self):
|
||||
# Get one document from a change stream to determine resumeToken type.
|
||||
self.create_targets()
|
||||
@ -766,6 +792,7 @@ class ChangeStreamsWCustomTypesTestMixin(object):
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, docs[2], docs[2])
|
||||
|
||||
@no_type_check
|
||||
def test_document_class(self):
|
||||
def run_test(doc_cls):
|
||||
codecopts = CodecOptions(type_registry=TypeRegistry([
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
import datetime
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -57,6 +58,7 @@ from test.test_custom_types import DECIMAL_CODECOPTS
|
||||
class TestDatabaseNoConnect(unittest.TestCase):
|
||||
"""Test Database features on a client that does not connect.
|
||||
"""
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -143,7 +145,7 @@ class TestDatabase(IntegrationTest):
|
||||
test = db.create_collection("test")
|
||||
self.assertTrue("test" in db.list_collection_names())
|
||||
test.insert_one({"hello": "world"})
|
||||
self.assertEqual(db.test.find_one()["hello"], "world")
|
||||
self.assertEqual(db.test.find_one()["hello"], "world") # type: ignore
|
||||
|
||||
db.drop_collection("test.foo")
|
||||
db.create_collection("test.foo")
|
||||
@ -198,6 +200,7 @@ class TestDatabase(IntegrationTest):
|
||||
self.assertNotIn("nameOnly", results["started"][0].command)
|
||||
|
||||
# Should send nameOnly (except on 2.6).
|
||||
filter: Any
|
||||
for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}):
|
||||
results.clear()
|
||||
names = db.list_collection_names(filter=filter)
|
||||
@ -225,7 +228,7 @@ class TestDatabase(IntegrationTest):
|
||||
self.assertTrue("$" not in coll)
|
||||
|
||||
# Duplicate check.
|
||||
coll_cnt = {}
|
||||
coll_cnt: dict = {}
|
||||
for coll in colls:
|
||||
try:
|
||||
# Found duplicate.
|
||||
@ -233,7 +236,7 @@ class TestDatabase(IntegrationTest):
|
||||
self.assertTrue(False)
|
||||
except KeyError:
|
||||
coll_cnt[coll] = 1
|
||||
coll_cnt = {}
|
||||
coll_cnt: dict = {}
|
||||
|
||||
# Checking if is there any collection which don't exists.
|
||||
if (len(set(colls) - set(["test","test.mike"])) == 0 or
|
||||
@ -466,6 +469,7 @@ class TestDatabase(IntegrationTest):
|
||||
self.assertEqual(None, db.test.find_one({"hello": "test"}))
|
||||
|
||||
b = db.test.find_one()
|
||||
assert b is not None
|
||||
b["hello"] = "mike"
|
||||
db.test.replace_one({"_id": b["_id"]}, b)
|
||||
|
||||
@ -482,12 +486,12 @@ class TestDatabase(IntegrationTest):
|
||||
db = self.client.pymongo_test
|
||||
db.test.drop()
|
||||
db.test.insert_one({"x": 9223372036854775807})
|
||||
retrieved = db.test.find_one()['x']
|
||||
retrieved = db.test.find_one()['x'] # type: ignore
|
||||
self.assertEqual(Int64(9223372036854775807), retrieved)
|
||||
self.assertIsInstance(retrieved, Int64)
|
||||
db.test.delete_many({})
|
||||
db.test.insert_one({"x": Int64(1)})
|
||||
retrieved = db.test.find_one()['x']
|
||||
retrieved = db.test.find_one()['x'] # type: ignore
|
||||
self.assertEqual(Int64(1), retrieved)
|
||||
self.assertIsInstance(retrieved, Int64)
|
||||
|
||||
@ -509,8 +513,8 @@ class TestDatabase(IntegrationTest):
|
||||
length += 1
|
||||
self.assertEqual(length, 2)
|
||||
|
||||
db.test.delete_one(db.test.find_one())
|
||||
db.test.delete_one(db.test.find_one())
|
||||
db.test.delete_one(db.test.find_one()) # type: ignore[arg-type]
|
||||
db.test.delete_one(db.test.find_one()) # type: ignore[arg-type]
|
||||
self.assertEqual(db.test.find_one(), None)
|
||||
|
||||
db.test.insert_one({"x": 1})
|
||||
@ -625,7 +629,7 @@ class TestDatabase(IntegrationTest):
|
||||
'read_preference': ReadPreference.PRIMARY,
|
||||
'write_concern': WriteConcern(w=1),
|
||||
'read_concern': ReadConcern(level="local")}
|
||||
db2 = db1.with_options(**newopts)
|
||||
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
|
||||
for opt in newopts:
|
||||
self.assertEqual(
|
||||
getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
|
||||
@ -633,7 +637,7 @@ class TestDatabase(IntegrationTest):
|
||||
|
||||
class TestDatabaseAggregation(IntegrationTest):
|
||||
def setUp(self):
|
||||
self.pipeline = [{"$listLocalSessions": {}},
|
||||
self.pipeline: List[Mapping[str, Any]] = [{"$listLocalSessions": {}},
|
||||
{"$limit": 1},
|
||||
{"$addFields": {"dummy": "dummy field"}},
|
||||
{"$project": {"_id": 0, "dummy": 1}}]
|
||||
@ -648,6 +652,7 @@ class TestDatabaseAggregation(IntegrationTest):
|
||||
@client_context.require_no_mongos
|
||||
def test_database_aggregation_fake_cursor(self):
|
||||
coll_name = "test_output"
|
||||
write_stage: dict
|
||||
if client_context.version < (4, 3):
|
||||
db_name = "admin"
|
||||
write_stage = {"$out": coll_name}
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Any
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson import encode, decode
|
||||
@ -44,10 +45,10 @@ class TestDBRef(unittest.TestCase):
|
||||
a = DBRef("coll", ObjectId())
|
||||
|
||||
def foo():
|
||||
a.collection = "blah"
|
||||
a.collection = "blah" # type: ignore[misc]
|
||||
|
||||
def bar():
|
||||
a.id = "aoeu"
|
||||
a.id = "aoeu" # type: ignore[misc]
|
||||
|
||||
self.assertEqual("coll", a.collection)
|
||||
a.id
|
||||
@ -136,6 +137,7 @@ class TestDBRef(unittest.TestCase):
|
||||
# https://github.com/mongodb/specifications/blob/master/source/dbref.rst#test-plan
|
||||
class TestDBRefSpec(unittest.TestCase):
|
||||
def test_decoding_1_2_3(self):
|
||||
doc: Any
|
||||
for doc in [
|
||||
# 1, Valid documents MUST be decoded to a DBRef:
|
||||
{"$ref": "coll0", "$id": ObjectId("60a6fe9a54f4180c86309efa")},
|
||||
@ -183,6 +185,7 @@ class TestDBRefSpec(unittest.TestCase):
|
||||
self.assertIsInstance(dbref, dict)
|
||||
|
||||
def test_encoding_1_2(self):
|
||||
doc: Any
|
||||
for doc in [
|
||||
# 1, Encoding DBRefs with basic fields:
|
||||
{"$ref": "coll0", "$id": ObjectId("60a6fe9a54f4180c86309efa")},
|
||||
|
||||
@ -35,6 +35,7 @@ class TestDecimal128(unittest.TestCase):
|
||||
b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0')
|
||||
coll.insert_one({'dec128': dec128})
|
||||
doc = coll.find_one({'dec128': dec128})
|
||||
assert doc is not None
|
||||
self.assertIsNotNone(doc)
|
||||
self.assertEqual(doc['dec128'], dec128)
|
||||
|
||||
|
||||
@ -364,10 +364,12 @@ class TestIntegration(SpecRunner):
|
||||
def marked_unknown(e):
|
||||
return (isinstance(e, monitoring.ServerDescriptionChangedEvent)
|
||||
and not e.new_description.is_server_type_known)
|
||||
assert self.server_listener is not None
|
||||
return len(self.server_listener.matching(marked_unknown))
|
||||
# Only support CMAP events for now.
|
||||
self.assertTrue(event.startswith('Pool') or event.startswith('Conn'))
|
||||
event_type = getattr(monitoring, event)
|
||||
assert self.pool_listener is not None
|
||||
return self.pool_listener.event_count(event_type)
|
||||
|
||||
def assert_event_count(self, event, count):
|
||||
|
||||
@ -25,6 +25,10 @@ import textwrap
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pymongo.collection import Collection
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson import encode, json_util
|
||||
@ -126,6 +130,7 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r'kms_tls_options\["kmip"\] must be a dict'):
|
||||
AutoEncryptionOpts({}, 'k.d', kms_tls_options={'kmip': 1})
|
||||
tls_opts: Any
|
||||
for tls_opts in [
|
||||
{'kmip': {'tls': True, 'tlsInsecure': True}},
|
||||
{'kmip': {'tls': True, 'tlsAllowInvalidCertificates': True}},
|
||||
@ -138,6 +143,7 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
AutoEncryptionOpts({}, 'k.d', kms_tls_options={
|
||||
'kmip': {'tlsCAFile': 'does-not-exist'}})
|
||||
# Success cases:
|
||||
tls_opts: Any
|
||||
for tls_opts in [None, {}]:
|
||||
opts = AutoEncryptionOpts({}, 'k.d', kms_tls_options=tls_opts)
|
||||
self.assertEqual(opts._kms_ssl_contexts, {})
|
||||
@ -432,14 +438,14 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
|
||||
msg = 'value to decrypt must be a bson.binary.Binary with subtype 6'
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
client_encryption.decrypt('str')
|
||||
client_encryption.decrypt('str') # type: ignore[arg-type]
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
client_encryption.decrypt(Binary(b'123'))
|
||||
|
||||
msg = 'key_id must be a bson.binary.Binary with subtype 4'
|
||||
algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
client_encryption.encrypt('str', algo, key_id=uuid.uuid4())
|
||||
client_encryption.encrypt('str', algo, key_id=uuid.uuid4()) # type: ignore[arg-type]
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
client_encryption.encrypt('str', algo, key_id=Binary(b'123'))
|
||||
|
||||
@ -459,7 +465,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
def test_codec_options(self):
|
||||
with self.assertRaisesRegex(TypeError, 'codec_options must be'):
|
||||
ClientEncryption(
|
||||
KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None)
|
||||
KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None) # type: ignore[arg-type]
|
||||
|
||||
opts = CodecOptions(uuid_representation=JAVA_LEGACY)
|
||||
client_encryption_legacy = ClientEncryption(
|
||||
@ -708,6 +714,10 @@ def create_key_vault(vault, *data_keys):
|
||||
|
||||
|
||||
class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
client_encrypted: MongoClient
|
||||
client_encryption: ClientEncryption
|
||||
listener: OvertCommandListener
|
||||
vault: Any
|
||||
|
||||
KMS_PROVIDERS = ALL_KMS_PROVIDERS
|
||||
|
||||
@ -776,7 +786,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
|
||||
def run_test(self, provider_name):
|
||||
# Create data key.
|
||||
master_key = self.MASTER_KEYS[provider_name]
|
||||
master_key: Any = self.MASTER_KEYS[provider_name]
|
||||
datakey_id = self.client_encryption.create_data_key(
|
||||
provider_name, master_key=master_key,
|
||||
key_alt_names=['%s_altname' % (provider_name,)])
|
||||
@ -798,7 +808,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
{'_id': provider_name, 'value': encrypted})
|
||||
doc_decrypted = self.client_encrypted.db.coll.find_one(
|
||||
{'_id': provider_name})
|
||||
self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,))
|
||||
self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,)) # type: ignore
|
||||
|
||||
# Encrypt by key_alt_name.
|
||||
encrypted_altname = self.client_encryption.encrypt(
|
||||
@ -985,7 +995,7 @@ class TestCorpus(EncryptionIntegrationTest):
|
||||
self.addCleanup(client_encryption.close)
|
||||
|
||||
corpus = self.fix_up_curpus(json_data('corpus', 'corpus.json'))
|
||||
corpus_copied = SON()
|
||||
corpus_copied: SON = SON()
|
||||
for key, value in corpus.items():
|
||||
corpus_copied[key] = copy.deepcopy(value)
|
||||
if key in ('_id', 'altname_aws', 'altname_azure', 'altname_gcp',
|
||||
@ -1021,7 +1031,7 @@ class TestCorpus(EncryptionIntegrationTest):
|
||||
|
||||
try:
|
||||
encrypted_val = client_encryption.encrypt(
|
||||
value['value'], algo, **kwargs)
|
||||
value['value'], algo, **kwargs) # type: ignore[arg-type]
|
||||
if not value['allowed']:
|
||||
self.fail('encrypt should have failed: %r: %r' % (
|
||||
key, value))
|
||||
@ -1082,6 +1092,10 @@ _16_MiB = 16777216
|
||||
|
||||
class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
"""Prose tests for BSON size limits and batch splitting."""
|
||||
coll: Collection
|
||||
coll_encrypted: Collection
|
||||
client_encrypted: MongoClient
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -1397,6 +1411,7 @@ class AzureGCPEncryptionTestMixin(object):
|
||||
KMS_PROVIDER_MAP = None
|
||||
KEYVAULT_DB = 'keyvault'
|
||||
KEYVAULT_COLL = 'datakeys'
|
||||
client: MongoClient
|
||||
|
||||
def setUp(self):
|
||||
keyvault = self.client.get_database(
|
||||
@ -1406,7 +1421,7 @@ class AzureGCPEncryptionTestMixin(object):
|
||||
|
||||
def _test_explicit(self, expectation):
|
||||
client_encryption = ClientEncryption(
|
||||
self.KMS_PROVIDER_MAP,
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
'.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||
client_context.client,
|
||||
OPTS)
|
||||
@ -1426,7 +1441,7 @@ class AzureGCPEncryptionTestMixin(object):
|
||||
keyvault_namespace = '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
||||
|
||||
encryption_opts = AutoEncryptionOpts(
|
||||
self.KMS_PROVIDER_MAP,
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
keyvault_namespace,
|
||||
schema_map=self.SCHEMA_MAP)
|
||||
|
||||
@ -1818,7 +1833,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
def setUp(self):
|
||||
super(TestKmsTLSOptions, self).setUp()
|
||||
# 1, create client with only tlsCAFile.
|
||||
providers = copy.deepcopy(ALL_KMS_PROVIDERS)
|
||||
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
|
||||
providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8002'
|
||||
providers['gcp']['endpoint'] = '127.0.0.1:8002'
|
||||
kms_tls_opts_ca_only = {
|
||||
@ -1840,7 +1855,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
kms_tls_options=kms_tls_opts)
|
||||
self.addCleanup(self.client_encryption_with_tls.close)
|
||||
# 3, update endpoints to expired host.
|
||||
providers = copy.deepcopy(providers)
|
||||
providers: dict = copy.deepcopy(providers)
|
||||
providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8000'
|
||||
providers['gcp']['endpoint'] = '127.0.0.1:8000'
|
||||
providers['kmip']['endpoint'] = '127.0.0.1:8000'
|
||||
@ -1849,7 +1864,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
kms_tls_options=kms_tls_opts_ca_only)
|
||||
self.addCleanup(self.client_encryption_expired.close)
|
||||
# 3, update endpoints to invalid host.
|
||||
providers = copy.deepcopy(providers)
|
||||
providers: dict = copy.deepcopy(providers)
|
||||
providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8001'
|
||||
providers['gcp']['endpoint'] = '127.0.0.1:8001'
|
||||
providers['kmip']['endpoint'] = '127.0.0.1:8001'
|
||||
|
||||
@ -890,6 +890,7 @@ class TestTransactionExamples(IntegrationTest):
|
||||
update_employee_info(session)
|
||||
|
||||
employee = employees.find_one({"employee": 3})
|
||||
assert employee is not None
|
||||
self.assertIsNotNone(employee)
|
||||
self.assertEqual(employee['status'], 'Inactive')
|
||||
|
||||
@ -916,6 +917,7 @@ class TestTransactionExamples(IntegrationTest):
|
||||
run_transaction_with_retry(update_employee_info, session)
|
||||
|
||||
employee = employees.find_one({"employee": 3})
|
||||
assert employee is not None
|
||||
self.assertIsNotNone(employee)
|
||||
self.assertEqual(employee['status'], 'Inactive')
|
||||
|
||||
@ -954,6 +956,7 @@ class TestTransactionExamples(IntegrationTest):
|
||||
run_transaction_with_retry(_insert_employee_retry_commit, session)
|
||||
|
||||
employee = employees.find_one({"employee": 4})
|
||||
assert employee is not None
|
||||
self.assertIsNotNone(employee)
|
||||
self.assertEqual(employee['status'], 'Active')
|
||||
|
||||
@ -1021,6 +1024,7 @@ class TestTransactionExamples(IntegrationTest):
|
||||
# End Transactions Retry Example 3
|
||||
|
||||
employee = employees.find_one({"employee": 3})
|
||||
assert employee is not None
|
||||
self.assertIsNotNone(employee)
|
||||
self.assertEqual(employee['status'], 'Inactive')
|
||||
|
||||
@ -1089,6 +1093,9 @@ class TestCausalConsistencyExamples(IntegrationTest):
|
||||
'start': current_date}, session=s1)
|
||||
# End Causal Consistency Example 1
|
||||
|
||||
assert s1.cluster_time is not None
|
||||
assert s1.operation_time is not None
|
||||
|
||||
# Start Causal Consistency Example 2
|
||||
with client.start_session(causal_consistency=True) as s2:
|
||||
s2.advance_cluster_time(s1.cluster_time)
|
||||
|
||||
@ -24,6 +24,8 @@ import zipfile
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from pymongo.database import Database
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
@ -47,6 +49,7 @@ from test.utils import rs_or_single_client, EventListener
|
||||
class TestGridFileNoConnect(unittest.TestCase):
|
||||
"""Test GridFile features on a client that does not connect.
|
||||
"""
|
||||
db: Database
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -27,6 +27,7 @@ from io import BytesIO
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.database import Database
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
NotPrimaryError,
|
||||
@ -78,6 +79,7 @@ class JustRead(threading.Thread):
|
||||
|
||||
|
||||
class TestGridfsNoConnect(unittest.TestCase):
|
||||
db: Database
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -89,6 +91,8 @@ class TestGridfsNoConnect(unittest.TestCase):
|
||||
|
||||
|
||||
class TestGridfs(IntegrationTest):
|
||||
fs: gridfs.GridFS
|
||||
alt: gridfs.GridFS
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -152,6 +156,7 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
raw = self.db.fs.files.find_one()
|
||||
assert raw is not None
|
||||
self.assertEqual(0, raw["length"])
|
||||
self.assertEqual(oid, raw["_id"])
|
||||
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
|
||||
@ -213,7 +218,7 @@ class TestGridfs(IntegrationTest):
|
||||
self.fs.put(b"hello", _id="test")
|
||||
|
||||
threads = []
|
||||
results = []
|
||||
results: list = []
|
||||
for i in range(10):
|
||||
threads.append(JustRead(self.fs, 10, results))
|
||||
threads[i].start()
|
||||
@ -396,6 +401,7 @@ class TestGridfs(IntegrationTest):
|
||||
# Test fix that guards against PHP-237
|
||||
self.fs.put(b"", filename="empty")
|
||||
doc = self.db.fs.files.find_one({"filename": "empty"})
|
||||
assert doc is not None
|
||||
doc.pop("length")
|
||||
self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
|
||||
f = self.fs.get_last_version(filename="empty")
|
||||
@ -447,23 +453,32 @@ class TestGridfs(IntegrationTest):
|
||||
# but will still call __del__.
|
||||
cursor = GridOutCursor.__new__(GridOutCursor) # Skip calling __init__
|
||||
with self.assertRaises(TypeError):
|
||||
cursor.__init__(self.db.fs.files, {}, {"_id": True})
|
||||
cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore
|
||||
cursor.__del__() # no error
|
||||
|
||||
def test_gridfs_find_one(self):
|
||||
self.assertEqual(None, self.fs.find_one())
|
||||
|
||||
id1 = self.fs.put(b'test1', filename='file1')
|
||||
self.assertEqual(b'test1', self.fs.find_one().read())
|
||||
res = self.fs.find_one()
|
||||
assert res is not None
|
||||
self.assertEqual(b'test1', res.read())
|
||||
|
||||
id2 = self.fs.put(b'test2', filename='file2', meta='data')
|
||||
self.assertEqual(b'test1', self.fs.find_one(id1).read())
|
||||
self.assertEqual(b'test2', self.fs.find_one(id2).read())
|
||||
res1 = self.fs.find_one(id1)
|
||||
assert res1 is not None
|
||||
self.assertEqual(b'test1', res1.read())
|
||||
res2 = self.fs.find_one(id2)
|
||||
assert res2 is not None
|
||||
self.assertEqual(b'test2', res2.read())
|
||||
|
||||
self.assertEqual(b'test1',
|
||||
self.fs.find_one({'filename': 'file1'}).read())
|
||||
res3 = self.fs.find_one({'filename': 'file1'})
|
||||
assert res3 is not None
|
||||
self.assertEqual(b'test1', res3.read())
|
||||
|
||||
self.assertEqual('data', self.fs.find_one(id2).meta)
|
||||
res4 = self.fs.find_one(id2)
|
||||
assert res4 is not None
|
||||
self.assertEqual('data', res4.meta)
|
||||
|
||||
def test_grid_in_non_int_chunksize(self):
|
||||
# Lua, and perhaps other buggy GridFS clients, store size as a float.
|
||||
|
||||
@ -77,6 +77,8 @@ class JustRead(threading.Thread):
|
||||
|
||||
|
||||
class TestGridfs(IntegrationTest):
|
||||
fs: gridfs.GridFSBucket
|
||||
alt: gridfs.GridFSBucket
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -123,6 +125,7 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
raw = self.db.fs.files.find_one()
|
||||
assert raw is not None
|
||||
self.assertEqual(0, raw["length"])
|
||||
self.assertEqual(oid, raw["_id"])
|
||||
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
|
||||
@ -208,7 +211,7 @@ class TestGridfs(IntegrationTest):
|
||||
self.fs.upload_from_stream("test", b"hello")
|
||||
|
||||
threads = []
|
||||
results = []
|
||||
results: list = []
|
||||
for i in range(10):
|
||||
threads.append(JustRead(self.fs, 10, results))
|
||||
threads[i].start()
|
||||
@ -322,6 +325,7 @@ class TestGridfs(IntegrationTest):
|
||||
# Test fix that guards against PHP-237
|
||||
self.fs.upload_from_stream("empty", b"")
|
||||
doc = self.db.fs.files.find_one({"filename": "empty"})
|
||||
assert doc is not None
|
||||
doc.pop("length")
|
||||
self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
|
||||
fstr = self.fs.open_download_stream_by_name("empty")
|
||||
|
||||
@ -55,6 +55,9 @@ def camel_to_snake(camel):
|
||||
|
||||
|
||||
class TestAllScenarios(IntegrationTest):
|
||||
fs: gridfs.GridFSBucket
|
||||
str_to_cmd: dict
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestAllScenarios, cls).setUpClass()
|
||||
|
||||
@ -20,6 +20,8 @@ import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from typing import Any, List, MutableMapping
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson import json_util, EPOCH_AWARE, EPOCH_NAIVE, SON
|
||||
@ -466,7 +468,7 @@ class TestJsonUtilRoundtrip(IntegrationTest):
|
||||
db = self.db
|
||||
|
||||
db.drop_collection("test")
|
||||
docs = [
|
||||
docs: List[MutableMapping[str, Any]] = [
|
||||
{'foo': [1, 2]},
|
||||
{'bar': {'hello': 'world'}},
|
||||
{'code': Code("function x() { return 1; }")},
|
||||
|
||||
@ -35,7 +35,7 @@ _TEST_PATH = os.path.join(
|
||||
'max_staleness')
|
||||
|
||||
|
||||
class TestAllScenarios(create_selection_tests(_TEST_PATH)):
|
||||
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ class TestMonitor(IntegrationTest):
|
||||
|
||||
# Each executor stores a weakref to itself in _EXECUTORS.
|
||||
executor_refs = [
|
||||
(r, r()._name) for r in _EXECUTORS.copy() if r() in executors]
|
||||
(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] # type: ignore
|
||||
|
||||
del executors
|
||||
del client
|
||||
|
||||
@ -16,6 +16,7 @@ import copy
|
||||
import datetime
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -43,6 +44,7 @@ from test.utils import (EventListener,
|
||||
|
||||
|
||||
class TestCommandMonitoring(IntegrationTest):
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
@ -754,7 +756,7 @@ class TestCommandMonitoring(IntegrationTest):
|
||||
|
||||
# delete_one
|
||||
self.listener.results.clear()
|
||||
res = coll.delete_one({'x': 3})
|
||||
res2 = coll.delete_one({'x': 3})
|
||||
results = self.listener.results
|
||||
started = results['started'][0]
|
||||
succeeded = results['succeeded'][0]
|
||||
@ -1091,6 +1093,8 @@ class TestCommandMonitoring(IntegrationTest):
|
||||
|
||||
|
||||
class TestGlobalListener(IntegrationTest):
|
||||
listener: EventListener
|
||||
saved_listeners: Any
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
@ -1167,13 +1171,13 @@ class TestEventClasses(unittest.TestCase):
|
||||
"<ServerHeartbeatStartedEvent ('localhost', 27017)>")
|
||||
delta = 0.1
|
||||
event = monitoring.ServerHeartbeatSucceededEvent(
|
||||
delta, {'ok': 1}, connection_id)
|
||||
delta, {'ok': 1}, connection_id) # type: ignore[arg-type]
|
||||
self.assertEqual(
|
||||
repr(event),
|
||||
"<ServerHeartbeatSucceededEvent ('localhost', 27017) "
|
||||
"duration: 0.1, awaited: False, reply: {'ok': 1}>")
|
||||
event = monitoring.ServerHeartbeatFailedEvent(
|
||||
delta, 'ERROR', connection_id)
|
||||
delta, 'ERROR', connection_id) # type: ignore[arg-type]
|
||||
self.assertEqual(
|
||||
repr(event),
|
||||
"<ServerHeartbeatFailedEvent ('localhost', 27017) "
|
||||
@ -1188,7 +1192,7 @@ class TestEventClasses(unittest.TestCase):
|
||||
"<ServerOpeningEvent ('localhost', 27017) "
|
||||
"topology_id: 000000000000000000000001>")
|
||||
event = monitoring.ServerDescriptionChangedEvent(
|
||||
'PREV', 'NEW', server_address, topology_id)
|
||||
'PREV', 'NEW', server_address, topology_id) # type: ignore[arg-type]
|
||||
self.assertEqual(
|
||||
repr(event),
|
||||
"<ServerDescriptionChangedEvent ('localhost', 27017) "
|
||||
@ -1206,7 +1210,7 @@ class TestEventClasses(unittest.TestCase):
|
||||
repr(event),
|
||||
"<TopologyOpenedEvent topology_id: 000000000000000000000001>")
|
||||
event = monitoring.TopologyDescriptionChangedEvent(
|
||||
'PREV', 'NEW', topology_id)
|
||||
'PREV', 'NEW', topology_id) # type: ignore[arg-type]
|
||||
self.assertEqual(
|
||||
repr(event),
|
||||
"<TopologyDescriptionChangedEvent "
|
||||
|
||||
@ -106,7 +106,9 @@ class TestObjectId(unittest.TestCase):
|
||||
|
||||
aware = datetime.datetime(1993, 4, 4, 2,
|
||||
tzinfo=FixedOffset(555, "SomeZone"))
|
||||
as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc)
|
||||
offset = aware.utcoffset()
|
||||
assert offset is not None
|
||||
as_utc = (aware - offset).replace(tzinfo=utc)
|
||||
oid = ObjectId.from_datetime(aware)
|
||||
self.assertEqual(as_utc, oid.generation_time)
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ import random
|
||||
import sys
|
||||
from time import sleep
|
||||
|
||||
from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.ocsp_cache import _OCSPCache
|
||||
@ -28,14 +30,18 @@ from test import unittest
|
||||
|
||||
|
||||
class TestOcspCache(unittest.TestCase):
|
||||
MockHashAlgorithm: Any
|
||||
MockOcspRequest: Any
|
||||
MockOcspResponse: Any
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.MockHashAlgorithm = namedtuple(
|
||||
cls.MockHashAlgorithm = namedtuple( # type: ignore
|
||||
"MockHashAlgorithm", ['name'])
|
||||
cls.MockOcspRequest = namedtuple(
|
||||
cls.MockOcspRequest = namedtuple( # type: ignore
|
||||
"MockOcspRequest", ['hash_algorithm', 'issuer_name_hash',
|
||||
'issuer_key_hash', 'serial_number'])
|
||||
cls.MockOcspResponse = namedtuple(
|
||||
cls.MockOcspResponse = namedtuple( # type: ignore
|
||||
"MockOcspResponse", ["this_update", "next_update"])
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@ -82,6 +82,7 @@ class TestRawBSONDocument(IntegrationTest):
|
||||
codec_options=CodecOptions(document_class=RawBSONDocument))
|
||||
db.test_raw.insert_one(self.document)
|
||||
result = db.test_raw.find_one(self.document['_id'])
|
||||
assert result is not None
|
||||
self.assertIsInstance(result, RawBSONDocument)
|
||||
self.assertEqual(dict(self.document.items()), dict(result.items()))
|
||||
|
||||
@ -146,6 +147,7 @@ class TestRawBSONDocument(IntegrationTest):
|
||||
db = self.client.pymongo_test
|
||||
db.test_raw.insert_one(doc)
|
||||
result = db.test_raw.find_one()
|
||||
assert result is not None
|
||||
self.assertEqual(decode(self.document.raw), result['embedded'])
|
||||
|
||||
# Make sure that CodecOptions are preserved.
|
||||
@ -169,6 +171,7 @@ class TestRawBSONDocument(IntegrationTest):
|
||||
db.test_raw.insert_one(rbd)
|
||||
result = db.get_collection('test_raw', codec_options=CodecOptions(
|
||||
uuid_representation=JAVA_LEGACY)).find_one()
|
||||
assert result is not None
|
||||
self.assertEqual(rbd['embedded'][0]['_id'],
|
||||
result['embedded'][0]['_id'])
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from test.utils import single_client, rs_or_single_client, OvertCommandListener
|
||||
|
||||
|
||||
class TestReadConcern(IntegrationTest):
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
|
||||
@ -225,7 +225,7 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
localthresholdms=-1)
|
||||
|
||||
def test_zero_latency(self):
|
||||
ping_times = set()
|
||||
ping_times: set = set()
|
||||
# Generate unique ping times.
|
||||
while len(ping_times) < len(self.client.nodes):
|
||||
ping_times.add(random.random())
|
||||
@ -278,7 +278,7 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
# far, and keep reading until we've used all the members or give up.
|
||||
# Chance of using only 2 of 3 members 10k times if there's no bug =
|
||||
# 3 * (2/3)**10000, very low.
|
||||
used = set()
|
||||
used: set = set()
|
||||
i = 0
|
||||
while data_members.difference(used) and i < 10000:
|
||||
address = self.read_from_which_host(c)
|
||||
@ -335,6 +335,8 @@ _PREF_MAP = [
|
||||
|
||||
|
||||
class TestCommandAndReadPreference(IntegrationTest):
|
||||
c: ReadPrefTester
|
||||
client_version: Version
|
||||
|
||||
@classmethod
|
||||
@client_context.require_secondaries_count(1)
|
||||
@ -378,6 +380,7 @@ class TestCommandAndReadPreference(IntegrationTest):
|
||||
# Success
|
||||
break
|
||||
|
||||
assert self.c.primary is not None
|
||||
unused = self.c.secondaries.union(
|
||||
set([self.c.primary])
|
||||
).difference(used)
|
||||
@ -445,11 +448,11 @@ class TestMovingAverage(unittest.TestCase):
|
||||
avg = MovingAverage()
|
||||
self.assertIsNone(avg.get())
|
||||
avg.add_sample(10)
|
||||
self.assertAlmostEqual(10, avg.get())
|
||||
self.assertAlmostEqual(10, avg.get()) # type: ignore
|
||||
avg.add_sample(20)
|
||||
self.assertAlmostEqual(12, avg.get())
|
||||
self.assertAlmostEqual(12, avg.get()) # type: ignore
|
||||
avg.add_sample(30)
|
||||
self.assertAlmostEqual(15.6, avg.get())
|
||||
self.assertAlmostEqual(15.6, avg.get()) # type: ignore
|
||||
|
||||
|
||||
class TestMongosAndReadPreference(IntegrationTest):
|
||||
@ -526,7 +529,8 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
'maxStalenessSeconds': 30})
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
Nearest(max_staleness=1.5) # Float is prohibited.
|
||||
# Float is prohibited.
|
||||
Nearest(max_staleness=1.5) # type: ignore
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Nearest(max_staleness=0)
|
||||
@ -543,7 +547,7 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
}
|
||||
for mode, cls in cases.items():
|
||||
with self.assertRaises(TypeError):
|
||||
cls(hedge=[])
|
||||
cls(hedge=[]) # type: ignore
|
||||
|
||||
pref = cls(hedge={})
|
||||
self.assertEqual(pref.document, {'mode': mode})
|
||||
@ -668,7 +672,9 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
|
||||
@client_context.require_mongos
|
||||
def test_mongos(self):
|
||||
shard = client_context.client.config.shards.find_one()['host']
|
||||
res = client_context.client.config.shards.find_one()
|
||||
assert res is not None
|
||||
shard = res['host']
|
||||
num_members = shard.count(',') + 1
|
||||
if num_members == 1:
|
||||
raise SkipTest("Need a replica set shard to test.")
|
||||
|
||||
@ -149,12 +149,14 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
f()
|
||||
if expected == BulkWriteError:
|
||||
bulk_result = cm.exception.details
|
||||
assert bulk_result is not None
|
||||
wc_errors = bulk_result['writeConcernErrors']
|
||||
self.assertTrue(wc_errors)
|
||||
|
||||
@client_context.require_replica_set
|
||||
def test_raise_write_concern_error(self):
|
||||
self.addCleanup(client_context.client.drop_database, 'pymongo_test')
|
||||
assert client_context.w is not None
|
||||
self.assertWriteOpsRaise(
|
||||
WriteConcern(w=client_context.w+1, wtimeout=1), WriteConcernError)
|
||||
|
||||
@ -219,6 +221,7 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
db.test.insert_one({'x': 1})
|
||||
self.assertEqual(ctx.exception.code, 121)
|
||||
self.assertIsNotNone(ctx.exception.details)
|
||||
assert ctx.exception.details is not None
|
||||
self.assertIsNotNone(ctx.exception.details.get('errInfo'))
|
||||
for event in listener.results['succeeded']:
|
||||
if event.command_name == 'insert':
|
||||
@ -290,19 +293,19 @@ def create_document_test(test_case):
|
||||
WriteConcern,
|
||||
**normalized)
|
||||
else:
|
||||
concern = WriteConcern(**normalized)
|
||||
write_concern = WriteConcern(**normalized)
|
||||
self.assertEqual(
|
||||
concern.document, test_case['writeConcernDocument'])
|
||||
write_concern.document, test_case['writeConcernDocument'])
|
||||
self.assertEqual(
|
||||
concern.acknowledged, test_case['isAcknowledged'])
|
||||
write_concern.acknowledged, test_case['isAcknowledged'])
|
||||
self.assertEqual(
|
||||
concern.is_server_default, test_case['isServerDefault'])
|
||||
write_concern.is_server_default, test_case['isServerDefault'])
|
||||
if 'readConcern' in test_case:
|
||||
# Any string for 'level' is equaly valid
|
||||
concern = ReadConcern(**test_case['readConcern'])
|
||||
self.assertEqual(concern.document, test_case['readConcernDocument'])
|
||||
read_concern = ReadConcern(**test_case['readConcern'])
|
||||
self.assertEqual(read_concern.document, test_case['readConcernDocument'])
|
||||
self.assertEqual(
|
||||
not bool(concern.level), test_case['isServerDefault'])
|
||||
not bool(read_concern.level), test_case['isServerDefault'])
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
@ -135,6 +135,7 @@ def non_retryable_single_statement_ops(coll):
|
||||
class IgnoreDeprecationsTest(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
RUN_ON_SERVERLESS = True
|
||||
deprecation_filter: DeprecationFilter
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -148,6 +149,7 @@ class IgnoreDeprecationsTest(IntegrationTest):
|
||||
|
||||
|
||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -180,6 +182,8 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
|
||||
|
||||
class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
listener: OvertCommandListener
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
@client_context.require_no_mmap
|
||||
@ -426,6 +430,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
class TestWriteConcernError(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
RUN_ON_SERVERLESS = True
|
||||
fail_insert: dict
|
||||
|
||||
@classmethod
|
||||
@client_context.require_replica_set
|
||||
|
||||
@ -24,6 +24,7 @@ sys.path[0:0] = [""]
|
||||
from pymongo import MongoClient
|
||||
from bson.json_util import object_hook
|
||||
from pymongo import monitoring
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.common import clean_node
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
NotPrimaryError)
|
||||
@ -253,6 +254,10 @@ create_tests()
|
||||
|
||||
|
||||
class TestSdamMonitoring(IntegrationTest):
|
||||
knobs: client_knobs
|
||||
listener: ServerAndTopologyEventListener
|
||||
test_client: MongoClient
|
||||
coll: Collection
|
||||
|
||||
@classmethod
|
||||
@client_context.require_failCommand_fail_point
|
||||
|
||||
@ -52,7 +52,7 @@ class SelectionStoreSelector(object):
|
||||
|
||||
|
||||
|
||||
class TestAllScenarios(create_selection_tests(_TEST_PATH)):
|
||||
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
@ -125,7 +125,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
def test_latency_threshold_application(self):
|
||||
selector = SelectionStoreSelector()
|
||||
|
||||
scenario_def = {
|
||||
scenario_def: dict = {
|
||||
'topology_description': {
|
||||
'type': 'ReplicaSetWithPrimary', 'servers': [
|
||||
{'address': 'b:27017',
|
||||
@ -160,6 +160,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
# Invoke server selection and assert no filtering based on latency
|
||||
# prior to custom server selection logic kicking in.
|
||||
server = topology.select_server(ReadPreference.NEAREST)
|
||||
assert selector.selection is not None
|
||||
self.assertEqual(
|
||||
len(selector.selection),
|
||||
len(topology.description.server_descriptions()))
|
||||
|
||||
@ -116,7 +116,7 @@ class TestProse(IntegrationTest):
|
||||
self.assertEqual(len(events), N_FINDS * N_THREADS)
|
||||
nodes = client.nodes
|
||||
self.assertEqual(len(nodes), 2)
|
||||
freqs = {address: 0 for address in nodes}
|
||||
freqs = {address: 0.0 for address in nodes}
|
||||
for event in events:
|
||||
freqs[event.connection_id] += 1
|
||||
for address in freqs:
|
||||
|
||||
@ -20,6 +20,9 @@ import sys
|
||||
import time
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Set
|
||||
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -64,6 +67,8 @@ def session_ids(client):
|
||||
|
||||
|
||||
class TestSession(IntegrationTest):
|
||||
client2: MongoClient
|
||||
sensitive_commands: Set[str]
|
||||
|
||||
@classmethod
|
||||
@client_context.require_sessions
|
||||
@ -231,7 +236,7 @@ class TestSession(IntegrationTest):
|
||||
|
||||
def test_client(self):
|
||||
client = self.client
|
||||
ops = [
|
||||
ops: list = [
|
||||
(client.server_info, [], {}),
|
||||
(client.list_database_names, [], {}),
|
||||
(client.drop_database, ['pymongo_test'], {}),
|
||||
@ -242,7 +247,7 @@ class TestSession(IntegrationTest):
|
||||
def test_database(self):
|
||||
client = self.client
|
||||
db = client.pymongo_test
|
||||
ops = [
|
||||
ops: list = [
|
||||
(db.command, ['ping'], {}),
|
||||
(db.create_collection, ['collection'], {}),
|
||||
(db.list_collection_names, [], {}),
|
||||
@ -493,6 +498,7 @@ class TestSession(IntegrationTest):
|
||||
# Explicit session.
|
||||
with client.start_session() as s:
|
||||
cursor = bucket.find(session=s)
|
||||
assert cursor.session is not None
|
||||
s = cursor.session
|
||||
files = list(cursor)
|
||||
cursor.__del__()
|
||||
@ -680,7 +686,7 @@ class TestSession(IntegrationTest):
|
||||
self.addCleanup(client.close)
|
||||
db = client.pymongo_test
|
||||
coll = db.test_unacked_writes
|
||||
ops = [
|
||||
ops: list = [
|
||||
(client.drop_database, [db.name], {}),
|
||||
(db.create_collection, ['collection'], {}),
|
||||
(db.drop_collection, ['collection'], {}),
|
||||
@ -722,6 +728,8 @@ class TestSession(IntegrationTest):
|
||||
self.assertRaises(TypeError, lambda: copy.copy(s))
|
||||
|
||||
class TestCausalConsistency(unittest.TestCase):
|
||||
listener: SessionTestListener
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -778,6 +786,8 @@ class TestCausalConsistency(unittest.TestCase):
|
||||
self.assertRaises(ValueError, sess2.advance_cluster_time, {})
|
||||
self.assertRaises(TypeError, sess2.advance_operation_time, 1)
|
||||
# No error
|
||||
assert sess.cluster_time is not None
|
||||
assert sess.operation_time is not None
|
||||
sess2.advance_cluster_time(sess.cluster_time)
|
||||
sess2.advance_operation_time(sess.operation_time)
|
||||
self.assertEqual(sess.cluster_time, sess2.cluster_time)
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
import sys
|
||||
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -54,6 +55,7 @@ class SrvPollingKnobs(object):
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||
|
||||
def mock_get_hosts_and_min_ttl(resolver, *args):
|
||||
assert self.old_dns_resolver_response is not None
|
||||
nodes, ttl = self.old_dns_resolver_response(resolver)
|
||||
if self.nodelist_callback is not None:
|
||||
nodes = self.nodelist_callback()
|
||||
@ -61,20 +63,22 @@ class SrvPollingKnobs(object):
|
||||
ttl = self.ttl_time
|
||||
return nodes, ttl
|
||||
|
||||
patch_func: Any
|
||||
if self.count_resolver_calls:
|
||||
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
|
||||
else:
|
||||
patch_func = mock_get_hosts_and_min_ttl
|
||||
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
def disable(self):
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \
|
||||
self.old_dns_resolver_response
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||
self.old_dns_resolver_response # type: ignore
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.disable()
|
||||
@ -128,7 +132,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||
self.assertGreaterEqual(
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count,
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||
1, "resolver was never called")
|
||||
return True
|
||||
|
||||
@ -196,7 +200,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
self.run_scenario(response_callback, False)
|
||||
|
||||
def test_dns_record_lookup_empty(self):
|
||||
response = []
|
||||
response: list = []
|
||||
self.run_scenario(response, False)
|
||||
|
||||
def _test_recover_from_initial(self, initial_callback):
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -52,7 +53,7 @@ try:
|
||||
from pymongo.ocsp_support import _load_trusted_ca_certs
|
||||
_HAVE_PYOPENSSL = True
|
||||
except ImportError:
|
||||
_load_trusted_ca_certs = None
|
||||
_load_trusted_ca_certs = None # type: ignore
|
||||
|
||||
|
||||
if HAVE_SSL:
|
||||
@ -148,6 +149,7 @@ class TestClientSSL(unittest.TestCase):
|
||||
|
||||
|
||||
class TestSSL(IntegrationTest):
|
||||
saved_port: int
|
||||
|
||||
def assertClientWorks(self, client):
|
||||
coll = client.pymongo_test.ssl_test.with_options(
|
||||
@ -200,13 +202,13 @@ class TestSSL(IntegrationTest):
|
||||
tlsCertificateKeyFilePassword="qwerty",
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=5000,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
uri_fmt = ("mongodb://localhost/?ssl=true"
|
||||
"&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty"
|
||||
"&tlsCAFile=%s&serverSelectionTimeoutMS=5000")
|
||||
connected(MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM),
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@client_context.require_no_auth
|
||||
@ -313,7 +315,7 @@ class TestSSL(IntegrationTest):
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
connected(MongoClient('server',
|
||||
ssl=True,
|
||||
@ -322,7 +324,7 @@ class TestSSL(IntegrationTest):
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
if 'setName' in response:
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
@ -333,7 +335,7 @@ class TestSSL(IntegrationTest):
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
connected(MongoClient('server',
|
||||
replicaSet=response['setName'],
|
||||
@ -343,7 +345,7 @@ class TestSSL(IntegrationTest):
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@ignore_deprecations
|
||||
@ -362,7 +364,7 @@ class TestSSL(IntegrationTest):
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=100,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(MongoClient('localhost',
|
||||
@ -370,18 +372,18 @@ class TestSSL(IntegrationTest):
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsCRLFile=CRL_PEM,
|
||||
serverSelectionTimeoutMS=100,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
uri_fmt = ("mongodb://localhost/?ssl=true&"
|
||||
"tlsCAFile=%s&serverSelectionTimeoutMS=100")
|
||||
connected(MongoClient(uri_fmt % (CA_PEM,),
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
uri_fmt = ("mongodb://localhost/?ssl=true&tlsCRLFile=%s"
|
||||
"&tlsCAFile=%s&serverSelectionTimeoutMS=100")
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(MongoClient(uri_fmt % (CRL_PEM, CA_PEM),
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@client_context.require_server_resolvable
|
||||
@ -399,26 +401,26 @@ class TestSSL(IntegrationTest):
|
||||
connected(MongoClient('server',
|
||||
ssl=True,
|
||||
serverSelectionTimeoutMS=100,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
# Server cert is verified. Disable hostname matching.
|
||||
connected(MongoClient('server',
|
||||
ssl=True,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
serverSelectionTimeoutMS=100,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
# Server cert and hostname are verified.
|
||||
connected(MongoClient('localhost',
|
||||
ssl=True,
|
||||
serverSelectionTimeoutMS=100,
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
# Server cert and hostname are verified.
|
||||
connected(
|
||||
MongoClient(
|
||||
'mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=100',
|
||||
**self.credentials))
|
||||
**self.credentials)) # type: ignore
|
||||
|
||||
def test_system_certs_config_error(self):
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
@ -428,6 +430,7 @@ class TestSSL(IntegrationTest):
|
||||
raise SkipTest(
|
||||
"Can't test when system CA certificates are loadable.")
|
||||
|
||||
ssl_support: Any
|
||||
have_certifi = ssl_support.HAVE_CERTIFI
|
||||
have_wincertstore = ssl_support.HAVE_WINCERTSTORE
|
||||
# Force the test regardless of environment.
|
||||
@ -446,6 +449,7 @@ class TestSSL(IntegrationTest):
|
||||
# with SSLContext and SSLContext provides no information
|
||||
# about ca_certs.
|
||||
raise SkipTest("Can't test when SSLContext available.")
|
||||
ssl_support: Any
|
||||
if not ssl_support.HAVE_CERTIFI:
|
||||
raise SkipTest("Need certifi to test certifi support.")
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
# 1-15 millisecond resolution. We need to delay the initial hello
|
||||
# to ensure that RTT is never zero.
|
||||
name = 'streamingRttTest'
|
||||
delay_hello = {
|
||||
delay_hello: dict = {
|
||||
'configureFailPoint': 'failCommand',
|
||||
'mode': {'times': 1000},
|
||||
'data': {
|
||||
|
||||
@ -90,19 +90,19 @@ class TestTransactions(TransactionsBase):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
max_commit_time_ms=10000)
|
||||
with self.assertRaisesRegex(TypeError, "read_concern must be "):
|
||||
TransactionOptions(read_concern={})
|
||||
TransactionOptions(read_concern={}) # type: ignore
|
||||
with self.assertRaisesRegex(TypeError, "write_concern must be "):
|
||||
TransactionOptions(write_concern={})
|
||||
TransactionOptions(write_concern={}) # type: ignore
|
||||
with self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"transactions do not support unacknowledged write concern"):
|
||||
TransactionOptions(write_concern=WriteConcern(w=0))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "is not valid for read_preference"):
|
||||
TransactionOptions(read_preference={})
|
||||
TransactionOptions(read_preference={}) # type: ignore
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "max_commit_time_ms must be an integer or None"):
|
||||
TransactionOptions(max_commit_time_ms="10000")
|
||||
TransactionOptions(max_commit_time_ms="10000") # type: ignore
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_transaction_write_concern_override(self):
|
||||
@ -131,7 +131,7 @@ class TestTransactions(TransactionsBase):
|
||||
coll.find_one_and_replace({}, {}, session=s)
|
||||
coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s)
|
||||
|
||||
unsupported_txn_writes = [
|
||||
unsupported_txn_writes: list = [
|
||||
(client.drop_database, [db.name], {}),
|
||||
(db.drop_collection, ['collection'], {}),
|
||||
(coll.drop, [], {}),
|
||||
@ -284,7 +284,7 @@ class TestTransactions(TransactionsBase):
|
||||
InvalidOperation,
|
||||
'GridFS does not support multi-document transactions',
|
||||
):
|
||||
op(*args, session=s)
|
||||
op(*args, session=s) # type: ignore
|
||||
|
||||
# Require 4.2+ for large (16MB+) transactions.
|
||||
@client_context.require_version_min(4, 2)
|
||||
@ -360,11 +360,11 @@ class TestTransactionsConvenientAPI(TransactionsBase):
|
||||
|
||||
self.db.test.insert_one({})
|
||||
|
||||
def callback(session):
|
||||
def callback2(session):
|
||||
self.db.test.insert_one({}, session=session)
|
||||
return 'Foo'
|
||||
with self.client.start_session() as s:
|
||||
self.assertEqual(s.with_transaction(callback), 'Foo')
|
||||
self.assertEqual(s.with_transaction(callback2), 'Foo')
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_callback_not_retried_after_timeout(self):
|
||||
@ -375,7 +375,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
|
||||
|
||||
def callback(session):
|
||||
coll.insert_one({}, session=session)
|
||||
err = {
|
||||
err: dict = {
|
||||
'ok': 0,
|
||||
'errmsg': 'Transaction 7819 has been aborted.',
|
||||
'code': 251,
|
||||
|
||||
@ -186,7 +186,7 @@ class TestURI(unittest.TestCase):
|
||||
self.assertRaises(ValueError,
|
||||
parse_uri, "mongodb://::1", 27017)
|
||||
|
||||
orig = {
|
||||
orig: dict = {
|
||||
'nodelist': [("localhost", 27017)],
|
||||
'username': None,
|
||||
'password': None,
|
||||
@ -196,7 +196,7 @@ class TestURI(unittest.TestCase):
|
||||
'fqdn': None
|
||||
}
|
||||
|
||||
res = copy.deepcopy(orig)
|
||||
res: dict = copy.deepcopy(orig)
|
||||
self.assertEqual(res, parse_uri("mongodb://localhost"))
|
||||
|
||||
res.update({'username': 'fred', 'password': 'foobar'})
|
||||
|
||||
@ -66,7 +66,7 @@ class TestWriteConcern(unittest.TestCase):
|
||||
self.assertNotEqual(WriteConcern(wtimeout=42), _FakeWriteConcern(wtimeout=2000))
|
||||
|
||||
def test_equality_incompatible_type(self):
|
||||
_fake_type = collections.namedtuple('NotAWriteConcern', ['document'])
|
||||
_fake_type = collections.namedtuple('NotAWriteConcern', ['document']) # type: ignore
|
||||
self.assertNotEqual(WriteConcern(j=True), _fake_type({'j': True}))
|
||||
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ import time
|
||||
import types
|
||||
|
||||
from collections import abc
|
||||
from typing import Any
|
||||
|
||||
from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey
|
||||
from bson.binary import Binary
|
||||
@ -296,7 +297,7 @@ class EntityMapUtil(object):
|
||||
|
||||
entity_type, spec = next(iter(entity_spec.items()))
|
||||
if entity_type == 'client':
|
||||
kwargs = {}
|
||||
kwargs: dict = {}
|
||||
observe_events = spec.get('observeEvents', [])
|
||||
ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
|
||||
observe_sensitive_commands = spec.get(
|
||||
@ -691,6 +692,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
SCHEMA_VERSION = Version.from_string('1.5')
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
RUN_ON_SERVERLESS = True
|
||||
TEST_SPEC: Any
|
||||
|
||||
@staticmethod
|
||||
def should_run_on(run_on_spec):
|
||||
@ -1213,6 +1215,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
|
||||
class UnifiedSpecTestMeta(type):
|
||||
"""Metaclass for generating test classes."""
|
||||
TEST_SPEC: Any
|
||||
EXPECTED_FAILURES: Any
|
||||
|
||||
def __init__(cls, *args, **kwargs):
|
||||
super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs)
|
||||
|
||||
@ -1258,7 +1263,7 @@ def generate_test_classes(test_path, module=__name__, class_name_prefix='',
|
||||
"""Utility that creates the base class to use for test generation.
|
||||
This is needed to ensure that cls.TEST_SPEC is appropriately set when
|
||||
the metaclass __init__ is invoked."""
|
||||
class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)):
|
||||
class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore
|
||||
TEST_SPEC = test_spec
|
||||
EXPECTED_FAILURES = expected_failures
|
||||
return SpecTestBase
|
||||
|
||||
@ -226,7 +226,7 @@ class ServerEventListener(_ServerEventListener,
|
||||
"""Listens to Server events."""
|
||||
|
||||
|
||||
class ServerAndTopologyEventListener(ServerEventListener,
|
||||
class ServerAndTopologyEventListener(ServerEventListener, # type: ignore
|
||||
monitoring.TopologyListener):
|
||||
"""Listens to Server and Topology events."""
|
||||
|
||||
@ -519,7 +519,7 @@ def _mongo_client(host, port, authenticate=True, directConnection=None,
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or client_context.host
|
||||
port = port or client_context.port
|
||||
client_options = client_context.default_client_options.copy()
|
||||
client_options: dict = client_context.default_client_options.copy()
|
||||
if client_context.replica_set_name and not directConnection:
|
||||
client_options['replicaSet'] = client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
@ -678,7 +678,7 @@ def server_started_with_auth(client):
|
||||
try:
|
||||
command_line = get_command_line(client)
|
||||
except OperationFailure as e:
|
||||
msg = e.details.get('errmsg', '')
|
||||
msg = e.details.get('errmsg', '') # type: ignore
|
||||
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
|
||||
# Unauthorized.
|
||||
return True
|
||||
@ -818,8 +818,8 @@ class DeprecationFilter(object):
|
||||
|
||||
def stop(self):
|
||||
"""Stop filtering deprecations."""
|
||||
self.warn_context.__exit__()
|
||||
self.warn_context = None
|
||||
self.warn_context.__exit__() # type: ignore
|
||||
self.warn_context = None # type: ignore
|
||||
|
||||
|
||||
def get_pool(client):
|
||||
@ -862,23 +862,13 @@ def run_threads(collection, target):
|
||||
@contextlib.contextmanager
|
||||
def frequent_thread_switches():
|
||||
"""Make concurrency bugs more likely to manifest."""
|
||||
interval = None
|
||||
if not sys.platform.startswith('java'):
|
||||
if hasattr(sys, 'getswitchinterval'):
|
||||
interval = sys.getswitchinterval()
|
||||
sys.setswitchinterval(1e-6)
|
||||
else:
|
||||
interval = sys.getcheckinterval()
|
||||
sys.setcheckinterval(1)
|
||||
interval = sys.getswitchinterval()
|
||||
sys.setswitchinterval(1e-6)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if not sys.platform.startswith('java'):
|
||||
if hasattr(sys, 'setswitchinterval'):
|
||||
sys.setswitchinterval(interval)
|
||||
else:
|
||||
sys.setcheckinterval(interval)
|
||||
sys.setswitchinterval(interval)
|
||||
|
||||
|
||||
def lazy_client_trial(reset, target, test, get_client):
|
||||
@ -994,6 +984,7 @@ def assertion_context(msg):
|
||||
except AssertionError as exc:
|
||||
msg = '%s (%s)' % (exc, msg)
|
||||
exc_type, exc_val, exc_tb = sys.exc_info()
|
||||
assert exc_type is not None
|
||||
raise exc_type(exc_val).with_traceback(exc_tb)
|
||||
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ import functools
|
||||
import threading
|
||||
|
||||
from collections import abc
|
||||
from typing import List
|
||||
|
||||
from bson import decode, encode
|
||||
from bson.binary import Binary
|
||||
@ -40,7 +41,7 @@ from pymongo.write_concern import WriteConcern
|
||||
from test import (client_context,
|
||||
client_knobs,
|
||||
IntegrationTest)
|
||||
from test.utils import (camel_to_snake,
|
||||
from test.utils import (EventListener, camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
CompareType,
|
||||
CMAPListener,
|
||||
@ -86,6 +87,9 @@ class SpecRunnerThread(threading.Thread):
|
||||
|
||||
|
||||
class SpecRunner(IntegrationTest):
|
||||
mongos_clients: List
|
||||
knobs: client_knobs
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -105,7 +109,7 @@ class SpecRunner(IntegrationTest):
|
||||
def setUp(self):
|
||||
super(SpecRunner, self).setUp()
|
||||
self.targets = {}
|
||||
self.listener = None
|
||||
self.listener = None # type: ignore
|
||||
self.pool_listener = None
|
||||
self.server_listener = None
|
||||
self.maxDiff = None
|
||||
@ -219,6 +223,7 @@ class SpecRunner(IntegrationTest):
|
||||
ids = expected_result[res]
|
||||
if isinstance(ids, dict):
|
||||
ids = [ids[str(i)] for i in range(len(ids))]
|
||||
|
||||
self.assertEqual(ids, result.inserted_ids, prop)
|
||||
elif prop == "upserted_ids":
|
||||
# Convert indexes from strings to integers.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user