diff --git a/pymongo/_azure_helpers.py b/pymongo/_azure_helpers.py index 8a7af0b40..b7a518185 100644 --- a/pymongo/_azure_helpers.py +++ b/pymongo/_azure_helpers.py @@ -17,6 +17,7 @@ from __future__ import annotations import json from typing import Any, Optional +from urllib.parse import quote def _get_azure_response( @@ -29,7 +30,7 @@ def _get_azure_response( url += "?api-version=2018-02-01" url += f"&resource={resource}" if client_id: - url += f"&client_id={client_id}" + url += f"&client_id={quote(client_id)}" headers = {"Metadata": "true", "Accept": "application/json"} request = Request(url, headers=headers) # noqa: S310 try: diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 9023112b5..dea1161af 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2711,7 +2711,7 @@ class TestClientPool(AsyncMockClientTest): await async_wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) - # Assert that we create 1 pooled connection. + # Wait for the pooled connection to be registered await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) arbiter = c._topology.get_server_by_address(("c", 3)) diff --git a/test/test_azure_helpers.py b/test/test_azure_helpers.py index 6fe645187..25d1a779c 100644 --- a/test/test_azure_helpers.py +++ b/test/test_azure_helpers.py @@ -150,6 +150,20 @@ class TestGetAzureResponse(unittest.TestCase): _, kwargs = mock_open.call_args self.assertEqual(kwargs["timeout"], 42) + def test_client_id_is_url_encoded(self): + """Ensure special characters in client_id are percent-encoded.""" + body = json.dumps({"access_token": "tok", "expires_in": "3600"}) + with _mock_urlopen(200, body) as mock_open: + self._call(client_id="id with spaces&special=chars") + + url = mock_open.call_args[0][0].full_url + # '&' and '=' must be percent-encoded so they don't inject extra query params + self.assertIn("client_id=id%20with%20spaces%26special%3Dchars", url) + # The encoded client_id should not introduce a raw '&' + # Count params: api-version, resource, client_id — exactly 3 + query_string = url.split("?", 1)[1] + self.assertEqual(query_string.count("&"), 2) + if __name__ == "__main__": unittest.main() diff --git a/test/test_client.py b/test/test_client.py index fb400b698..d2d93c6ba 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2666,7 +2666,7 @@ class TestClientPool(MockClientTest): wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(c.address, ("c", 3)) - # Assert that we create 1 pooled connection. + # Wait for the pooled connection to be registered listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) arbiter = c._topology.get_server_by_address(("c", 3))