diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ee08ed3ef..5a4df213d 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -650,7 +650,10 @@ class MongoClient(common.BaseObject): opts = common._CaseInsensitiveDictionary() fqdn = None for entity in host: - if "://" in entity: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: # Determine connection timeout from kwargs. timeout = keyword_opts.get("connecttimeoutms") if timeout is not None: diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index 42be08a4d..c53ce378e 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -14,6 +14,8 @@ """Support for resolving hosts and options from mongodb+srv:// URIs.""" +import ipaddress + try: from dns import resolver _HAVE_DNSPYTHON = True @@ -40,6 +42,9 @@ def _resolve(*args, **kwargs): # dnspython 1.X return resolver.query(*args, **kwargs) +_INVALID_HOST_MSG = ( + "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " + "Did you mean to use 'mongodb://'?") class _SrvResolver(object): def __init__(self, fqdn, connect_timeout=None): @@ -47,13 +52,19 @@ class _SrvResolver(object): self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT # Validate the fully qualified domain name. + try: + ipaddress.ip_address(fqdn) + raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) + except ValueError: + pass + try: self.__plist = self.__fqdn.split(".")[1:] except Exception: - raise ConfigurationError("Invalid URI host: %s" % (fqdn,)) + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) self.__slen = len(self.__plist) if self.__slen < 2: - raise ConfigurationError("Invalid URI host: %s" % (fqdn,)) + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) def get_options(self): try: diff --git a/test/test_client.py b/test/test_client.py index f97691176..26a0af5ae 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -166,6 +166,20 @@ class ClientUnitTest(unittest.TestCase): with self.assertRaises(ValueError): MongoClient(maxPoolSize=0) + def test_uri_detection(self): + self.assertRaises( + ConfigurationError, + MongoClient, + "/foo") + self.assertRaises( + ConfigurationError, + MongoClient, + "://") + self.assertRaises( + ConfigurationError, + MongoClient, + "foo/") + def test_get_db(self): def make_db(base, name): return base[name] diff --git a/test/test_dns.py b/test/test_dns.py index 16814063c..2cca4d448 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -142,12 +142,20 @@ class TestParsingErrors(unittest.TestCase): def test_invalid_host(self): self.assertRaisesRegex( ConfigurationError, - "Invalid URI host: mongodb", + "Invalid URI host: mongodb is not", MongoClient, "mongodb+srv://mongodb") self.assertRaisesRegex( ConfigurationError, - "Invalid URI host: mongodb.com", + "Invalid URI host: mongodb.com is not", MongoClient, "mongodb+srv://mongodb.com") + self.assertRaisesRegex( + ConfigurationError, + "Invalid URI host: an IP address is not", + MongoClient, "mongodb+srv://127.0.0.1") + self.assertRaisesRegex( + ConfigurationError, + "Invalid URI host: an IP address is not", + MongoClient, "mongodb+srv://[::1]") if __name__ == '__main__':