diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 156da7b77..3cd7485a9 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -394,6 +394,29 @@ def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: metadata["platform"] = plat else: metadata.pop("platform", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 5. Truncate driver info. + overflow = encoded_size - _MAX_METADATA_SIZE + driver = metadata.get("driver", {}) + if driver: + # Truncate driver version. + driver_version = driver.get("version")[:-overflow] + if len(driver_version) >= len(_METADATA["driver"]["version"]): + metadata["driver"]["version"] = driver_version + else: + metadata["driver"]["version"] = _METADATA["driver"]["version"] + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # Truncate driver name. + overflow = encoded_size - _MAX_METADATA_SIZE + driver_name = driver.get("name")[:-overflow] + if len(driver_name) >= len(_METADATA["driver"]["name"]): + metadata["driver"]["name"] = driver_name + else: + metadata["driver"]["name"] = _METADATA["driver"]["name"] # If the first getaddrinfo call of this interpreter's life is on a thread, diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 3918116ba..b71cec220 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -394,6 +394,29 @@ def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: metadata["platform"] = plat else: metadata.pop("platform", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 5. Truncate driver info. + overflow = encoded_size - _MAX_METADATA_SIZE + driver = metadata.get("driver", {}) + if driver: + # Truncate driver version. + driver_version = driver.get("version")[:-overflow] + if len(driver_version) >= len(_METADATA["driver"]["version"]): + metadata["driver"]["version"] = driver_version + else: + metadata["driver"]["version"] = _METADATA["driver"]["version"] + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # Truncate driver name. + overflow = encoded_size - _MAX_METADATA_SIZE + driver_name = driver.get("name")[:-overflow] + if len(driver_name) >= len(_METADATA["driver"]["name"]): + metadata["driver"]["name"] = driver_name + else: + metadata["driver"]["name"] = _METADATA["driver"]["name"] # If the first getaddrinfo call of this interpreter's life is on a thread, diff --git a/test/test_client.py b/test/test_client.py index 64b1addbb..503c2e6e3 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -72,6 +72,7 @@ from test.utils import ( wait_until, ) +import bson import pymongo from bson import encode from bson.codec_options import ( @@ -106,6 +107,7 @@ from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent from pymongo.synchronous.pool import ( + _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, Connection, @@ -361,6 +363,25 @@ class ClientUnitTest(unittest.TestCase): ) options = client.options self.assertEqual(options.pool_options.metadata, metadata) + # Test truncating driver info metadata. + client = MongoClient( + driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), + connect=False, + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + client = MongoClient( + driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), + connect=False, + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) def test_container_metadata(self):