PYTHON-2823 Allow custom service names with srvServiceName URI option (#749)

This commit is contained in:
Julius Park 2021-10-08 11:23:21 -07:00 committed by GitHub
parent 049daf9cf6
commit 6bb8a1f411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 111 additions and 8 deletions

View File

@ -134,6 +134,9 @@ Breaking Changes in 4.0
- The ``hint`` option is now required when using ``min`` or ``max`` queries
with :meth:`~pymongo.collection.Collection.find`.
- ``name`` is now a required argument for the :class:`pymongo.driver_info.DriverInfo` class.
- When providing a "mongodb+srv://" URI to
:class:`~pymongo.mongo_client.MongoClient` constructor you can now use the
``srvServiceName`` URI option to specify your own SRV service name.
- :meth:`~bson.son.SON.items` now returns a ``dict_items`` object rather
than a list.
- Removed :meth:`bson.son.SON.iteritems`.
@ -160,7 +163,6 @@ Breaking Changes in 4.0
- ``MongoClient()`` now raises a :exc:`~pymongo.errors.ConfigurationError`
when more than one URI is passed into the ``hosts`` argument.
Notable improvements
....................

View File

@ -113,6 +113,9 @@ UNAUTHORIZED_CODES = (13, 16547, 16548)
# From the driver sessions spec.
_MAX_END_SESSIONS = 10000
# Default value for srvServiceName
SRV_SERVICE_NAME = "mongodb"
def partition_node(node):
"""Split a host:port string into (host, int(port)) pair."""
@ -626,6 +629,7 @@ URI_OPTIONS_VALIDATOR_MAP = {
'w': validate_non_negative_int_or_basestring,
'wtimeoutms': validate_non_negative_integer,
'zlibcompressionlevel': validate_zlib_compression_level,
'srvservicename': validate_string
}
# Dictionary where keys are the names of URI options specific to pymongo,

View File

@ -329,6 +329,11 @@ class MongoClient(common.BaseObject):
a Unicode-related error occurs during BSON decoding that would
otherwise raise :exc:`UnicodeDecodeError`. Valid options include
'strict', 'replace', and 'ignore'. Defaults to 'strict'.
- ``srvServiceName`: (string) The SRV service name to use for
"mongodb+srv://" URIs. Defaults to "mongodb". Use it like so::
MongoClient("mongodb+srv://example.com/?srvServiceName=customname")
| **Write Concern options:**
| (Only set if passed. No default values.)
@ -499,6 +504,7 @@ class MongoClient(common.BaseObject):
arguments.
The default for `uuidRepresentation` was changed from
``pythonLegacy`` to ``unspecified``.
Added the ``srvServiceName`` URI and keyword argument.
.. versionchanged:: 3.12
Added the ``server_api`` keyword argument.
@ -644,6 +650,8 @@ class MongoClient(common.BaseObject):
dbase = None
opts = common._CaseInsensitiveDictionary()
fqdn = None
srv_service_name = keyword_opts.get("srvservicename", None)
if len([h for h in host if "/" in h]) > 1:
raise ConfigurationError("host must not contain multiple MongoDB "
"URIs")
@ -659,7 +667,7 @@ class MongoClient(common.BaseObject):
keyword_opts.cased_key("connecttimeoutms"), timeout)
res = uri_parser.parse_uri(
entity, port, validate=True, warn=True, normalize=False,
connect_timeout=timeout)
connect_timeout=timeout, srv_service_name=srv_service_name)
seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
@ -689,6 +697,10 @@ class MongoClient(common.BaseObject):
# Override connection string options with kwarg options.
opts.update(keyword_opts)
if srv_service_name is None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
# Handle security-option conflicts in combined options.
opts = _handle_security_options(opts)
# Normalize combined options.
@ -728,6 +740,7 @@ class MongoClient(common.BaseObject):
server_selector=options.server_selector,
heartbeat_frequency=options.heartbeat_frequency,
fqdn=fqdn,
srv_service_name=srv_service_name,
direct_connection=options.direct_connection,
load_balanced=options.load_balanced,
)

View File

@ -299,6 +299,7 @@ class SrvMonitor(MonitorBase):
self._settings = topology_settings
self._seedlist = self._settings._seeds
self._fqdn = self._settings.fqdn
self._srv_service_name = self._settings._srv_service_name
def _run(self):
seedlist = self._get_seedlist()
@ -316,7 +317,10 @@ class SrvMonitor(MonitorBase):
Returns a list of ServerDescriptions.
"""
try:
seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl()
resolver = _SrvResolver(self._fqdn,
self._settings.pool_options.connect_timeout,
self._srv_service_name)
seedlist, ttl = resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception

View File

@ -39,6 +39,7 @@ class TopologySettings(object):
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
server_selector=None,
fqdn=None,
srv_service_name=common.SRV_SERVICE_NAME,
direct_connection=False,
load_balanced=None):
"""Represent MongoClient's configuration.
@ -60,6 +61,7 @@ class TopologySettings(object):
self._server_selection_timeout = server_selection_timeout
self._server_selector = server_selector
self._fqdn = fqdn
self._srv_service_name = srv_service_name
self._heartbeat_frequency = heartbeat_frequency
self._direct = direct_connection

View File

@ -47,8 +47,10 @@ _INVALID_HOST_MSG = (
"Did you mean to use 'mongodb://'?")
class _SrvResolver(object):
def __init__(self, fqdn, connect_timeout=None):
def __init__(self, fqdn,
connect_timeout, srv_service_name):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
# Validate the fully qualified domain name.
@ -83,8 +85,8 @@ class _SrvResolver(object):
def _resolve_uri(self, encapsulate_errors):
try:
results = _resolve('_mongodb._tcp.' + self.__fqdn, 'SRV',
lifetime=self.__connect_timeout)
results = _resolve('_' + self.__srv + '._tcp.' + self.__fqdn,
'SRV', lifetime=self.__connect_timeout)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.

View File

@ -21,6 +21,7 @@ import sys
from urllib.parse import unquote_plus
from pymongo.common import (
SRV_SERVICE_NAME,
get_validated_options, INTERNAL_URI_OPTION_NAME_MAP,
URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary)
from pymongo.errors import ConfigurationError, InvalidURI
@ -373,7 +374,7 @@ def _check_options(nodes, options):
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
normalize=True, connect_timeout=None):
normalize=True, connect_timeout=None, srv_service_name=None):
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
@ -405,6 +406,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
to their internally-used names. Default: ``True``.
- `connect_timeout` (optional): The maximum time in milliseconds to
wait for a response from the DNS server.
- 'srv_service_name` (optional): A custom SRV service name
.. versionchanged:: 3.9
Added the ``normalize`` parameter.
@ -468,6 +470,9 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if '@' in host_part:
userinfo, _, hosts = host_part.rpartition('@')
user, passwd = parse_userinfo(userinfo)
@ -499,7 +504,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout=connect_timeout)
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name)
nodes = dns_resolver.get_hosts()
dns_options = dns_resolver.get_options()
if dns_options:
@ -514,6 +519,9 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
options[opt] = val
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else 'true'
elif not is_srv and options.get("srvServiceName") is not None:
raise ConfigurationError("The srvServiceName option is only allowed "
"with 'mongodb+srv://' URIs")
else:
nodes = split_hosts(hosts, default_port=default_port)

