PYTHON-5306 - Fix use of public MongoClient attributes before connection (#2285)
This commit is contained in:
parent
d51c70b401
commit
412d0005b8
@ -10,10 +10,13 @@ Version 4.12.1 is a bug fix release.
|
||||
- Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL.
|
||||
- Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels.
|
||||
- Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``.
|
||||
- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string
|
||||
could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description,
|
||||
nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type
|
||||
errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()".
|
||||
- Removed Eventlet testing against Python versions newer than 3.9 since
|
||||
Eventlet is actively being sunset by its maintainers and has compatibility issues with PyMongo's dnspython dependency.
|
||||
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
|
||||
@ -109,6 +109,7 @@ from pymongo.operations import (
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
@ -779,7 +780,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
keyword_opts["document_class"] = doc_class
|
||||
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||
|
||||
seeds = set()
|
||||
self._seeds = set()
|
||||
is_srv = False
|
||||
username = None
|
||||
password = None
|
||||
@ -804,18 +805,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
is_srv = entity.startswith(SRV_SCHEME)
|
||||
seeds.update(res["nodelist"])
|
||||
self._seeds.update(res["nodelist"])
|
||||
username = res["username"] or username
|
||||
password = res["password"] or password
|
||||
dbase = res["database"] or dbase
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
if not seeds:
|
||||
self._seeds.update(split_hosts(entity, self._port))
|
||||
if not self._seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
for hostname in [node[0] for node in seeds]:
|
||||
for hostname in [node[0] for node in self._seeds]:
|
||||
if _detect_external_db(hostname):
|
||||
break
|
||||
|
||||
@ -838,7 +839,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
opts = self._normalize_and_validate_options(opts, self._seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", username)
|
||||
@ -857,7 +858,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"username": username,
|
||||
"password": password,
|
||||
"dbase": dbase,
|
||||
"seeds": seeds,
|
||||
"seeds": self._seeds,
|
||||
"fqdn": fqdn,
|
||||
"srv_service_name": srv_service_name,
|
||||
"pool_class": pool_class,
|
||||
@ -873,8 +874,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._options.read_concern,
|
||||
)
|
||||
|
||||
if not is_srv:
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
@ -975,6 +975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
|
||||
)
|
||||
if self._options.auto_encryption_opts:
|
||||
from pymongo.asynchronous.encryption import _Encrypter
|
||||
@ -1205,6 +1206,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
|
||||
return TopologyDescription(
|
||||
TOPOLOGY_TYPE.Unknown,
|
||||
servers,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self._topology_settings,
|
||||
)
|
||||
return self._topology.description
|
||||
|
||||
@property
|
||||
@ -1218,6 +1229,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
to any servers, or a network partition causes it to lose connection
|
||||
to all servers.
|
||||
"""
|
||||
if self._topology is None:
|
||||
return frozenset()
|
||||
description = self._topology.description
|
||||
return frozenset(s.address for s in description.known_servers)
|
||||
|
||||
@ -1576,6 +1589,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
topology_type = self._topology._description.topology_type
|
||||
if (
|
||||
topology_type == TOPOLOGY_TYPE.Sharded
|
||||
@ -1598,6 +1613,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
AsyncMongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
return await self._topology.get_primary() # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
@ -1611,6 +1628,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
AsyncMongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
return await self._topology.get_secondaries()
|
||||
|
||||
@property
|
||||
@ -1621,6 +1640,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
connected to a replica set, there are no arbiters, or this client was
|
||||
created without the `replicaSet` option.
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
return await self._topology.get_arbiters()
|
||||
|
||||
@property
|
||||
|
||||
@ -51,6 +51,7 @@ class TopologySettings:
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
@ -78,8 +79,10 @@ class TopologySettings:
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
|
||||
self._topology_id = ObjectId()
|
||||
if topology_id is not None:
|
||||
self._topology_id = topology_id
|
||||
else:
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack()[:-2])
|
||||
|
||||
@ -101,6 +101,7 @@ from pymongo.operations import (
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import client_session, database, uri_parser
|
||||
@ -777,7 +778,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
keyword_opts["document_class"] = doc_class
|
||||
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||
|
||||
seeds = set()
|
||||
self._seeds = set()
|
||||
is_srv = False
|
||||
username = None
|
||||
password = None
|
||||
@ -802,18 +803,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
is_srv = entity.startswith(SRV_SCHEME)
|
||||
seeds.update(res["nodelist"])
|
||||
self._seeds.update(res["nodelist"])
|
||||
username = res["username"] or username
|
||||
password = res["password"] or password
|
||||
dbase = res["database"] or dbase
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
if not seeds:
|
||||
self._seeds.update(split_hosts(entity, self._port))
|
||||
if not self._seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
for hostname in [node[0] for node in seeds]:
|
||||
for hostname in [node[0] for node in self._seeds]:
|
||||
if _detect_external_db(hostname):
|
||||
break
|
||||
|
||||
@ -836,7 +837,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
opts = self._normalize_and_validate_options(opts, self._seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", username)
|
||||
@ -855,7 +856,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"username": username,
|
||||
"password": password,
|
||||
"dbase": dbase,
|
||||
"seeds": seeds,
|
||||
"seeds": self._seeds,
|
||||
"fqdn": fqdn,
|
||||
"srv_service_name": srv_service_name,
|
||||
"pool_class": pool_class,
|
||||
@ -871,8 +872,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._options.read_concern,
|
||||
)
|
||||
|
||||
if not is_srv:
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
@ -973,6 +973,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
|
||||
)
|
||||
if self._options.auto_encryption_opts:
|
||||
from pymongo.synchronous.encryption import _Encrypter
|
||||
@ -1203,6 +1204,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
|
||||
return TopologyDescription(
|
||||
TOPOLOGY_TYPE.Unknown,
|
||||
servers,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self._topology_settings,
|
||||
)
|
||||
return self._topology.description
|
||||
|
||||
@property
|
||||
@ -1216,6 +1227,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
to any servers, or a network partition causes it to lose connection
|
||||
to all servers.
|
||||
"""
|
||||
if self._topology is None:
|
||||
return frozenset()
|
||||
description = self._topology.description
|
||||
return frozenset(s.address for s in description.known_servers)
|
||||
|
||||
@ -1570,6 +1583,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
topology_type = self._topology._description.topology_type
|
||||
if (
|
||||
topology_type == TOPOLOGY_TYPE.Sharded
|
||||
@ -1592,6 +1607,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
MongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
return self._topology.get_primary() # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
@ -1605,6 +1622,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
MongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
return self._topology.get_secondaries()
|
||||
|
||||
@property
|
||||
@ -1615,6 +1634,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
connected to a replica set, there are no arbiters, or this client was
|
||||
created without the `replicaSet` option.
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
return self._topology.get_arbiters()
|
||||
|
||||
@property
|
||||
|
||||
@ -51,6 +51,7 @@ class TopologySettings:
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
@ -78,8 +79,10 @@ class TopologySettings:
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
|
||||
self._topology_id = ObjectId()
|
||||
if topology_id is not None:
|
||||
self._topology_id = topology_id
|
||||
else:
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack()[:-2])
|
||||
|
||||
@ -850,6 +850,58 @@ class TestClient(AsyncIntegrationTest):
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await c.pymongo_test.test.find_one()
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
@async_client_context.require_no_load_balancer
|
||||
@async_client_context.require_tls
|
||||
async def test_init_disconnected_with_srv(self):
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# nodes returns an empty set if not connected
|
||||
self.assertEqual(c.nodes, frozenset())
|
||||
# topology_description returns the initial seed description if not connected
|
||||
topology_description = c.topology_description
|
||||
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
|
||||
self.assertEqual(
|
||||
{
|
||||
("test1.test.build.10gen.cc", None): ServerDescription(
|
||||
("test1.test.build.10gen.cc", None)
|
||||
)
|
||||
},
|
||||
topology_description.server_descriptions(),
|
||||
)
|
||||
|
||||
# address causes client to block until connected
|
||||
self.assertIsNotNone(await c.address)
|
||||
# Initial seed topology and connected topology have the same ID
|
||||
self.assertEqual(
|
||||
c._topology._topology_id, topology_description._topology_settings._topology_id
|
||||
)
|
||||
await c.close()
|
||||
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# primary causes client to block until connected
|
||||
await c.primary
|
||||
self.assertIsNotNone(c._topology)
|
||||
await c.close()
|
||||
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# secondaries causes client to block until connected
|
||||
await c.secondaries
|
||||
self.assertIsNotNone(c._topology)
|
||||
await c.close()
|
||||
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# arbiters causes client to block until connected
|
||||
await c.arbiters
|
||||
self.assertIsNotNone(c._topology)
|
||||
|
||||
async def test_equality(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
c = await self.async_rs_or_single_client(seed, connect=False)
|
||||
|
||||
@ -825,6 +825,58 @@ class TestClient(IntegrationTest):
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
c.pymongo_test.test.find_one()
|
||||
|
||||
@client_context.require_no_standalone
|
||||
@client_context.require_no_load_balancer
|
||||
@client_context.require_tls
|
||||
def test_init_disconnected_with_srv(self):
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# nodes returns an empty set if not connected
|
||||
self.assertEqual(c.nodes, frozenset())
|
||||
# topology_description returns the initial seed description if not connected
|
||||
topology_description = c.topology_description
|
||||
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
|
||||
self.assertEqual(
|
||||
{
|
||||
("test1.test.build.10gen.cc", None): ServerDescription(
|
||||
("test1.test.build.10gen.cc", None)
|
||||
)
|
||||
},
|
||||
topology_description.server_descriptions(),
|
||||
)
|
||||
|
||||
# address causes client to block until connected
|
||||
self.assertIsNotNone(c.address)
|
||||
# Initial seed topology and connected topology have the same ID
|
||||
self.assertEqual(
|
||||
c._topology._topology_id, topology_description._topology_settings._topology_id
|
||||
)
|
||||
c.close()
|
||||
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# primary causes client to block until connected
|
||||
c.primary
|
||||
self.assertIsNotNone(c._topology)
|
||||
c.close()
|
||||
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# secondaries causes client to block until connected
|
||||
c.secondaries
|
||||
self.assertIsNotNone(c._topology)
|
||||
c.close()
|
||||
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# arbiters causes client to block until connected
|
||||
c.arbiters
|
||||
self.assertIsNotNone(c._topology)
|
||||
|
||||
def test_equality(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
c = self.rs_or_single_client(seed, connect=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user