diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 6d8dfaf89..1ec74aad0 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,7 +231,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): - raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index d5eec0134..06c0eca2c 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,7 +125,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(client)}") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index c9e3cadd6..9b00c13e1 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,7 +597,13 @@ class AsyncClientEncryption(Generic[_DocumentType]): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any( + cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__ + ): + raise TypeError( + f"AsyncMongoClient required but given {type(key_vault_client).__name__}" + ) self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -685,9 +691,9 @@ class AsyncClientEncryption(Generic[_DocumentType]): """ if not isinstance(database, AsyncDatabase): - raise TypeError( - f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given" - ) + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a84fbf2e5..9dba97d12 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,7 +2446,9 @@ class _MongoClientErrorHandler: self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(client)}") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 93e24432e..7a41aef31 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,7 +234,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): from pymongo.synchronous.database import Database if not isinstance(database, Database): - raise TypeError(f"Collection requires a Database but {type(database)} given") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): + raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 1cd8ee643..c57a59e09 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,7 +125,9 @@ class Database(common.BaseObject, Generic[_DocumentType]): raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): - raise TypeError(f"MongoClient required but given {type(client)}") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): + raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 3849cf3f2..efef6df9e 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,7 +595,9 @@ class ClientEncryption(Generic[_DocumentType]): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): - raise TypeError(f"MongoClient required but given {type(key_vault_client)}") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__): + raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -683,9 +685,9 @@ class ClientEncryption(Generic[_DocumentType]): """ if not isinstance(database, Database): - raise TypeError( - f"create_encrypted_collection() requires a Database but {type(database)} given" - ) + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): + raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index cec78463b..21fa57b5d 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,7 +2434,9 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): - raise TypeError(f"MongoClient required but given {type(client)}") + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): + raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/test/unified_format.py b/test/unified_format.py index 63cd23af8..78fc63878 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -580,7 +580,7 @@ class EntityMapUtil: return elif entity_type == "database": client = self[spec["client"]] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( "Expected entity {} to be of type MongoClient, got {}".format( spec["client"], type(client) @@ -602,7 +602,7 @@ class EntityMapUtil: return elif entity_type == "session": client = self[spec["client"]] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( "Expected entity {} to be of type MongoClient, got {}".format( spec["client"], type(client) @@ -667,7 +667,7 @@ class EntityMapUtil: def get_listener_for_client(self, client_name: str) -> EventListenerUtil: client = self[client_name] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( f"Expected entity {client_name} to be of type MongoClient, got {type(client)}" )