PYTHON-2676 Add load balancer tests in EVG (#625)

Add load balancer spec tests
Ensure LB supports retryable reads/writes
Add assertNumberConnectionsCheckedOut, createFindCursor, ignoreResultAndError
Add PoolClearedEvent.service_id and fix isClientError unified test assertion
This commit is contained in:
Shane Harvey 2021-05-27 15:05:26 -07:00 committed by GitHub
parent 21c92b13cf
commit 93ac5e0277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1024 additions and 151 deletions

View File

@ -411,6 +411,11 @@ functions:
if [ -n "${SETDEFAULTENCODING}" ]; then
export SETDEFAULTENCODING="${SETDEFAULTENCODING}"
fi
if [ -n "${test_loadbalancer}" ]; then
export TEST_LOADBALANCER=1
export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}"
export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}"
fi
PYTHON_BINARY=${PYTHON_BINARY} \
GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \
@ -788,6 +793,22 @@ functions:
-v \
--fault revoked
"run load-balancer":
- command: shell.exec
params:
script: |
DRIVERS_TOOLS=${DRIVERS_TOOLS} MONGODB_URI=${MONGODB_URI} bash ${DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh start
- command: expansions.update
params:
file: lb-expansion.yml
"stop load-balancer":
- command: shell.exec
params:
script: |
cd ${DRIVERS_TOOLS}/.evergreen
DRIVERS_TOOLS=${DRIVERS_TOOLS} bash ${DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh stop
"teardown_docker":
- command: shell.exec
params:
@ -1537,6 +1558,13 @@ tasks:
- func: "run aws auth test with aws EC2 credentials"
- func: "run aws ECS auth test"
- name: load-balancer-test
commands:
- func: "bootstrap mongo-orchestration"
vars:
TOPOLOGY: "sharded_cluster"
- func: "run load-balancer"
- func: "run tests"
# }}}
- name: "coverage-report"
tags: ["coverage"]
@ -1941,6 +1969,16 @@ axes:
variables:
ORCHESTRATION_FILE: "versioned-api-testing.json"
# Run load balancer tests?
- id: loadbalancer
display_name: "Load Balancer"
values:
- id: "enabled"
display_name: "Load Balancer"
variables:
test_loadbalancer: true
batchtime: 10080 # 7 days
buildvariants:
- matrix_name: "tests-all"
matrix_spec:
@ -2463,6 +2501,17 @@ buildvariants:
- name: "aws-auth-test-4.4"
- name: "aws-auth-test-latest"
- matrix_name: "load-balancer"
matrix_spec:
platform: ubuntu-18.04
mongodb-version: ["latest"]
auth-ssl: "*"
python-version: ["3.6", "3.9"]
loadbalancer: "*"
display_name: "Load Balancer ${platform} ${python-version} ${mongodb-version} ${auth-ssl}"
tasks:
- name: "load-balancer-test"
- matrix_name: "Release"
matrix_spec:
platform: [ubuntu-20.04, windows-64-vsMulti-small, macos-1014]

View File

@ -51,6 +51,11 @@ fi
if [ "$SSL" != "nossl" ]; then
export CLIENT_PEM="$DRIVERS_TOOLS/.evergreen/x509gen/client.pem"
export CA_PEM="$DRIVERS_TOOLS/.evergreen/x509gen/ca.pem"
if [ -n "$TEST_LOADBALANCER" ]; then
export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}&tls=true"
export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}&tls=true"
fi
fi
# For createvirtualenv.
@ -191,7 +196,12 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
# causing this script to exit.
$PYTHON -c "from bson import _cbson; from pymongo import _cmessage"
fi
$PYTHON $COVERAGE_ARGS setup.py $C_EXTENSIONS test $TEST_ARGS $OUTPUT
if [ -n "$TEST_LOADBALANCER" ]; then
$PYTHON -m xmlrunner discover -s test/load_balancer -v --locals -o $XUNIT_DIR
else
$PYTHON $COVERAGE_ARGS setup.py $C_EXTENSIONS test $TEST_ARGS $OUTPUT
fi
else
# --no_ext has to come before "test" so there is no way to toggle extensions here.
$PYTHON green_framework_test.py $GREEN_FRAMEWORK $OUTPUT

View File

