diff --git a/doc/changelog.rst b/doc/changelog.rst index c4c4f35fe..8dfe4f750 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -134,6 +134,9 @@ Breaking Changes in 4.0 - The ``hint`` option is now required when using ``min`` or ``max`` queries with :meth:`~pymongo.collection.Collection.find`. - ``name`` is now a required argument for the :class:`pymongo.driver_info.DriverInfo` class. +- When providing a "mongodb+srv://" URI to + :class:`~pymongo.mongo_client.MongoClient` constructor you can now use the + ``srvServiceName`` URI option to specify your own SRV service name. - :meth:`~bson.son.SON.items` now returns a ``dict_items`` object rather than a list. - Removed :meth:`bson.son.SON.iteritems`. @@ -160,7 +163,6 @@ Breaking Changes in 4.0 - ``MongoClient()`` now raises a :exc:`~pymongo.errors.ConfigurationError` when more than one URI is passed into the ``hosts`` argument. - Notable improvements .................... diff --git a/pymongo/common.py b/pymongo/common.py index f1a5389b3..8ccdc21b6 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -113,6 +113,9 @@ UNAUTHORIZED_CODES = (13, 16547, 16548) # From the driver sessions spec. _MAX_END_SESSIONS = 10000 +# Default value for srvServiceName +SRV_SERVICE_NAME = "mongodb" + def partition_node(node): """Split a host:port string into (host, int(port)) pair.""" @@ -626,6 +629,7 @@ URI_OPTIONS_VALIDATOR_MAP = { 'w': validate_non_negative_int_or_basestring, 'wtimeoutms': validate_non_negative_integer, 'zlibcompressionlevel': validate_zlib_compression_level, + 'srvservicename': validate_string } # Dictionary where keys are the names of URI options specific to pymongo, diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index c0bcb9575..d1794938f 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -329,6 +329,11 @@ class MongoClient(common.BaseObject): a Unicode-related error occurs during BSON decoding that would otherwise raise :exc:`UnicodeDecodeError`. Valid options include 'strict', 'replace', and 'ignore'. Defaults to 'strict'. + - ``srvServiceName`: (string) The SRV service name to use for + "mongodb+srv://" URIs. Defaults to "mongodb". Use it like so:: + + MongoClient("mongodb+srv://example.com/?srvServiceName=customname") + | **Write Concern options:** | (Only set if passed. No default values.) @@ -499,6 +504,7 @@ class MongoClient(common.BaseObject): arguments. The default for `uuidRepresentation` was changed from ``pythonLegacy`` to ``unspecified``. + Added the ``srvServiceName`` URI and keyword argument. .. versionchanged:: 3.12 Added the ``server_api`` keyword argument. @@ -644,6 +650,8 @@ class MongoClient(common.BaseObject): dbase = None opts = common._CaseInsensitiveDictionary() fqdn = None + srv_service_name = keyword_opts.get("srvservicename", None) + if len([h for h in host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB " "URIs") @@ -659,7 +667,7 @@ class MongoClient(common.BaseObject): keyword_opts.cased_key("connecttimeoutms"), timeout) res = uri_parser.parse_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout) + connect_timeout=timeout, srv_service_name=srv_service_name) seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password @@ -689,6 +697,10 @@ class MongoClient(common.BaseObject): # Override connection string options with kwarg options. opts.update(keyword_opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + # Handle security-option conflicts in combined options. opts = _handle_security_options(opts) # Normalize combined options. @@ -728,6 +740,7 @@ class MongoClient(common.BaseObject): server_selector=options.server_selector, heartbeat_frequency=options.heartbeat_frequency, fqdn=fqdn, + srv_service_name=srv_service_name, direct_connection=options.direct_connection, load_balanced=options.load_balanced, ) diff --git a/pymongo/monitor.py b/pymongo/monitor.py index d13d337f8..5359ee054 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -299,6 +299,7 @@ class SrvMonitor(MonitorBase): self._settings = topology_settings self._seedlist = self._settings._seeds self._fqdn = self._settings.fqdn + self._srv_service_name = self._settings._srv_service_name def _run(self): seedlist = self._get_seedlist() @@ -316,7 +317,10 @@ class SrvMonitor(MonitorBase): Returns a list of ServerDescriptions. """ try: - seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl() + resolver = _SrvResolver(self._fqdn, + self._settings.pool_options.connect_timeout, + self._srv_service_name) + seedlist, ttl = resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception diff --git a/pymongo/settings.py b/pymongo/settings.py index ff9e84ee0..7e9d393ac 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -39,6 +39,7 @@ class TopologySettings(object): heartbeat_frequency=common.HEARTBEAT_FREQUENCY, server_selector=None, fqdn=None, + srv_service_name=common.SRV_SERVICE_NAME, direct_connection=False, load_balanced=None): """Represent MongoClient's configuration. @@ -60,6 +61,7 @@ class TopologySettings(object): self._server_selection_timeout = server_selection_timeout self._server_selector = server_selector self._fqdn = fqdn + self._srv_service_name = srv_service_name self._heartbeat_frequency = heartbeat_frequency self._direct = direct_connection diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index c53ce378e..e3f8f330f 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -47,8 +47,10 @@ _INVALID_HOST_MSG = ( "Did you mean to use 'mongodb://'?") class _SrvResolver(object): - def __init__(self, fqdn, connect_timeout=None): + def __init__(self, fqdn, + connect_timeout, srv_service_name): self.__fqdn = fqdn + self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT # Validate the fully qualified domain name. @@ -83,8 +85,8 @@ class _SrvResolver(object): def _resolve_uri(self, encapsulate_errors): try: - results = _resolve('_mongodb._tcp.' + self.__fqdn, 'SRV', - lifetime=self.__connect_timeout) + results = _resolve('_' + self.__srv + '._tcp.' + self.__fqdn, + 'SRV', lifetime=self.__connect_timeout) except Exception as exc: if not encapsulate_errors: # Raise the original error. diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index cf5cf64f5..79642c711 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -21,6 +21,7 @@ import sys from urllib.parse import unquote_plus from pymongo.common import ( + SRV_SERVICE_NAME, get_validated_options, INTERNAL_URI_OPTION_NAME_MAP, URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary) from pymongo.errors import ConfigurationError, InvalidURI @@ -373,7 +374,7 @@ def _check_options(nodes, options): def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, - normalize=True, connect_timeout=None): + normalize=True, connect_timeout=None, srv_service_name=None): """Parse and validate a MongoDB URI. Returns a dict of the form:: @@ -405,6 +406,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, to their internally-used names. Default: ``True``. - `connect_timeout` (optional): The maximum time in milliseconds to wait for a response from the DNS server. + - 'srv_service_name` (optional): A custom SRV service name .. versionchanged:: 3.9 Added the ``normalize`` parameter. @@ -468,6 +470,9 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, if opts: options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if '@' in host_part: userinfo, _, hosts = host_part.rpartition('@') user, passwd = parse_userinfo(userinfo) @@ -499,7 +504,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout=connect_timeout) + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name) nodes = dns_resolver.get_hosts() dns_options = dns_resolver.get_options() if dns_options: @@ -514,6 +519,9 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, options[opt] = val if "tls" not in options and "ssl" not in options: options["tls"] = True if validate else 'true' + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError("The srvServiceName option is only allowed " + "with 'mongodb+srv://' URIs") else: nodes = split_hosts(hosts, default_port=default_port) diff --git a/test/srv_seedlist/load-balanced/loadBalanced-no-results.json b/test/srv_seedlist/load-balanced/loadBalanced-no-results.json new file mode 100644 index 000000000..7f49416aa --- /dev/null +++ b/test/srv_seedlist/load-balanced/loadBalanced-no-results.json @@ -0,0 +1,7 @@ +{ + "uri": "mongodb+srv://test4.test.build.10gen.cc/?loadBalanced=true", + "seeds": [], + "hosts": [], + "error": true, + "comment": "Should fail because no SRV records are present for this URI." +} diff --git a/test/srv_seedlist/replica-set/srv-service-name.json b/test/srv_seedlist/replica-set/srv-service-name.json new file mode 100644 index 000000000..ec36cdbb0 --- /dev/null +++ b/test/srv_seedlist/replica-set/srv-service-name.json @@ -0,0 +1,16 @@ +{ + "uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname", + "seeds": [ + "localhost.test.build.10gen.cc:27017", + "localhost.test.build.10gen.cc:27018" + ], + "hosts": [ + "localhost:27017", + "localhost:27018", + "localhost:27019" + ], + "options": { + "ssl": true, + "srvServiceName": "customname" + } +} diff --git a/test/test_client.py b/test/test_client.py index 4398fa08b..4d8b26400 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1591,6 +1591,27 @@ class TestClient(IntegrationTest): with self.assertRaisesRegex(AutoReconnect, expected): client.pymongo_test.test.find_one({}) + @unittest.skipUnless( + _HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") + def test_service_name_from_kwargs(self): + client = MongoClient( + 'mongodb+srv://user:password@test22.test.build.10gen.cc', + srvServiceName='customname', connect=False) + self.assertEqual(client._topology_settings._srv_service_name, + 'customname') + client = MongoClient( + 'mongodb+srv://user:password@test22.test.build.10gen.cc' + '/?srvServiceName=shouldbeoverriden', + srvServiceName='customname', connect=False) + self.assertEqual(client._topology_settings._srv_service_name, + 'customname') + client = MongoClient( + 'mongodb+srv://user:password@test22.test.build.10gen.cc' + '/?srvServiceName=customname', + connect=False) + self.assertEqual(client._topology_settings._srv_service_name, + 'customname') + class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" diff --git a/test/uri_options/srv-service-name-option.json b/test/uri_options/srv-service-name-option.json new file mode 100644 index 000000000..049a35bc7 --- /dev/null +++ b/test/uri_options/srv-service-name-option.json @@ -0,0 +1,24 @@ +{ + "tests": [ + { + "description": "SRV URI with custom srvServiceName", + "uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "srvServiceName": "customname" + } + }, + { + "description": "Non-SRV URI with custom srvServiceName", + "uri": "mongodb://example.com/?srvServiceName=customname", + "valid": false, + "warning": true, + "hosts": null, + "auth": null, + "options": {} + } + ] +}