mongo-python-driver/test/test_discovery_and_monitoring.py
2017-12-01 17:23:39 -08:00

242 lines
7.4 KiB
Python

# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the topology module."""
import os
import sys
import threading
sys.path[0:0] = [""]
from bson import json_util, Timestamp
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.topology import Topology
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.ismaster import IsMaster
from pymongo.server_description import ServerDescription, SERVER_TYPE
from pymongo.settings import TopologySettings
from pymongo.uri_parser import parse_uri
from test import unittest
# Location of JSON test specifications.
_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'discovery_and_monitoring')
class MockSocketInfo(object):
def close(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class MockPool(object):
def __init__(self, *args, **kwargs):
self.pool_id = 0
self._lock = threading.Lock()
def reset(self):
with self._lock:
self.pool_id += 1
class MockMonitor(object):
def __init__(self, server_description, topology, pool, topology_settings):
self._server_description = server_description
self._topology = topology
def open(self):
pass
def request_check(self):
pass
def close(self):
pass
def remove_stale_sockets(self):
pass
def create_mock_topology(uri, monitor_class=MockMonitor):
# Some tests in the spec include URIs like mongodb://A/?connect=direct,
# but PyMongo considers any single-seed URI with no setName to be "direct".
parsed_uri = parse_uri(uri.replace('connect=direct', ''))
replica_set_name = None
if 'replicaset' in parsed_uri['options']:
replica_set_name = parsed_uri['options']['replicaset']
topology_settings = TopologySettings(
parsed_uri['nodelist'],
replica_set_name=replica_set_name,
pool_class=MockPool,
monitor_class=monitor_class)
c = Topology(topology_settings)
c.open()
return c
def got_ismaster(topology, server_address, ismaster_response):
server_description = ServerDescription(
server_address, IsMaster(ismaster_response), 0)
topology.on_change(server_description)
def get_type(topology, hostname):
description = topology.get_server_by_address((hostname, 27017)).description
return description.server_type
class TestAllScenarios(unittest.TestCase):
pass
def topology_type_name(topology_type):
return TOPOLOGY_TYPE._fields[topology_type]
def server_type_name(server_type):
return SERVER_TYPE._fields[server_type]
def check_outcome(self, topology, outcome):
expected_servers = outcome['servers']
# Check weak equality before proceeding.
self.assertEqual(
len(topology.description.server_descriptions()),
len(expected_servers))
if outcome.get('compatible') is False:
with self.assertRaises(ConfigurationError):
topology.description.check_compatible()
else:
# No error.
topology.description.check_compatible()
# Since lengths are equal, every actual server must have a corresponding
# expected server.
for expected_server_address, expected_server in expected_servers.items():
node = common.partition_node(expected_server_address)
self.assertTrue(topology.has_server(node))
actual_server = topology.get_server_by_address(node)
actual_server_description = actual_server.description
if expected_server['type'] == 'PossiblePrimary':
# Special case, some tests in the spec include the PossiblePrimary
# type, but only single-threaded drivers need that type. We call
# possible primaries Unknown.
expected_server_type = SERVER_TYPE.Unknown
else:
expected_server_type = getattr(
SERVER_TYPE, expected_server['type'])
self.assertEqual(
server_type_name(expected_server_type),
server_type_name(actual_server_description.server_type))
self.assertEqual(
expected_server.get('setName'),
actual_server_description.replica_set_name)
self.assertEqual(
expected_server.get('setVersion'),
actual_server_description.set_version)
self.assertEqual(
expected_server.get('electionId'),
actual_server_description.election_id)
self.assertEqual(outcome['setName'], topology.description.replica_set_name)
self.assertEqual(outcome['logicalSessionTimeoutMinutes'],
topology.description.logical_session_timeout_minutes)
expected_topology_type = getattr(TOPOLOGY_TYPE, outcome['topologyType'])
self.assertEqual(topology_type_name(expected_topology_type),
topology_type_name(topology.description.topology_type))
def create_test(scenario_def):
def run_scenario(self):
c = create_mock_topology(scenario_def['uri'])
for phase in scenario_def['phases']:
for response in phase['responses']:
got_ismaster(c,
common.partition_node(response[0]),
response[1])
check_outcome(self, c, phase['outcome'])
return run_scenario
def create_tests():
for dirpath, _, filenames in os.walk(_TEST_PATH):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream:
scenario_def = json_util.loads(scenario_stream.read())
# Construct test from scenario.
new_test = create_test(scenario_def)
test_name = 'test_%s_%s' % (
dirname, os.path.splitext(filename)[0])
new_test.__name__ = test_name
setattr(TestAllScenarios, new_test.__name__, new_test)
create_tests()
class TestClusterTimeComparison(unittest.TestCase):
def test_cluster_time_comparison(self):
t = create_mock_topology('mongodb://host')
def send_cluster_time(time, inc, should_update):
old = t.max_cluster_time()
new = {'clusterTime': Timestamp(time, inc)}
got_ismaster(t,
('host', 27017),
{'ok': 1,
'minWireVersion': 0,
'maxWireVersion': 6,
'$clusterTime': new})
actual = t.max_cluster_time()
if should_update:
self.assertEqual(actual, new)
else:
self.assertEqual(actual, old)
send_cluster_time(0, 1, True)
send_cluster_time(2, 2, True)
send_cluster_time(2, 1, False)
send_cluster_time(1, 3, False)
send_cluster_time(2, 3, True)
if __name__ == "__main__":
unittest.main()