@ -660,7 +660,7 @@ class ClientSession(object):
pass
finally:
self._transaction.state = _TxnState.ABORTED
self._unpin_mongos()
self._unpin()
def _finish_transaction_with_retry(self, command_name):
"""Run commit or abort with one retry after any retryable error.
@ -779,13 +779,13 @@ class ClientSession(object):
return self._transaction.pinned_address
return None
def _pin_mongos(self, server):
"""Pin this session to the given mongos Server."""
def _pin(self, server):
"""Pin this session to the given Server."""
self._transaction.sharded = True
self._transaction.pinned_address = server.description.address
def _unpin_mongos(self):
"""Unpin this session from any pinned mongos address."""
def _unpin(self):
"""Unpin this session from any pinned Server."""
self._transaction.pinned_address = None
def _txn_read_preference(self):
@ -906,9 +906,11 @@ class _ServerSessionPool(collections.deque):
return _ServerSession(self.generation)
def return_server_session(self, server_session, session_timeout_minutes):
self._clear_stale(session_timeout_minutes)
if not server_session.timed_out(session_timeout_minutes):
self.return_server_session_no_lock(server_session)
if session_timeout_minutes is not None:
self._clear_stale(session_timeout_minutes)
if server_session.timed_out(session_timeout_minutes):
return
self.return_server_session_no_lock(server_session)
def return_server_session_no_lock(self, server_session):
# Discard sessions from an old pool to avoid duplicate sessions in the

View File

@ -1197,15 +1197,16 @@ class MongoClient(common.BaseObject):
server = topology.select_server(server_selector)
# Pin this session to the selected server if it's performing a
# sharded transaction.
if server.description.mongos and (session and
session.in_transaction):
session._pin_mongos(server)
if (server.description.server_type in (
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)
and session and session.in_transaction):
session._pin(server)
return server
except PyMongoError as exc:
# Server selection errors in a transaction are transient.
if session and session.in_transaction:
exc._add_error_label("TransientTransactionError")
session._unpin_mongos()
session._unpin()
raise
def _socket_for_writes(self, session):
@ -1350,7 +1351,7 @@ class MongoClient(common.BaseObject):
_add_retryable_write_error(exc, max_wire_version)
retryable_error = exc.has_error_label("RetryableWriteError")
if retryable_error:
session._unpin_mongos()
session._unpin()
if is_retrying() or not retryable_error:
raise
if bulk:
@ -1965,7 +1966,7 @@ def _add_retryable_write_error(exc, max_wire_version):
class _MongoClientErrorHandler(object):
"""Handle errors raised when executing an operation."""
__slots__ = ('client', 'server_address', 'session', 'max_wire_version',
'sock_generation', 'completed_handshake')
'sock_generation', 'completed_handshake', 'service_id')
def __init__(self, client, server, session):
self.client = client
@ -1978,11 +1979,13 @@ class _MongoClientErrorHandler(object):
# of the pool at the time the connection attempt was started."
self.sock_generation = server.pool.generation
self.completed_handshake = False
self.service_id = None
def contribute_socket(self, sock_info):
"""Provide socket information to the error handler."""
self.max_wire_version = sock_info.max_wire_version
self.sock_generation = sock_info.generation
self.service_id = sock_info.service_id
self.completed_handshake = True
def __enter__(self):
@ -2001,9 +2004,9 @@ class _MongoClientErrorHandler(object):
if issubclass(exc_type, PyMongoError):
if (exc_val.has_error_label("TransientTransactionError") or
exc_val.has_error_label("RetryableWriteError")):
self.session._unpin_mongos()
self.session._unpin()
err_ctx = _ErrorContext(
exc_val, self.max_wire_version, self.sock_generation,
self.completed_handshake)
self.completed_handshake, self.service_id)
self.client._topology.handle_error(self.server_address, err_ctx)

View File

@ -512,13 +512,16 @@ _SENSITIVE_COMMANDS = set(
class _CommandEvent(object):
"""Base class for command events."""
__slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id")
__slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id",
"__service_id")
def __init__(self, command_name, request_id, connection_id, operation_id):
def __init__(self, command_name, request_id, connection_id, operation_id,
service_id=None):
self.__cmd_name = command_name
self.__rqst_id = request_id
self.__conn_id = connection_id
self.__op_id = operation_id
self.__service_id = service_id
@property
def command_name(self):
@ -535,6 +538,14 @@ class _CommandEvent(object):
"""The address (host, port) of the server this command was sent to."""
return self.__conn_id
@property
def service_id(self):
"""The service_id this command was sent to, or ``None``.
.. versionadded:: 3.12
"""
return self.__service_id
@property
def operation_id(self):
"""An id for this series of events or None."""
@ -551,15 +562,17 @@ class CommandStartedEvent(_CommandEvent):
- `connection_id`: The address (host, port) of the server this command
was sent to.
- `operation_id`: An optional identifier for a series of related events.
- `service_id`: The service_id this command was sent to, or ``None``.
"""
__slots__ = ("__cmd", "__db")
def __init__(self, command, database_name, *args):
def __init__(self, command, database_name, *args, service_id=None):
if not command:
raise ValueError("%r is not a valid command" % (command,))
# Command name must be first key.
command_name = next(iter(command))
super(CommandStartedEvent, self).__init__(command_name, *args)
super(CommandStartedEvent, self).__init__(
command_name, *args, service_id=service_id)
if command_name.lower() in _SENSITIVE_COMMANDS:
self.__cmd = {}
else:
@ -577,9 +590,12 @@ class CommandStartedEvent(_CommandEvent):
return self.__db
def __repr__(self):
return "<%s %s db: %r, command: %r, operation_id: %s>" % (
self.__class__.__name__, self.connection_id, self.database_name,
self.command_name, self.operation_id)
return (
"<%s %s db: %r, command: %r, operation_id: %s, "
"service_id: %s>") % (
self.__class__.__name__, self.connection_id,
self.database_name, self.command_name, self.operation_id,
self.service_id)
class CommandSucceededEvent(_CommandEvent):
@ -593,13 +609,15 @@ class CommandSucceededEvent(_CommandEvent):
- `connection_id`: The address (host, port) of the server this command
was sent to.
- `operation_id`: An optional identifier for a series of related events.
- `service_id`: The service_id this command was sent to, or ``None``.
"""
__slots__ = ("__duration_micros", "__reply")
def __init__(self, duration, reply, command_name,
request_id, connection_id, operation_id):
request_id, connection_id, operation_id, service_id=None):
super(CommandSucceededEvent, self).__init__(
command_name, request_id, connection_id, operation_id)
command_name, request_id, connection_id, operation_id,
service_id=service_id)
self.__duration_micros = _to_micros(duration)
if command_name.lower() in _SENSITIVE_COMMANDS:
self.__reply = {}
@ -617,9 +635,12 @@ class CommandSucceededEvent(_CommandEvent):
return self.__reply
def __repr__(self):
return "<%s %s command: %r, operation_id: %s, duration_micros: %s>" % (
self.__class__.__name__, self.connection_id,
self.command_name, self.operation_id, self.duration_micros)
return (
"<%s %s command: %r, operation_id: %s, duration_micros: %s, "
"service_id: %s>") % (
self.__class__.__name__, self.connection_id,
self.command_name, self.operation_id, self.duration_micros,
self.service_id)
class CommandFailedEvent(_CommandEvent):
@ -633,11 +654,12 @@ class CommandFailedEvent(_CommandEvent):
- `connection_id`: The address (host, port) of the server this command
was sent to.
- `operation_id`: An optional identifier for a series of related events.
- `service_id`: The service_id this command was sent to, or ``None``.
"""
__slots__ = ("__duration_micros", "__failure")
def __init__(self, duration, failure, *args):
super(CommandFailedEvent, self).__init__(*args)
def __init__(self, duration, failure, *args, service_id=None):
super(CommandFailedEvent, self).__init__(*args, service_id=service_id)
self.__duration_micros = _to_micros(duration)
self.__failure = failure
@ -654,9 +676,10 @@ class CommandFailedEvent(_CommandEvent):
def __repr__(self):
return (
"<%s %s command: %r, operation_id: %s, duration_micros: %s, "
"failure: %r>" % (
"failure: %r, service_id: %s>") % (
self.__class__.__name__, self.connection_id, self.command_name,
self.operation_id, self.duration_micros, self.failure))
self.operation_id, self.duration_micros, self.failure,
self.service_id)
class _PoolEvent(object):
@ -721,10 +744,29 @@ class PoolClearedEvent(_PoolEvent):
:Parameters:
- `address`: The address (host, port) pair of the server this Pool is
attempting to connect to.
- `service_id`: The service_id this command was sent to, or ``None``.
.. versionadded:: 3.9
"""
__slots__ = ()
__slots__ = ("__service_id",)
def __init__(self, address, service_id=None):
super(PoolClearedEvent, self).__init__(address)
self.__service_id = service_id
@property
def service_id(self):
"""Connections with this service_id are cleared.
When service_id is ``None``, all connections in the pool are cleared.
.. versionadded:: 3.12
"""
return self.__service_id
def __repr__(self):
return '%s(%r, %r)' % (
self.__class__.__name__, self.address, self.__service_id)
class PoolClosedEvent(_PoolEvent):
@ -1508,10 +1550,10 @@ class _EventListeners(object):
except Exception:
_handle_exception()
def publish_pool_cleared(self, address):
def publish_pool_cleared(self, address, service_id):
"""Publish a :class:`PoolClearedEvent` to all pool listeners.
"""
event = PoolClearedEvent(address)
event = PoolClearedEvent(address, service_id)
for subscriber in self.__cmap_listeners:
try:
subscriber.pool_cleared(event)

View File

@ -1129,7 +1129,7 @@ class Pool:
def closed(self):
return self.state == PoolState.CLOSED
def _reset(self, close, pause=True):
def _reset(self, close, pause=True, service_id=None):
old_state = self.state
with self.size_cond:
if self.closed:
@ -1161,7 +1161,8 @@ class Pool:
listeners.publish_pool_closed(self.address)
else:
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
listeners.publish_pool_cleared(self.address)
listeners.publish_pool_cleared(self.address,
service_id=service_id)
for sock_info in sockets:
sock_info.close_socket(ConnectionClosedReason.STALE)
@ -1174,8 +1175,8 @@ class Pool:
for socket in self.sockets:
socket.update_is_writable(self.is_writable)
def reset(self):
self._reset(close=False)
def reset(self, service_id=None):
self._reset(close=False, service_id=service_id)
def reset_without_pause(self):
self._reset(close=False, pause=False)

View File

@ -49,9 +49,9 @@ class Server(object):
if not self._pool.opts.load_balanced:
self._monitor.open()
def reset(self):
def reset(self, service_id=None):
"""Clear the connection pool."""
self.pool.reset()
self.pool.reset(service_id)
def close(self):
"""Clear the connection pool and stop the monitor.

