mongo-python-driver/test/test_uri_spec.py
Shane Harvey 0092b0af79
PYTHON-2504 Run pyupgrade 3.4.0 and ruff 0.0.265 (#1196)
pyupgrade --py37-plus bson/*.py pymongo/*.py gridfs/*.py test/*.py tools/*.py test/*/*.py
ruff --fix-only --select ALL --fixable ALL --target-version py37 --line-length=100 --unfixable COM812,D400,D415,ERA001,RUF100,SIM108,D211,D212,SIM105,SIM,PT,ANN204,EM bson/*.py pymongo/*.py gridfs/*.py test/*.py test/*/*.py
2023-05-11 15:27:17 -07:00

230 lines
9.7 KiB
Python

# Copyright 2011-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that the pymongo.uri_parser module is compliant with the connection
string and uri options specifications.
"""
import json
import os
import sys
import warnings
sys.path[0:0] = [""]
from test import clear_warning_registry, unittest
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
from pymongo.compression_support import _HAVE_SNAPPY
from pymongo.srv_resolver import _HAVE_DNSPYTHON
from pymongo.uri_parser import SRV_SCHEME, parse_uri
CONN_STRING_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test")
)
URI_OPTIONS_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "uri_options")
TEST_DESC_SKIP_LIST = [
"Valid options specific to single-threaded drivers are parsed correctly",
"Invalid serverSelectionTryOnce causes a warning",
"tlsDisableCertificateRevocationCheck can be set to true",
"tlsDisableCertificateRevocationCheck can be set to false",
"tlsAllowInvalidCertificates and tlsDisableCertificateRevocationCheck both present (and true) raises an error",
"tlsAllowInvalidCertificates=true and tlsDisableCertificateRevocationCheck=false raises an error",
"tlsAllowInvalidCertificates=false and tlsDisableCertificateRevocationCheck=true raises an error",
"tlsAllowInvalidCertificates and tlsDisableCertificateRevocationCheck both present (and false) raises an error",
"tlsDisableCertificateRevocationCheck and tlsAllowInvalidCertificates both present (and true) raises an error",
"tlsDisableCertificateRevocationCheck=true and tlsAllowInvalidCertificates=false raises an error",
"tlsDisableCertificateRevocationCheck=false and tlsAllowInvalidCertificates=true raises an error",
"tlsDisableCertificateRevocationCheck and tlsAllowInvalidCertificates both present (and false) raises an error",
"tlsInsecure and tlsDisableCertificateRevocationCheck both present (and true) raises an error",
"tlsInsecure=true and tlsDisableCertificateRevocationCheck=false raises an error",
"tlsInsecure=false and tlsDisableCertificateRevocationCheck=true raises an error",
"tlsInsecure and tlsDisableCertificateRevocationCheck both present (and false) raises an error",
"tlsDisableCertificateRevocationCheck and tlsInsecure both present (and true) raises an error",
"tlsDisableCertificateRevocationCheck=true and tlsInsecure=false raises an error",
"tlsDisableCertificateRevocationCheck=false and tlsInsecure=true raises an error",
"tlsDisableCertificateRevocationCheck and tlsInsecure both present (and false) raises an error",
"tlsDisableCertificateRevocationCheck and tlsDisableOCSPEndpointCheck both present (and true) raises an error",
"tlsDisableCertificateRevocationCheck=true and tlsDisableOCSPEndpointCheck=false raises an error",
"tlsDisableCertificateRevocationCheck=false and tlsDisableOCSPEndpointCheck=true raises an error",
"tlsDisableCertificateRevocationCheck and tlsDisableOCSPEndpointCheck both present (and false) raises an error",
"tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and true) raises an error",
"tlsDisableOCSPEndpointCheck=true and tlsDisableCertificateRevocationCheck=false raises an error",
"tlsDisableOCSPEndpointCheck=false and tlsDisableCertificateRevocationCheck=true raises an error",
"tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and false) raises an error",
]
class TestAllScenarios(unittest.TestCase):
def setUp(self):
clear_warning_registry()
def get_error_message_template(expected, artefact):
return "{} {} for test '{}'".format("Expected" if expected else "Unexpected", artefact, "%s")
def run_scenario_in_dir(target_workdir):
def workdir_context_decorator(func):
def modified_test_scenario(*args, **kwargs):
original_workdir = os.getcwd()
os.chdir(target_workdir)
func(*args, **kwargs)
os.chdir(original_workdir)
return modified_test_scenario
return workdir_context_decorator
def create_test(test, test_workdir):
def run_scenario(self):
compressors = (test.get("options") or {}).get("compressors", [])
if "snappy" in compressors and not _HAVE_SNAPPY:
self.skipTest("This test needs the snappy module.")
if test["uri"].startswith(SRV_SCHEME) and not _HAVE_DNSPYTHON:
self.skipTest("This test needs dnspython package.")
valid = True
warning = False
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter("always")
try:
options = parse_uri(test["uri"], warn=True)
except Exception:
valid = False
else:
warning = len(ctx) > 0
expected_valid = test.get("valid", True)
self.assertEqual(
valid,
expected_valid,
get_error_message_template(not expected_valid, "error") % test["description"],
)
if expected_valid:
expected_warning = test.get("warning", False)
self.assertEqual(
warning,
expected_warning,
get_error_message_template(expected_warning, "warning") % test["description"],
)
# Compare hosts and port.
if test["hosts"] is not None:
self.assertEqual(
len(test["hosts"]),
len(options["nodelist"]),
"Incorrect number of hosts parsed from URI",
)
for exp, actual in zip(test["hosts"], options["nodelist"]):
self.assertEqual(
exp["host"],
actual[0],
"Expected host {} but got {}".format(exp["host"], actual[0]),
)
if exp["port"] is not None:
self.assertEqual(
exp["port"],
actual[1],
"Expected port {} but got {}".format(exp["port"], actual),
)
# Compare auth options.
auth = test["auth"]
if auth is not None:
auth["database"] = auth.pop("db") # db == database
# Special case for PyMongo's collection parsing.
if options.get("collection") is not None:
options["database"] += "." + options["collection"]
for elm in auth:
if auth[elm] is not None:
# We have to do this because while the spec requires
# "+"->"+", unquote_plus does "+"->" "
options[elm] = options[elm].replace(" ", "+")
self.assertEqual(
auth[elm],
options[elm],
f"Expected {auth[elm]} but got {options[elm]}",
)
# Compare URI options.
err_msg = "For option %s expected %s but got %s"
if test["options"]:
opts = options["options"]
for opt in test["options"]:
lopt = opt.lower()
optname = INTERNAL_URI_OPTION_NAME_MAP.get(lopt, lopt)
if opts.get(optname) is not None:
if opts[optname] == test["options"][opt]:
expected_value = test["options"][opt]
else:
expected_value = validate(lopt, test["options"][opt])[1]
self.assertEqual(
opts[optname],
expected_value,
err_msg
% (
opt,
expected_value,
opts[optname],
),
)
else:
self.fail(f"Missing expected option {opt}")
return run_scenario_in_dir(test_workdir)(run_scenario)
def create_tests(test_path):
for dirpath, _, filenames in os.walk(test_path):
dirname = os.path.split(dirpath)
dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1]
for filename in filenames:
if not filename.endswith(".json"):
# skip everything that is not a test specification
continue
json_path = os.path.join(dirpath, filename)
with open(json_path, encoding="utf-8") as scenario_stream:
scenario_def = json.load(scenario_stream)
for testcase in scenario_def["tests"]:
dsc = testcase["description"]
if dsc in TEST_DESC_SKIP_LIST:
print("Skipping test '%s'" % dsc)
continue
testmethod = create_test(testcase, dirpath)
testname = "test_{}_{}_{}".format(
dirname,
os.path.splitext(filename)[0],
str(dsc).replace(" ", "_"),
)
testmethod.__name__ = testname
setattr(TestAllScenarios, testmethod.__name__, testmethod)
for test_path in [CONN_STRING_TEST_PATH, URI_OPTIONS_TEST_PATH]:
create_tests(test_path)
if __name__ == "__main__":
unittest.main()