View File

@ -0,0 +1,7 @@
{
"uri": "mongodb+srv://test4.test.build.10gen.cc/?loadBalanced=true",
"seeds": [],
"hosts": [],
"error": true,
"comment": "Should fail because no SRV records are present for this URI."
}

View File

@ -0,0 +1,16 @@
{
"uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
"seeds": [
"localhost.test.build.10gen.cc:27017",
"localhost.test.build.10gen.cc:27018"
],
"hosts": [
"localhost:27017",
"localhost:27018",
"localhost:27019"
],
"options": {
"ssl": true,
"srvServiceName": "customname"
}
}

View File

@ -1591,6 +1591,27 @@ class TestClient(IntegrationTest):
with self.assertRaisesRegex(AutoReconnect, expected):
client.pymongo_test.test.find_one({})
@unittest.skipUnless(
_HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed")
def test_service_name_from_kwargs(self):
client = MongoClient(
'mongodb+srv://user:password@test22.test.build.10gen.cc',
srvServiceName='customname', connect=False)
self.assertEqual(client._topology_settings._srv_service_name,
'customname')
client = MongoClient(
'mongodb+srv://user:password@test22.test.build.10gen.cc'
'/?srvServiceName=shouldbeoverriden',
srvServiceName='customname', connect=False)
self.assertEqual(client._topology_settings._srv_service_name,
'customname')
client = MongoClient(
'mongodb+srv://user:password@test22.test.build.10gen.cc'
'/?srvServiceName=customname',
connect=False)
self.assertEqual(client._topology_settings._srv_service_name,
'customname')
class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""

View File

@ -0,0 +1,24 @@
{
"tests": [
{
"description": "SRV URI with custom srvServiceName",
"uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"srvServiceName": "customname"
}
},
{
"description": "Non-SRV URI with custom srvServiceName",
"uri": "mongodb://example.com/?srvServiceName=customname",
"valid": false,
"warning": true,
"hosts": null,
"auth": null,
"options": {}
}
]
}