View File

@ -204,10 +204,10 @@ class ServerDescription(object):
@property
def retryable_writes_supported(self):
"""Checks if this server supports retryable writes."""
return (
return ((
self._ls_timeout_minutes is not None and
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary,
SERVER_TYPE.LoadBalancer))
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary))
or self._server_type == SERVER_TYPE.LoadBalancer)
@property
def retryable_reads_supported(self):

View File

@ -453,7 +453,7 @@ class Topology(object):
try:
server.pool.remove_stale_sockets(generation, all_credentials)
except PyMongoError as exc:
ctx = _ErrorContext(exc, 0, generation, False)
ctx = _ErrorContext(exc, 0, generation, False, None)
self.handle_error(server.description.address, ctx)
raise
@ -528,11 +528,9 @@ class Topology(object):
def return_server_session(self, server_session, lock):
if lock:
with self._lock:
session_timeout = \
self._description.logical_session_timeout_minutes
if session_timeout is not None:
self._session_pool.return_server_session(server_session,
session_timeout)
self._session_pool.return_server_session(
server_session,
self._description.logical_session_timeout_minutes)
else:
# Called from a __del__ method, can't use a lock.
self._session_pool.return_server_session_no_lock(server_session)
@ -566,7 +564,8 @@ class Topology(object):
# Emit initial SDAM events for load balancer mode.
self._process_change(ServerDescription(
self._seed_addresses[0],
IsMaster({'ok': 1, 'serviceId': self._topology_id})))
IsMaster({'ok': 1, 'serviceId': self._topology_id,
'maxWireVersion': 13})))
# Ensure that the monitors are open.
for server in self._servers.values():
@ -599,6 +598,7 @@ class Topology(object):
server = self._servers[address]
error = err_ctx.error
exc_type = type(error)
service_id = err_ctx.service_id
if (issubclass(exc_type, NetworkTimeout) and
err_ctx.completed_handshake):
# The socket has been closed. Don't reset the server.
@ -629,21 +629,21 @@ class Topology(object):
self._process_change(ServerDescription(address, error=error))
if is_shutting_down or (err_ctx.max_wire_version <= 7):
# Clear the pool.
server.reset()
server.reset(service_id)
server.request_check()
elif not err_ctx.completed_handshake:
# Unknown command error during the connection handshake.
if not self._settings.load_balanced:
self._process_change(ServerDescription(address, error=error))
# Clear the pool.
server.reset()
server.reset(service_id)
elif issubclass(exc_type, ConnectionFailure):
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."
if not self._settings.load_balanced:
self._process_change(ServerDescription(address, error=error))
# Clear the pool.
server.reset()
server.reset(service_id)
# "When a client marks a server Unknown from `Network error when
# reading or writing`_, clients MUST cancel the isMaster check on
# that server and close the current monitoring connection."
@ -795,11 +795,12 @@ class Topology(object):
class _ErrorContext(object):
"""An error with context for SDAM error handling."""
def __init__(self, error, max_wire_version, sock_generation,
completed_handshake):
completed_handshake, service_id):
self.error = error
self.max_wire_version = max_wire_version
self.sock_generation = sock_generation
self.completed_handshake = completed_handshake
self.service_id = service_id
def _is_stale_error_topology_version(current_tv, error_tv):

