PYTHON-5306 - Fix use of public MongoClient attributes before connection (#2285)

This commit is contained in:
Noah Stapp 2025-04-21 15:53:21 -04:00 committed by GitHub
parent d51c70b401
commit 412d0005b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 178 additions and 23 deletions

View File

@ -10,10 +10,13 @@ Version 4.12.1 is a bug fix release.
- Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL.
- Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels.
- Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``.
- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string
could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description,
nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type
errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()".
- Removed Eventlet testing against Python versions newer than 3.9 since
Eventlet is actively being sunset by its maintainers and has compatibility issues with PyMongo's dnspython dependency.
Issues Resolved
...............

View File

@ -109,6 +109,7 @@ from pymongo.operations import (
)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@ -779,7 +780,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
seeds = set()
self._seeds = set()
is_srv = False
username = None
password = None
@ -804,18 +805,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
self._seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
self._seeds.update(split_hosts(entity, self._port))
if not self._seeds:
raise ConfigurationError("need to specify at least one host")
for hostname in [node[0] for node in seeds]:
for hostname in [node[0] for node in self._seeds]:
if _detect_external_db(hostname):
break
@ -838,7 +839,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
opts = self._normalize_and_validate_options(opts, self._seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
@ -857,7 +858,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"seeds": self._seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
@ -873,8 +874,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
@ -975,6 +975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
if self._options.auto_encryption_opts:
from pymongo.asynchronous.encryption import _Encrypter
@ -1205,6 +1206,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 4.0
"""
if self._topology is None:
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
return TopologyDescription(
TOPOLOGY_TYPE.Unknown,
servers,
None,
None,
None,
self._topology_settings,
)
return self._topology.description
@property
@ -1218,6 +1229,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
to any servers, or a network partition causes it to lose connection
to all servers.
"""
if self._topology is None:
return frozenset()
description = self._topology.description
return frozenset(s.address for s in description.known_servers)
@ -1576,6 +1589,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
"""
if self._topology is None:
await self._get_topology()
topology_type = self._topology._description.topology_type
if (
topology_type == TOPOLOGY_TYPE.Sharded
@ -1598,6 +1613,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
AsyncMongoClient gained this property in version 3.0.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_primary() # type: ignore[return-value]
@property
@ -1611,6 +1628,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
AsyncMongoClient gained this property in version 3.0.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_secondaries()
@property
@ -1621,6 +1640,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
connected to a replica set, there are no arbiters, or this client was
created without the `replicaSet` option.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_arbiters()
@property

View File

@ -51,6 +51,7 @@ class TopologySettings:
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
"""Represent MongoClient's configuration.
@ -78,8 +79,10 @@ class TopologySettings:
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode
self._topology_id = ObjectId()
if topology_id is not None:
self._topology_id = topology_id
else:
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])

View File

@ -101,6 +101,7 @@ from pymongo.operations import (
)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, uri_parser
@ -777,7 +778,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
seeds = set()
self._seeds = set()
is_srv = False
username = None
password = None
@ -802,18 +803,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
self._seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
self._seeds.update(split_hosts(entity, self._port))
if not self._seeds:
raise ConfigurationError("need to specify at least one host")
for hostname in [node[0] for node in seeds]:
for hostname in [node[0] for node in self._seeds]:
if _detect_external_db(hostname):
break
@ -836,7 +837,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
opts = self._normalize_and_validate_options(opts, self._seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
@ -855,7 +856,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"seeds": self._seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
@ -871,8 +872,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
@ -973,6 +973,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
if self._options.auto_encryption_opts:
from pymongo.synchronous.encryption import _Encrypter
@ -1203,6 +1204,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 4.0
"""
if self._topology is None:
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
return TopologyDescription(
TOPOLOGY_TYPE.Unknown,
servers,
None,
None,
None,
self._topology_settings,
)
return self._topology.description
@property
@ -1216,6 +1227,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
to any servers, or a network partition causes it to lose connection
to all servers.
"""
if self._topology is None:
return frozenset()
description = self._topology.description
return frozenset(s.address for s in description.known_servers)
@ -1570,6 +1583,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
"""
if self._topology is None:
self._get_topology()
topology_type = self._topology._description.topology_type
if (
topology_type == TOPOLOGY_TYPE.Sharded
@ -1592,6 +1607,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_primary() # type: ignore[return-value]
@property
@ -1605,6 +1622,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_secondaries()
@property
@ -1615,6 +1634,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
connected to a replica set, there are no arbiters, or this client was
created without the `replicaSet` option.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_arbiters()
@property

View File

@ -51,6 +51,7 @@ class TopologySettings:
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
"""Represent MongoClient's configuration.
@ -78,8 +79,10 @@ class TopologySettings:
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode
self._topology_id = ObjectId()
if topology_id is not None:
self._topology_id = topology_id
else:
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])

View File

@ -850,6 +850,58 @@ class TestClient(AsyncIntegrationTest):
with self.assertRaises(ConnectionFailure):
await c.pymongo_test.test.find_one()
@async_client_context.require_no_standalone
@async_client_context.require_no_load_balancer
@async_client_context.require_tls
async def test_init_disconnected_with_srv(self):
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# nodes returns an empty set if not connected
self.assertEqual(c.nodes, frozenset())
# topology_description returns the initial seed description if not connected
topology_description = c.topology_description
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
self.assertEqual(
{
("test1.test.build.10gen.cc", None): ServerDescription(
("test1.test.build.10gen.cc", None)
)
},
topology_description.server_descriptions(),
)
# address causes client to block until connected
self.assertIsNotNone(await c.address)
# Initial seed topology and connected topology have the same ID
self.assertEqual(
c._topology._topology_id, topology_description._topology_settings._topology_id
)
await c.close()
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# primary causes client to block until connected
await c.primary
self.assertIsNotNone(c._topology)
await c.close()
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# secondaries causes client to block until connected
await c.secondaries
self.assertIsNotNone(c._topology)
await c.close()
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# arbiters causes client to block until connected
await c.arbiters
self.assertIsNotNone(c._topology)
async def test_equality(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = await self.async_rs_or_single_client(seed, connect=False)

View File

@ -825,6 +825,58 @@ class TestClient(IntegrationTest):
with self.assertRaises(ConnectionFailure):
c.pymongo_test.test.find_one()
@client_context.require_no_standalone
@client_context.require_no_load_balancer
@client_context.require_tls
def test_init_disconnected_with_srv(self):
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# nodes returns an empty set if not connected
self.assertEqual(c.nodes, frozenset())
# topology_description returns the initial seed description if not connected
topology_description = c.topology_description
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
self.assertEqual(
{
("test1.test.build.10gen.cc", None): ServerDescription(
("test1.test.build.10gen.cc", None)
)
},
topology_description.server_descriptions(),
)
# address causes client to block until connected
self.assertIsNotNone(c.address)
# Initial seed topology and connected topology have the same ID
self.assertEqual(
c._topology._topology_id, topology_description._topology_settings._topology_id
)
c.close()
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# primary causes client to block until connected
c.primary
self.assertIsNotNone(c._topology)
c.close()
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# secondaries causes client to block until connected
c.secondaries
self.assertIsNotNone(c._topology)
c.close()
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# arbiters causes client to block until connected
c.arbiters
self.assertIsNotNone(c._topology)
def test_equality(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = self.rs_or_single_client(seed, connect=False)