376 lines
12 KiB
Python
376 lines
12 KiB
Python
# Copyright 2012-2015 MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Discover environment and server configuration, initialize PyMongo client."""
|
|
|
|
import os
|
|
import socket
|
|
import sys
|
|
from functools import wraps
|
|
from test.utils import create_user
|
|
from test.version import Version
|
|
from unittest import SkipTest
|
|
|
|
import pymongo.errors
|
|
|
|
HAVE_SSL = True
|
|
try:
|
|
import ssl
|
|
except ImportError:
|
|
HAVE_SSL = False
|
|
ssl = None
|
|
|
|
HAVE_TORNADO = True
|
|
try:
|
|
import tornado
|
|
except ImportError:
|
|
HAVE_TORNADO = False
|
|
tornado = None
|
|
|
|
HAVE_ASYNCIO = True
|
|
try:
|
|
import asyncio
|
|
except ImportError:
|
|
HAVE_ASYNCIO = False
|
|
asyncio = None
|
|
|
|
HAVE_AIOHTTP = True
|
|
try:
|
|
import aiohttp
|
|
except ImportError:
|
|
HAVE_AIOHTTP = False
|
|
aiohttp = None
|
|
|
|
HAVE_PYMONGOCRYPT = True
|
|
try:
|
|
import pymongocrypt # noqa: F401
|
|
except ImportError:
|
|
HAVE_PYMONGOCRYPT = False
|
|
|
|
|
|
# Copied from PyMongo.
|
|
def partition_node(node):
|
|
"""Split a host:port string into (host, int(port)) pair."""
|
|
host = node
|
|
port = 27017
|
|
idx = node.rfind(":")
|
|
if idx != -1:
|
|
host, port = node[:idx], int(node[idx + 1 :])
|
|
if host.startswith("["):
|
|
host = host[1:-1]
|
|
return host, port
|
|
|
|
|
|
def connected(client):
|
|
"""Convenience, wait for a new PyMongo MongoClient to connect."""
|
|
client.admin.command("ping") # Force connection.
|
|
return client
|
|
|
|
|
|
# If these are set to the empty string, substitute None.
|
|
db_user = os.environ.get("DB_USER") or None
|
|
db_password = os.environ.get("DB_PASSWORD") or None
|
|
|
|
CERT_PATH = os.environ.get(
|
|
"CERT_DIR", os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
|
|
)
|
|
CLIENT_PEM = os.path.join(CERT_PATH, "client.pem")
|
|
CA_PEM = os.path.join(CERT_PATH, "ca.pem")
|
|
MONGODB_X509_USERNAME = "CN=client,OU=kerneluser,O=10Gen,L=New York City,ST=New York,C=US"
|
|
|
|
|
|
def is_server_resolvable():
|
|
"""Returns True if 'server' is resolvable."""
|
|
socket_timeout = socket.getdefaulttimeout()
|
|
socket.setdefaulttimeout(1)
|
|
try:
|
|
socket.gethostbyname("server")
|
|
return True
|
|
except socket.error:
|
|
return False
|
|
finally:
|
|
socket.setdefaulttimeout(socket_timeout)
|
|
|
|
|
|
class TestEnvironment(object):
|
|
__test__ = False
|
|
|
|
def __init__(self):
|
|
self.initialized = False
|
|
self.host = None
|
|
self.port = None
|
|
self.mongod_started_with_ssl = False
|
|
self.mongod_validates_client_cert = False
|
|
self.server_is_resolvable = is_server_resolvable()
|
|
self.sync_cx = None
|
|
self.is_standalone = False
|
|
self.is_mongos = False
|
|
self.is_replica_set = False
|
|
self.rs_name = None
|
|
self.w = 1
|
|
self.hosts = None
|
|
self.arbiters = None
|
|
self.primary = None
|
|
self.secondaries = None
|
|
self.v8 = False
|
|
self.auth = False
|
|
self.uri = None
|
|
self.rs_uri = None
|
|
self.version = None
|
|
self.sessions_enabled = False
|
|
self.fake_hostname_uri = None
|
|
self.server_status = None
|
|
|
|
def setup(self):
|
|
assert not self.initialized
|
|
self.setup_sync_cx()
|
|
self.setup_auth_and_uri()
|
|
self.setup_version()
|
|
self.setup_v8()
|
|
self.server_status = self.sync_cx.admin.command("serverStatus")
|
|
self.initialized = True
|
|
|
|
def setup_sync_cx(self):
|
|
"""Get a synchronous PyMongo MongoClient and determine SSL config."""
|
|
host = os.environ.get("DB_IP", "localhost")
|
|
port = int(os.environ.get("DB_PORT", 27017))
|
|
connectTimeoutMS = 100
|
|
serverSelectionTimeoutMS = 100
|
|
socketTimeoutMS = 10000
|
|
try:
|
|
client = connected(
|
|
pymongo.MongoClient(
|
|
host,
|
|
port,
|
|
username=db_user,
|
|
password=db_password,
|
|
directConnection=True,
|
|
connectTimeoutMS=connectTimeoutMS,
|
|
socketTimeoutMS=socketTimeoutMS,
|
|
serverSelectionTimeoutMS=serverSelectionTimeoutMS,
|
|
tlsCAFile=CA_PEM,
|
|
ssl=True,
|
|
)
|
|
)
|
|
|
|
self.mongod_started_with_ssl = True
|
|
except pymongo.errors.ServerSelectionTimeoutError:
|
|
try:
|
|
client = connected(
|
|
pymongo.MongoClient(
|
|
host,
|
|
port,
|
|
username=db_user,
|
|
password=db_password,
|
|
directConnection=True,
|
|
connectTimeoutMS=connectTimeoutMS,
|
|
socketTimeoutMS=socketTimeoutMS,
|
|
serverSelectionTimeoutMS=serverSelectionTimeoutMS,
|
|
tlsCAFile=CA_PEM,
|
|
tlsCertificateKeyFile=CLIENT_PEM,
|
|
)
|
|
)
|
|
|
|
self.mongod_started_with_ssl = True
|
|
self.mongod_validates_client_cert = True
|
|
except pymongo.errors.ServerSelectionTimeoutError:
|
|
client = connected(
|
|
pymongo.MongoClient(
|
|
host,
|
|
port,
|
|
username=db_user,
|
|
password=db_password,
|
|
directConnection=True,
|
|
connectTimeoutMS=connectTimeoutMS,
|
|
socketTimeoutMS=socketTimeoutMS,
|
|
serverSelectionTimeoutMS=serverSelectionTimeoutMS,
|
|
)
|
|
)
|
|
|
|
response = client.admin.command("ismaster")
|
|
self.sessions_enabled = "logicalSessionTimeoutMinutes" in response
|
|
self.is_mongos = response.get("msg") == "isdbgrid"
|
|
if "setName" in response:
|
|
self.is_replica_set = True
|
|
self.rs_name = str(response["setName"])
|
|
self.w = len(response["hosts"])
|
|
self.hosts = set([partition_node(h) for h in response["hosts"]])
|
|
host, port = self.primary = partition_node(response["primary"])
|
|
self.arbiters = set([partition_node(h) for h in response.get("arbiters", [])])
|
|
|
|
self.secondaries = [
|
|
partition_node(m)
|
|
for m in response["hosts"]
|
|
if m != self.primary and m not in self.arbiters
|
|
]
|
|
elif not self.is_mongos:
|
|
self.is_standalone = True
|
|
|
|
# Reconnect to found primary, without short timeouts.
|
|
if self.mongod_started_with_ssl:
|
|
client = connected(
|
|
pymongo.MongoClient(
|
|
host,
|
|
port,
|
|
username=db_user,
|
|
password=db_password,
|
|
directConnection=True,
|
|
tlsCAFile=CA_PEM,
|
|
tlsCertificateKeyFile=CLIENT_PEM,
|
|
)
|
|
)
|
|
else:
|
|
client = connected(
|
|
pymongo.MongoClient(
|
|
host,
|
|
port,
|
|
username=db_user,
|
|
password=db_password,
|
|
directConnection=True,
|
|
ssl=False,
|
|
)
|
|
)
|
|
|
|
self.sync_cx = client
|
|
self.host = host
|
|
self.port = port
|
|
|
|
def setup_auth_and_uri(self):
|
|
"""Set self.auth and self.uri."""
|
|
if db_user or db_password:
|
|
if not (db_user and db_password):
|
|
sys.stderr.write("You must set both DB_USER and DB_PASSWORD, or neither\n")
|
|
sys.exit(1)
|
|
|
|
self.auth = True
|
|
uri_template = "mongodb://%s:%s@%s:%s/admin"
|
|
self.uri = uri_template % (db_user, db_password, self.host, self.port)
|
|
|
|
# If the hostname 'server' is resolvable, this URI lets us use it
|
|
# to test SSL hostname validation with auth.
|
|
self.fake_hostname_uri = uri_template % (db_user, db_password, "server", self.port)
|
|
|
|
else:
|
|
self.uri = "mongodb://%s:%s/admin" % (self.host, self.port)
|
|
|
|
self.fake_hostname_uri = "mongodb://%s:%s/admin" % ("server", self.port)
|
|
|
|
if self.rs_name:
|
|
self.rs_uri = self.uri + "?replicaSet=" + self.rs_name
|
|
|
|
def setup_version(self):
|
|
"""Set self.version to the server's version."""
|
|
self.version = Version.from_client(self.sync_cx)
|
|
|
|
def setup_v8(self):
|
|
"""Determine if server is running SpiderMonkey or V8."""
|
|
if self.sync_cx.server_info().get("javascriptEngine") == "V8":
|
|
self.v8 = True
|
|
|
|
@property
|
|
def storage_engine(self):
|
|
try:
|
|
return self.server_status.get("storageEngine", {}).get("name")
|
|
except AttributeError:
|
|
# Raised if self.server_status is None.
|
|
return None
|
|
|
|
def supports_transactions(self):
|
|
if self.storage_engine == "mmapv1":
|
|
return False
|
|
|
|
if self.version.at_least(4, 1, 8):
|
|
return self.is_mongos or self.is_replica_set
|
|
|
|
if self.version.at_least(4, 0):
|
|
return self.is_replica_set
|
|
|
|
return False
|
|
|
|
def require(self, condition, msg, func=None):
|
|
def make_wrapper(f):
|
|
@wraps(f)
|
|
def wrap(*args, **kwargs):
|
|
assert self.initialized
|
|
if condition():
|
|
return f(*args, **kwargs)
|
|
raise SkipTest(msg)
|
|
|
|
return wrap
|
|
|
|
if func is None:
|
|
|
|
def decorate(f):
|
|
return make_wrapper(f)
|
|
|
|
return decorate
|
|
return make_wrapper(func)
|
|
|
|
def require_auth(self, func):
|
|
"""Run a test only if the server is started with auth."""
|
|
return self.require(lambda: self.auth, "Server must be start with auth", func=func)
|
|
|
|
def require_version_min(self, *ver):
|
|
"""Run a test only if the server version is at least ``version``."""
|
|
other_version = Version(*ver)
|
|
return self.require(
|
|
lambda: self.version >= other_version,
|
|
"Server version must be at least %s" % str(other_version),
|
|
)
|
|
|
|
def require_version_max(self, *ver):
|
|
"""Run a test only if the server version is at most ``version``."""
|
|
other_version = Version(*ver)
|
|
return self.require(
|
|
lambda: self.version <= other_version,
|
|
"Server version must be at most %s" % str(other_version),
|
|
)
|
|
|
|
def require_no_standalone(self, func):
|
|
"""Run a test only if the client is not connected to a standalone."""
|
|
return self.require(lambda: not self.is_standalone, "Connected to a standalone", func=func)
|
|
|
|
def require_replica_set(self, func):
|
|
"""Run a test only if the client is connected to a replica set."""
|
|
return self.require(
|
|
lambda: self.is_replica_set, "Not connected to a replica set", func=func
|
|
)
|
|
|
|
def require_transactions(self, func):
|
|
"""Run a test only if the deployment might support transactions.
|
|
|
|
*Might* because this does not test the FCV.
|
|
"""
|
|
return self.require(self.supports_transactions, "Transactions are not supported", func=func)
|
|
|
|
def require_csfle(self, func):
|
|
"""Run a test only if the deployment supports CSFLE."""
|
|
return self.require(
|
|
lambda: HAVE_PYMONGOCRYPT and self.version >= Version(4, 2),
|
|
"CSFLE requires pymongocrypt and MongoDB >=4.2",
|
|
func=func,
|
|
)
|
|
|
|
def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
|
|
kwargs["writeConcern"] = {"w": self.w}
|
|
return create_user(self.sync_cx[dbname], user, pwd, roles, **kwargs)
|
|
|
|
def drop_user(self, dbname, user):
|
|
self.sync_cx[dbname].command("dropUser", user, writeConcern={"w": self.w})
|
|
|
|
|
|
env = TestEnvironment()
|
|
if "SKIP_ENV_SETUP" not in os.environ:
|
|
env.setup()
|