View File

@ -50,6 +50,7 @@ from pymongo import common, message
from pymongo.common import partition_node
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, validate_cert_reqs
from pymongo.uri_parser import parse_uri
from test.version import Version
if HAVE_SSL:
@ -92,6 +93,14 @@ if CA_PEM:
COMPRESSORS = os.environ.get("COMPRESSORS")
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
if TEST_LOADBALANCER:
res = parse_uri(SINGLE_MONGOS_LB_URI)
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
def is_server_resolvable():
@ -190,6 +199,7 @@ def _all_users(db):
class ClientContext(object):
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
def __init__(self):
"""Create a client and grab essential information from the server."""
@ -216,7 +226,9 @@ class ClientContext(object):
self.client = None
self.conn_lock = threading.Lock()
self.is_data_lake = False
self.load_balancer = False
self.load_balancer = TEST_LOADBALANCER
if self.load_balancer:
self.default_client_options["loadBalanced"] = True
if COMPRESSORS:
self.default_client_options["compressors"] = COMPRESSORS
if MONGODB_API_VERSION:
@ -632,8 +644,10 @@ class ClientContext(object):
func=func)
def is_topology_type(self, topologies):
if 'load-balanced' in topologies and self.load_balancer:
return True
if self.load_balancer:
if 'load-balanced' in topologies:
return True
return False
if 'single' in topologies and not (self.is_mongos or self.is_rs):
return True
if 'replicaset' in topologies and self.is_rs:

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_crud_unified import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_dns import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,34 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Load Balancer unified spec tests."""
import os
import sys
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
# Location of JSON test specifications.
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'unified')
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_change_stream import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_retryable_reads import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_retryable_writes import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_transactions_unified import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_uri_spec import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
sys.path[0:0] = [""]
from test.test_versioned_api import *
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,184 @@
{
"description": "monitoring events include correct fields",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"topologies": [
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": true,
"uriOptions": {
"retryReads": false
},
"observeEvents": [
"commandStartedEvent",
"commandSucceededEvent",
"commandFailedEvent",
"poolClearedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "database0"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "coll0"
}
}
],
"initialData": [
{
"databaseName": "database0",
"collectionName": "coll0",
"documents": []
}
],
"tests": [
{
"description": "command started and succeeded events include serviceId",
"operations": [
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"x": 1
}
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "insert",
"hasServiceId": true
}
},
{
"commandSucceededEvent": {
"commandName": "insert",
"hasServiceId": true
}
}
]
}
]
},
{
"description": "command failed events include serviceId",
"operations": [
{
"name": "find",
"object": "collection0",
"arguments": {
"filter": {
"$or": true
}
},
"expectError": {
"isError": true
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find",
"hasServiceId": true
}
},
{
"commandFailedEvent": {
"commandName": "find",
"hasServiceId": true
}
}
]
}
]
},
{
"description": "poolClearedEvent events include serviceId",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"closeConnection": true
}
}
}
},
{
"name": "find",
"object": "collection0",
"arguments": {
"filter": {}
},
"expectError": {
"isClientError": true
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find",
"hasServiceId": true
}
},
{
"commandFailedEvent": {
"commandName": "find",
"hasServiceId": true
}
}
]
},
{
"client": "client0",
"eventType": "cmap",
"events": [
{
"poolClearedEvent": {
"hasServiceId": true
}
}
]
}
]
}
]
}

View File

@ -0,0 +1,58 @@
{
"description": "connection establishment for load-balanced clusters",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"topologies": [
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"uriOptions": {
"loadBalanced": false
},
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "database0"
}
}
],
"tests": [
{
"description": "operations against load balancers fail if URI contains loadBalanced=false",
"skipReason": "servers have not implemented LB support yet so they will not fail the connection handshake in this case",
"operations": [
{
"name": "runCommand",
"object": "database0",
"arguments": {
"commandName": "ping",
"command": {
"ping": 1
}
},
"expectError": {
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,92 @@
{
"description": "connection establishment if loadBalanced is specified for non-load balanced clusters",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"topologies": [
"single",
"sharded"
]
}
],
"createEntities": [
{
"client": {
"id": "lbTrueClient",
"useMultipleMongoses": false,
"uriOptions": {
"loadBalanced": true
}
}
},
{
"database": {
"id": "lbTrueDatabase",
"client": "lbTrueClient",
"databaseName": "lbTrueDb"
}
},
{
"client": {
"id": "lbFalseClient",
"uriOptions": {
"loadBalanced": false
}
}
},
{
"database": {
"id": "lbFalseDatabase",
"client": "lbFalseClient",
"databaseName": "lbFalseDb"
}
}
],
"_yamlAnchors": {
"runCommandArguments": [
{
"arguments": {
"commandName": "ping",
"command": {
"ping": 1
}
}
}
]
},
"tests": [
{
"description": "operations against non-load balanced clusters fail if URI contains loadBalanced=true",
"operations": [
{
"name": "runCommand",
"object": "lbTrueDatabase",
"arguments": {
"commandName": "ping",
"command": {
"ping": 1
}
},
"expectError": {
"errorContains": "Driver attempted to initialize in load balancing mode, but the server does not support this mode"
}
}
]
},
{
"description": "operations against non-load balanced clusters succeed if URI contains loadBalanced=false",
"operations": [
{
"name": "runCommand",
"object": "lbFalseDatabase",
"arguments": {
"commandName": "ping",
"command": {
"ping": 1
}
}
}
]
}
]
}

View File

@ -0,0 +1,82 @@
{
"description": "server selection for load-balanced clusters",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"topologies": [
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": true,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "database0Name"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "coll0",
"collectionOptions": {
"readPreference": {
"mode": "secondaryPreferred"
}
}
}
}
],
"initialData": [
{
"collectionName": "coll0",
"databaseName": "database0Name",
"documents": []
}
],
"tests": [
{
"description": "$readPreference is sent for load-balanced clusters",
"operations": [
{
"name": "find",
"object": "collection0",
"arguments": {
"filter": {}
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"find": "coll0",
"filter": {},
"$readPreference": {
"mode": "secondaryPreferred"
}
},
"commandName": "find",
"databaseName": "database0Name"
}
}
]
}
]
}
]
}

View File

@ -1483,7 +1483,7 @@ class TestClient(IntegrationTest):
def run(self):
while self.running:
exc = AutoReconnect('mock pool error')
ctx = _ErrorContext(exc, 0, pool.generation, False)
ctx = _ErrorContext(exc, 0, pool.generation, False, None)
client._topology.handle_error(pool.address, ctx)
time.sleep(0.001)

View File

@ -22,6 +22,7 @@ import time
sys.path[0:0] = [""]
from bson.son import SON
from bson.objectid import ObjectId
from pymongo.errors import (ConnectionFailure,
OperationFailure,
@ -422,6 +423,7 @@ class TestCMAP(IntegrationTest):
self.assertRepr(ConnectionCheckOutStartedEvent(host))
self.assertRepr(PoolCreatedEvent(host, {}))
self.assertRepr(PoolClearedEvent(host))
self.assertRepr(PoolClearedEvent(host, service_id=ObjectId()))
self.assertRepr(PoolClosedEvent(host))
def test_close_leaves_pool_unpaused(self):

View File

@ -113,7 +113,7 @@ def got_app_error(topology, app_error):
topology.handle_error(
server_address, _ErrorContext(e, max_wire_version, generation,
completed_handshake))
completed_handshake, None))
def get_type(topology, hostname):

View File

@ -1183,7 +1183,7 @@ class TestEventClasses(PyMongoTestCase):
self.assertEqual(
repr(event),
"<CommandStartedEvent ('localhost', 27017) db: 'admin', "
"command: 'isMaster', operation_id: 2>")
"command: 'isMaster', operation_id: 2, service_id: None>")
delta = datetime.timedelta(milliseconds=100)
event = monitoring.CommandSucceededEvent(
delta, {'ok': 1}, 'isMaster', request_id, connection_id,
@ -1191,7 +1191,8 @@ class TestEventClasses(PyMongoTestCase):
self.assertEqual(
repr(event),
"<CommandSucceededEvent ('localhost', 27017) "
"command: 'isMaster', operation_id: 2, duration_micros: 100000>")
"command: 'isMaster', operation_id: 2, duration_micros: 100000, "
"service_id: None>")
event = monitoring.CommandFailedEvent(
delta, {'ok': 0}, 'isMaster', request_id, connection_id,
operation_id)
@ -1199,7 +1200,7 @@ class TestEventClasses(PyMongoTestCase):
repr(event),
"<CommandFailedEvent ('localhost', 27017) "
"command: 'isMaster', operation_id: 2, duration_micros: 100000, "
"failure: {'ok': 0}>")
"failure: {'ok': 0}, service_id: None>")
def test_server_heartbeat_event_repr(self):
connection_id = ('localhost', 27017)

View File

@ -401,7 +401,7 @@ class TestMultiServerTopology(TopologyTest):
'setName': 'rs',
'hosts': ['a', 'b']})
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True)
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True, None)
t.handle_error(('a', 27017), errctx)
self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a'))
self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b'))
@ -430,7 +430,7 @@ class TestMultiServerTopology(TopologyTest):
t = create_mock_topology(replica_set_name='rs')
# No error resetting a server not in the TopologyDescription.
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True)
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True, None)
t.handle_error(('b', 27017), errctx)
# Server was *not* added as type Unknown.

View File

@ -39,10 +39,16 @@ from pymongo.client_session import ClientSession, TransactionOptions, _TxnState
from pymongo.change_stream import ChangeStream
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import BulkWriteError, InvalidOperation, PyMongoError
from pymongo.errors import (
BulkWriteError, ConnectionFailure, InvalidOperation, NotMasterError,
PyMongoError)
from pymongo.monitoring import (
CommandFailedEvent, CommandListener, CommandStartedEvent,
CommandSucceededEvent, _SENSITIVE_COMMANDS)
CommandSucceededEvent, _SENSITIVE_COMMANDS, PoolCreatedEvent,
PoolReadyEvent, PoolClearedEvent, PoolClosedEvent, ConnectionCreatedEvent,
ConnectionReadyEvent, ConnectionClosedEvent,
ConnectionCheckOutStartedEvent, ConnectionCheckOutFailedEvent,
ConnectionCheckedOutEvent, ConnectionCheckedInEvent)
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult
@ -51,7 +57,8 @@ from pymongo.write_concern import WriteConcern
from test import client_context, unittest, IntegrationTest
from test.utils import (
camel_to_snake, rs_or_single_client, single_client, snake_to_camel)
camel_to_snake, get_pool, rs_or_single_client, single_client,
snake_to_camel, CMAPListener)
from test.version import Version
from test.utils import (
@ -142,28 +149,52 @@ def parse_bulk_write_error_result(error):
return parse_bulk_write_result(write_result)
class EventListenerUtil(CommandListener):
class NonLazyCursor(object):
"""A find cursor proxy that creates the remote cursor when initialized."""
def __init__(self, find_cursor):
self.find_cursor = find_cursor
# Create the server side cursor.
self.first_result = next(find_cursor, None)
def __next__(self):
if self.first_result is not None:
first = self.first_result
self.first_result = None
return first
return next(self.find_cursor)
def close(self):
self.find_cursor.close()
class EventListenerUtil(CMAPListener, CommandListener):
def __init__(self, observe_events, ignore_commands):
self._event_types = set(observe_events)
self._event_types = set(name.lower() for name in observe_events)
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
self._ignore_commands.add('configurefailpoint')
self.results = []
super(EventListenerUtil, self).__init__()
def _observe_event(self, event):
def get_events(self, event_type):
if event_type == 'command':
return [e for e in self.events if 'Command' in type(e).__name__]
return [e for e in self.events if 'Command' not in type(e).__name__]
def add_event(self, event):
if type(event).__name__.lower() in self._event_types:
super(EventListenerUtil, self).add_event(event)
def _command_event(self, event):
if event.command_name.lower() not in self._ignore_commands:
self.results.append(event)
self.add_event(event)
def started(self, event):
if 'commandStartedEvent' in self._event_types:
self._observe_event(event)
self._command_event(event)
def succeeded(self, event):
if 'commandSucceededEvent' in self._event_types:
self._observe_event(event)
self._command_event(event)
def failed(self, event):
if 'commandFailedEvent' in self._event_types:
self._observe_event(event)
self._command_event(event)
class EntityMapUtil(object):
@ -173,28 +204,28 @@ class EntityMapUtil(object):
self._entities = {}
self._listeners = {}
self._session_lsids = {}
self._test_class = test_class
self.test = test_class
def __getitem__(self, item):
try:
return self._entities[item]
except KeyError:
self._test_class.fail('Could not find entity named %s in map' % (
self.test.fail('Could not find entity named %s in map' % (
item,))
def __setitem__(self, key, value):
if not isinstance(key, str):
self._test_class.fail(
self.test.fail(
'Expected entity name of type str, got %s' % (type(key)))
if key in self._entities:
self._test_class.fail('Entity named %s already in map' % (key,))
self.test.fail('Entity named %s already in map' % (key,))
self._entities[key] = value
def _create_entity(self, entity_spec):
if len(entity_spec) != 1:
self._test_class.fail(
self.test.fail(
"Entity spec %s did not contain exactly one top-level key" % (
entity_spec,))
@ -203,13 +234,17 @@ class EntityMapUtil(object):
kwargs = {}
observe_events = spec.get('observeEvents', [])
ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
# TODO: SUPPORT storeEventsAsEntities
if len(observe_events) or len(ignore_commands):
ignore_commands = [cmd.lower() for cmd in ignore_commands]
listener = EventListenerUtil(observe_events, ignore_commands)
self._listeners[spec['id']] = listener
kwargs['event_listeners'] = [listener]
if client_context.is_mongos and spec.get('useMultipleMongoses'):
kwargs['h'] = client_context.mongos_seeds()
if spec.get('useMultipleMongoses'):
if client_context.load_balancer:
kwargs['h'] = client_context.MULTI_MONGOS_LB_URI
elif client_context.is_mongos:
kwargs['h'] = client_context.mongos_seeds()
kwargs.update(spec.get('uriOptions', {}))
server_api = spec.get('serverApi')
if server_api:
@ -218,12 +253,12 @@ class EntityMapUtil(object):
deprecation_errors=server_api.get('deprecationErrors'))
client = rs_or_single_client(**kwargs)
self[spec['id']] = client
self._test_class.addCleanup(client.close)
self.test.addCleanup(client.close)
return
elif entity_type == 'database':
client = self[spec['client']]
if not isinstance(client, MongoClient):
self._test_class.fail(
self.test.fail(
'Expected entity %s to be of type MongoClient, got %s' % (
spec['client'], type(client)))
options = parse_collection_or_database_options(
@ -234,7 +269,7 @@ class EntityMapUtil(object):
elif entity_type == 'collection':
database = self[spec['database']]
if not isinstance(database, Database):
self._test_class.fail(
self.test.fail(
'Expected entity %s to be of type Database, got %s' % (
spec['database'], type(database)))
options = parse_collection_or_database_options(
@ -245,7 +280,7 @@ class EntityMapUtil(object):
elif entity_type == 'session':
client = self[spec['client']]
if not isinstance(client, MongoClient):
self._test_class.fail(
self.test.fail(
'Expected entity %s to be of type MongoClient, got %s' % (
spec['client'], type(client)))
opts = camel_to_snake_args(spec.get('sessionOptions', {}))
@ -258,13 +293,13 @@ class EntityMapUtil(object):
session = client.start_session(**dict(opts))
self[spec['id']] = session
self._session_lsids[spec['id']] = copy.deepcopy(session.session_id)
self._test_class.addCleanup(session.end_session)
self.test.addCleanup(session.end_session)
return
elif entity_type == 'bucket':
# TODO: implement the 'bucket' entity type
self._test_class.skipTest(
self.test.skipTest(
'GridFS is not currently supported (PYTHON-2459)')
self._test_class.fail(
self.test.fail(
'Unable to create entity of unknown type %s' % (entity_type,))
def create_entities_from_spec(self, entity_spec):
@ -274,13 +309,13 @@ class EntityMapUtil(object):
def get_listener_for_client(self, client_name):
client = self[client_name]
if not isinstance(client, MongoClient):
self._test_class.fail(
self.test.fail(
'Expected entity %s to be of type MongoClient, got %s' % (
client_name, type(client)))
listener = self._listeners.get(client_name)
if not listener:
self._test_class.fail(
self.test.fail(
'No listeners configured for client %s' % (client_name,))
return listener
@ -288,7 +323,7 @@ class EntityMapUtil(object):
def get_lsid_for_session(self, session_name):
session = self[session_name]
if not isinstance(session, ClientSession):
self._test_class.fail(
self.test.fail(
'Expected entity %s to be of type ClientSession, got %s' % (
session_name, type(session)))
@ -334,21 +369,21 @@ class MatchEvaluatorUtil(object):
"""Utility class that implements methods for evaluating matches as per
the unified test format specification."""
def __init__(self, test_class):
self._test_class = test_class
self.test = test_class
def _operation_exists(self, spec, actual, key_to_compare):
if spec is True:
self._test_class.assertIn(key_to_compare, actual)
self.test.assertIn(key_to_compare, actual)
elif spec is False:
self._test_class.assertNotIn(key_to_compare, actual)
self.test.assertNotIn(key_to_compare, actual)
else:
self._test_class.fail(
self.test.fail(
'Expected boolean value for $$exists operator, got %s' % (
spec,))
def __type_alias_to_type(self, alias):
if alias not in BSON_TYPE_ALIAS_MAP:
self._test_class.fail('Unrecognized BSON type alias %s' % (alias,))
self.test.fail('Unrecognized BSON type alias %s' % (alias,))
return BSON_TYPE_ALIAS_MAP[alias]
def _operation_type(self, spec, actual, key_to_compare):
@ -357,13 +392,13 @@ class MatchEvaluatorUtil(object):
t for alias in spec for t in self.__type_alias_to_type(alias)])
else:
permissible_types = self.__type_alias_to_type(spec)
self._test_class.assertIsInstance(
self.test.assertIsInstance(
actual[key_to_compare], permissible_types)
def _operation_matchesEntity(self, spec, actual, key_to_compare):
expected_entity = self._test_class.entity_map[spec]
self._test_class.assertIsInstance(expected_entity, abc.Mapping)
self._test_class.assertEqual(expected_entity, actual[key_to_compare])
expected_entity = self.test.entity_map[spec]
self.test.assertIsInstance(expected_entity, abc.Mapping)
self.test.assertEqual(expected_entity, actual[key_to_compare])
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
raise NotImplementedError
@ -380,8 +415,8 @@ class MatchEvaluatorUtil(object):
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
def _operation_sessionLsid(self, spec, actual, key_to_compare):
expected_lsid = self._test_class.entity_map.get_lsid_for_session(spec)
self._test_class.assertEqual(expected_lsid, actual[key_to_compare])
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
self.test.assertEqual(expected_lsid, actual[key_to_compare])
def _evaluate_special_operation(self, opname, spec, actual,
key_to_compare):
@ -389,7 +424,7 @@ class MatchEvaluatorUtil(object):
try:
method = getattr(self, method_name)
except AttributeError:
self._test_class.fail(
self.test.fail(
'Unsupported special matching operator %s' % (opname,))
else:
method(spec, actual, key_to_compare)
@ -440,16 +475,16 @@ class MatchEvaluatorUtil(object):
if self._evaluate_if_special_operation(expectation, actual):
return
self._test_class.assertIsInstance(actual, abc.Mapping)
self.test.assertIsInstance(actual, abc.Mapping)
for key, value in expectation.items():
if self._evaluate_if_special_operation(expectation, actual, key):
continue
self._test_class.assertIn(key, actual)
self.test.assertIn(key, actual)
self.match_result(value, actual[key], in_recursive_call=True)
if not is_root:
self._test_class.assertEqual(
self.test.assertEqual(
set(expectation.keys()), set(actual.keys()))
def match_result(self, expectation, actual,
@ -459,7 +494,7 @@ class MatchEvaluatorUtil(object):
expectation, actual, is_root=not in_recursive_call)
if isinstance(expectation, abc.MutableSequence):
self._test_class.assertIsInstance(actual, abc.MutableSequence)
self.test.assertIsInstance(actual, abc.MutableSequence)
for e, a in zip(expectation, actual):
if isinstance(e, abc.Mapping):
self._match_document(
@ -471,21 +506,22 @@ class MatchEvaluatorUtil(object):
# account for flexible numerics in element-wise comparison
if (isinstance(expectation, int) or
isinstance(expectation, float)):
self._test_class.assertEqual(expectation, actual)
self.test.assertEqual(expectation, actual)
else:
self._test_class.assertIsInstance(actual, type(expectation))
self._test_class.assertEqual(expectation, actual)
self.test.assertIsInstance(actual, type(expectation))
self.test.assertEqual(expectation, actual)
def match_event(self, expectation, actual):
event_type, spec = next(iter(expectation.items()))
def match_event(self, event_type, expectation, actual):
name, spec = next(iter(expectation.items()))
# every event type has the commandName field
command_name = spec.get('commandName')
if command_name:
self._test_class.assertEqual(command_name, actual.command_name)
# every command event has the commandName field
if event_type == 'command':
command_name = spec.get('commandName')
if command_name:
self.test.assertEqual(command_name, actual.command_name)
if event_type == 'commandStartedEvent':
self._test_class.assertIsInstance(actual, CommandStartedEvent)
if name == 'commandStartedEvent':
self.test.assertIsInstance(actual, CommandStartedEvent)
command = spec.get('command')
database_name = spec.get('databaseName')
if command:
@ -497,18 +533,47 @@ class MatchEvaluatorUtil(object):
update.setdefault('multi', False)
self.match_result(command, actual.command)
if database_name:
self._test_class.assertEqual(
self.test.assertEqual(
database_name, actual.database_name)
elif event_type == 'commandSucceededEvent':
self._test_class.assertIsInstance(actual, CommandSucceededEvent)
elif name == 'commandSucceededEvent':
self.test.assertIsInstance(actual, CommandSucceededEvent)
reply = spec.get('reply')
if reply:
self.match_result(reply, actual.reply)
elif event_type == 'commandFailedEvent':
self._test_class.assertIsInstance(actual, CommandFailedEvent)
elif name == 'commandFailedEvent':
self.test.assertIsInstance(actual, CommandFailedEvent)
elif name == 'poolCreatedEvent':
self.test.assertIsInstance(actual, PoolCreatedEvent)
elif name == 'poolReadyEvent':
self.test.assertIsInstance(actual, PoolReadyEvent)
elif name == 'poolClearedEvent':
self.test.assertIsInstance(actual, PoolClearedEvent)
if spec.get('hasServiceId'):
self.test.assertIsNotNone(actual.service_id)
self.test.assertIsInstance(actual.service_id, ObjectId)
else:
self.test.assertIsNone(actual.service_id)
elif name == 'poolClosedEvent':
self.test.assertIsInstance(actual, PoolClosedEvent)
elif name == 'connectionCreatedEvent':
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
elif name == 'connectionReadyEvent':
self.test.assertIsInstance(actual, ConnectionReadyEvent)
elif name == 'connectionClosedEvent':
self.test.assertIsInstance(actual, ConnectionClosedEvent)
self.test.assertEqual(actual.reason, spec['reason'])
elif name == 'connectionCheckOutStartedEvent':
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
elif name == 'connectionCheckOutFailedEvent':
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
self.test.assertEqual(actual.reason, spec['reason'])
elif name == 'connectionCheckedOutEvent':
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
elif name == 'connectionCheckedInEvent':
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
else:
self._test_class.fail(
'Unsupported event type %s' % (event_type,))
self.test.fail(
'Unsupported event type %s' % (name,))
def coerce_result(opname, result):
@ -623,7 +688,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
pass
if is_client_error:
self.assertNotIsInstance(exception, PyMongoError)
# Connection errors are considered client errors.
if isinstance(exception, ConnectionFailure):
self.assertNotIsInstance(exception, NotMasterError)
else:
self.assertNotIsInstance(exception, PyMongoError)
if error_contains:
if isinstance(exception, BulkWriteError):
@ -692,6 +761,12 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
kwargs['command'] = ordered_command
return target.command(**kwargs)
def _databaseOperation_listCollections(self, target, *args, **kwargs):
if 'batch_size' in kwargs:
kwargs['cursor'] = {'batchSize': kwargs.pop('batch_size')}
cursor = target.list_collections(*args, **kwargs)
return list(cursor)
def __entityOperation_aggregate(self, target, *args, **kwargs):
self.__raise_if_unsupported('aggregate', target, Database, Collection)
return list(target.aggregate(*args, **kwargs))
@ -707,6 +782,16 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
find_cursor = target.find(*args, **kwargs)
return list(find_cursor)
def _collectionOperation_createFindCursor(self, target, *args, **kwargs):
self.__raise_if_unsupported('find', target, Collection)
return NonLazyCursor(target.find(*args, **kwargs))
def _collectionOperation_listIndexes(self, target, *args, **kwargs):
if 'batch_size' in kwargs:
self.skipTest('PyMongo does not support batch_size for '
'list_indexes')
return target.list_indexes(*args, **kwargs)
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
if client_context.storage_engine == 'mmapv1':
self.skipTest('MMAPv1 does not support document-level locking')
@ -725,11 +810,27 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
'iterateUntilDocumentOrError', target, ChangeStream)
return next(target)
def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs):
self.__raise_if_unsupported(
'iterateUntilDocumentOrError', target, NonLazyCursor)
return next(target)
def _cursor_close(self, target, *args, **kwargs):
self.__raise_if_unsupported('close', target, NonLazyCursor)
return target.close()
def run_entity_operation(self, spec):
target = self.entity_map[spec['object']]
opname = spec['name']
opargs = spec.get('arguments')
expect_error = spec.get('expectError')
save_as_entity = spec.get('saveResultAsEntity')
expect_result = spec.get('expectResult')
ignore = spec.get('ignoreResultAndError')
if ignore and (expect_error or save_as_entity or expect_result):
raise ValueError(
'ignoreResultAndError is incompatible with saveResultAsEntity'
', expectError, and expectResult')
if opargs:
arguments = parse_spec_options(copy.deepcopy(opargs))
prepare_spec_arguments(spec, arguments, camel_to_snake(opname),
@ -745,6 +846,8 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
method_name = '_collectionOperation_%s' % (opname,)
elif isinstance(target, ChangeStream):
method_name = '_changeStreamOperation_%s' % (opname,)
elif isinstance(target, NonLazyCursor):
method_name = '_cursor_%s' % (opname,)
elif isinstance(target, ClientSession):
method_name = '_sessionOperation_%s' % (opname,)
elif isinstance(target, GridFSBucket):
@ -766,15 +869,16 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
try:
result = cmd(**dict(arguments))
except Exception as exc:
if ignore:
return
if expect_error:
return self.process_error(exc, expect_error)
raise
if 'expectResult' in spec:
if expect_result:
actual = coerce_result(opname, result)
self.match_evaluator.match_result(spec['expectResult'], actual)
self.match_evaluator.match_result(expect_result, actual)
save_as_entity = spec.get('saveResultAsEntity')
if save_as_entity:
self.entity_map[save_as_entity] = result
@ -821,7 +925,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def __get_last_two_command_lsids(self, listener):
cmd_started_events = []
for event in reversed(listener.results):
for event in reversed(listener.events):
if isinstance(event, CommandStartedEvent):
cmd_started_events.append(event)
if len(cmd_started_events) < 2:
@ -869,6 +973,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
for index in collection.list_indexes():
self.assertNotEqual(spec['indexName'], index['name'])
def _testOperation_assertNumberConnectionsCheckedOut(self, spec):
client = self.entity_map[spec['client']]
pool = get_pool(client)
self.assertEqual(spec['connections'], pool.active_sockets)
def run_special_operation(self, spec):
opname = spec['name']
method_name = '_testOperation_%s' % (opname,)
@ -891,19 +1000,23 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
for event_spec in spec:
client_name = event_spec['client']
events = event_spec['events']
listener = self.entity_map.get_listener_for_client(client_name)
# Valid types: 'command', 'cmap'
event_type = event_spec.get('eventType', 'command')
assert event_type in ('command', 'cmap')
listener = self.entity_map.get_listener_for_client(client_name)
actual_events = listener.get_events(event_type)
if len(events) == 0:
self.assertEqual(listener.results, [])
self.assertEqual(actual_events, [])
continue
if len(events) > len(listener.results):
if len(events) > len(actual_events):
self.fail('Expected to see %s events, got %s' % (
len(events), len(listener.results)))
len(events), len(actual_events)))
for idx, expected_event in enumerate(events):
self.match_evaluator.match_event(
expected_event, listener.results[idx])
event_type, expected_event, actual_events[idx])
def verify_outcome(self, spec):
for collection_data in spec:

View File

@ -277,7 +277,7 @@ class MockPool(object):
def ready(self):
pass
def reset(self):
def reset(self, service_id=None):
self._reset()
def reset_without_pause(self):

View File

@ -504,15 +504,16 @@ class SpecRunner(IntegrationTest):
client_context.storage_engine == 'mmapv1'):
self.skipTest("MMAPv1 does not support retryWrites=True")
use_multi_mongos = test['useMultipleMongoses']
if client_context.is_mongos and use_multi_mongos:
client = rs_client(
client_context.mongos_seeds(),
event_listeners=[listener, pool_listener, server_listener],
**client_options)
else:
client = rs_client(
event_listeners=[listener, pool_listener, server_listener],
**client_options)
host = None
if use_multi_mongos:
if client_context.load_balancer:
host = client_context.MULTI_MONGOS_LB_URI
elif client_context.is_mongos:
host = client_context.mongos_seeds()
client = rs_client(
h=host,
event_listeners=[listener, pool_listener, server_listener],
**client_options)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener