Merge branch 'master' into PYTHON-3923

This commit is contained in:
Noah Stapp 2026-05-18 10:21:12 -04:00
commit 2dcf4aeacd
4 changed files with 18 additions and 3 deletions

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import json
from typing import Any, Optional
from urllib.parse import quote
def _get_azure_response(
@ -29,7 +30,7 @@ def _get_azure_response(
url += "?api-version=2018-02-01"
url += f"&resource={resource}"
if client_id:
url += f"&client_id={client_id}"
url += f"&client_id={quote(client_id)}"
headers = {"Metadata": "true", "Accept": "application/json"}
request = Request(url, headers=headers) # noqa: S310
try:

View File

@ -2711,7 +2711,7 @@ class TestClientPool(AsyncMockClientTest):
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3))
# Assert that we create 1 pooled connection.
# Wait for the pooled connection to be registered
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))

View File

@ -150,6 +150,20 @@ class TestGetAzureResponse(unittest.TestCase):
_, kwargs = mock_open.call_args
self.assertEqual(kwargs["timeout"], 42)
def test_client_id_is_url_encoded(self):
"""Ensure special characters in client_id are percent-encoded."""
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call(client_id="id with spaces&special=chars")
url = mock_open.call_args[0][0].full_url
# '&' and '=' must be percent-encoded so they don't inject extra query params
self.assertIn("client_id=id%20with%20spaces%26special%3Dchars", url)
# The encoded client_id should not introduce a raw '&'
# Count params: api-version, resource, client_id — exactly 3
query_string = url.split("?", 1)[1]
self.assertEqual(query_string.count("&"), 2)
if __name__ == "__main__":
unittest.main()

View File

@ -2666,7 +2666,7 @@ class TestClientPool(MockClientTest):
wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(c.address, ("c", 3))
# Assert that we create 1 pooled connection.
# Wait for the pooled connection to be registered
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))