From 039db2f20a8f8133fce65c75b819aa03c6854dd8 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 08:46:44 -0400 Subject: [PATCH 1/2] PYTHON-4590 - Make type guards more compatible (#1850) --- pymongo/asynchronous/collection.py | 3 ++- pymongo/asynchronous/database.py | 3 ++- pymongo/asynchronous/encryption.py | 10 ++++++---- pymongo/asynchronous/mongo_client.py | 3 ++- pymongo/synchronous/collection.py | 3 ++- pymongo/synchronous/database.py | 3 ++- pymongo/synchronous/encryption.py | 8 ++++---- pymongo/synchronous/mongo_client.py | 3 ++- test/unified_format.py | 6 +++--- 9 files changed, 25 insertions(+), 17 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 6d8dfaf89..a0b727dc7 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,7 +231,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): - raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given") + if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index d5eec0134..fb042972b 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,7 +125,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(client)}") + if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index c9e3cadd6..b03af1b8a 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,7 +597,10 @@ class AsyncClientEncryption(Generic[_DocumentType]): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}") + if not any(cls.__name__ == "AsyncMongoClient" for cls in key_vault_client.__mro__): + raise TypeError( + f"AsyncMongoClient required but given {type(key_vault_client).__name__}" + ) self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -685,9 +688,8 @@ class AsyncClientEncryption(Generic[_DocumentType]): """ if not isinstance(database, AsyncDatabase): - raise TypeError( - f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given" - ) + if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a84fbf2e5..20cc65d9d 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,7 +2446,8 @@ class _MongoClientErrorHandler: self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(client)}") + if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 93e24432e..ff02c65af 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,7 +234,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]): from pymongo.synchronous.database import Database if not isinstance(database, Database): - raise TypeError(f"Collection requires a Database but {type(database)} given") + if not any(cls.__name__ == "Database" for cls in database.__mro__): + raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 1cd8ee643..5f499fff6 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,7 +125,8 @@ class Database(common.BaseObject, Generic[_DocumentType]): raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): - raise TypeError(f"MongoClient required but given {type(client)}") + if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 3849cf3f2..8c6411feb 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,7 +595,8 @@ class ClientEncryption(Generic[_DocumentType]): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): - raise TypeError(f"MongoClient required but given {type(key_vault_client)}") + if not any(cls.__name__ == "MongoClient" for cls in key_vault_client.__mro__): + raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -683,9 +684,8 @@ class ClientEncryption(Generic[_DocumentType]): """ if not isinstance(database, Database): - raise TypeError( - f"create_encrypted_collection() requires a Database but {type(database)} given" - ) + if not any(cls.__name__ == "Database" for cls in database.__mro__): + raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index cec78463b..ac697405d 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,7 +2434,8 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): - raise TypeError(f"MongoClient required but given {type(client)}") + if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/test/unified_format.py b/test/unified_format.py index 63cd23af8..78fc63878 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -580,7 +580,7 @@ class EntityMapUtil: return elif entity_type == "database": client = self[spec["client"]] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( "Expected entity {} to be of type MongoClient, got {}".format( spec["client"], type(client) @@ -602,7 +602,7 @@ class EntityMapUtil: return elif entity_type == "session": client = self[spec["client"]] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( "Expected entity {} to be of type MongoClient, got {}".format( spec["client"], type(client) @@ -667,7 +667,7 @@ class EntityMapUtil: def get_listener_for_client(self, client_name: str) -> EventListenerUtil: client = self[client_name] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( f"Expected entity {client_name} to be of type MongoClient, got {type(client)}" ) From 63d957c2137cec66821d4d1669ff7a24f4c4f6f0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 11:22:22 -0400 Subject: [PATCH 2/2] PYTHON-4590 - Fix MRO type guards (#1852) --- pymongo/asynchronous/collection.py | 3 ++- pymongo/asynchronous/database.py | 3 ++- pymongo/asynchronous/encryption.py | 8 ++++++-- pymongo/asynchronous/mongo_client.py | 3 ++- pymongo/synchronous/collection.py | 3 ++- pymongo/synchronous/database.py | 3 ++- pymongo/synchronous/encryption.py | 6 ++++-- pymongo/synchronous/mongo_client.py | 3 ++- 8 files changed, 22 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index a0b727dc7..1ec74aad0 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,7 +231,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): - if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index fb042972b..06c0eca2c 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,7 +125,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index b03af1b8a..9b00c13e1 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,7 +597,10 @@ class AsyncClientEncryption(Generic[_DocumentType]): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in key_vault_client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any( + cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__ + ): raise TypeError( f"AsyncMongoClient required but given {type(key_vault_client).__name__}" ) @@ -688,7 +691,8 @@ class AsyncClientEncryption(Generic[_DocumentType]): """ if not isinstance(database, AsyncDatabase): - if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 20cc65d9d..9dba97d12 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,7 +2446,8 @@ class _MongoClientErrorHandler: self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index ff02c65af..7a41aef31 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,7 +234,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]): from pymongo.synchronous.database import Database if not isinstance(database, Database): - if not any(cls.__name__ == "Database" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 5f499fff6..c57a59e09 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,7 +125,8 @@ class Database(common.BaseObject, Generic[_DocumentType]): raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 8c6411feb..efef6df9e 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,7 +595,8 @@ class ClientEncryption(Generic[_DocumentType]): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in key_vault_client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__): raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers @@ -684,7 +685,8 @@ class ClientEncryption(Generic[_DocumentType]): """ if not isinstance(database, Database): - if not any(cls.__name__ == "Database" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index ac697405d..21fa57b5d 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,7 +2434,8 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client