From 992d1507e7a49292b555cc0201d95d49234c6228 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 19 Oct 2023 11:56:22 -0500 Subject: [PATCH] PYTHON-4005 Replace flake8 and isort with ruff (#1399) --- .flake8 | 30 -------- .pre-commit-config.yaml | 22 ++---- bson/__init__.py | 58 ++++++++-------- bson/binary.py | 2 +- bson/codec_options.py | 2 +- bson/datetime_ms.py | 8 +-- bson/dbref.py | 2 +- bson/decimal128.py | 2 +- bson/errors.py | 1 + bson/json_util.py | 7 +- bson/objectid.py | 8 +-- bson/raw_bson.py | 11 +-- bson/son.py | 6 +- doc/conf.py | 14 ++-- green_framework_test.py | 8 +-- gridfs/__init__.py | 4 +- gridfs/errors.py | 1 + gridfs/grid_file.py | 20 +++--- pymongo/_version.py | 2 + pymongo/auth.py | 6 +- pymongo/auth_aws.py | 6 +- pymongo/auth_oidc.py | 5 +- pymongo/bulk.py | 2 +- pymongo/change_stream.py | 6 +- pymongo/client_session.py | 25 ++++--- pymongo/collection.py | 16 ++--- pymongo/command_cursor.py | 6 +- pymongo/common.py | 64 ++++++++--------- pymongo/compression_support.py | 13 ++-- pymongo/cursor.py | 50 +++++++------- pymongo/daemon.py | 10 ++- pymongo/database.py | 18 ++--- pymongo/driver_info.py | 5 +- pymongo/encryption.py | 9 ++- pymongo/helpers.py | 4 +- pymongo/max_staleness_selectors.py | 4 +- pymongo/message.py | 24 +++---- pymongo/mongo_client.py | 13 ++-- pymongo/monitoring.py | 62 ++++++++--------- pymongo/network.py | 12 ++-- pymongo/ocsp_support.py | 4 +- pymongo/operations.py | 2 +- pymongo/pool.py | 10 +-- pymongo/pyopenssl_context.py | 22 +++--- pymongo/read_preferences.py | 4 +- pymongo/results.py | 4 +- pymongo/saslprep.py | 4 +- pymongo/server_api.py | 4 +- pymongo/server_description.py | 2 +- pymongo/srv_resolver.py | 8 +-- pymongo/ssl_support.py | 8 ++- pymongo/topology.py | 3 +- pymongo/topology_description.py | 8 +-- pymongo/uri_parser.py | 18 ++--- pyproject.toml | 68 +++++++++++++++++++ setup.py | 10 ++- test/__init__.py | 14 ++-- test/atlas/test_connection.py | 1 + test/auth_aws/test_auth_aws.py | 1 + test/auth_oidc/test_auth_oidc.py | 1 + test/conftest.py | 2 + test/crud_v2_format.py | 1 + test/lambda/mongodb/app.py | 2 + test/mockupdb/operations.py | 1 + test/mockupdb/test_auth_recovering_member.py | 1 + test/mockupdb/test_cluster_time.py | 1 + test/mockupdb/test_cursor_namespace.py | 1 + test/mockupdb/test_getmore_sharded.py | 2 + test/mockupdb/test_handshake.py | 17 +++-- test/mockupdb/test_initial_ismaster.py | 1 + test/mockupdb/test_list_indexes.py | 1 + test/mockupdb/test_max_staleness.py | 1 + test/mockupdb/test_mixed_version_sharded.py | 1 + .../mockupdb/test_mongos_command_read_mode.py | 1 + .../test_network_disconnect_primary.py | 1 + test/mockupdb/test_op_msg.py | 1 + test/mockupdb/test_op_msg_read_preference.py | 1 + test/mockupdb/test_query_read_pref_sharded.py | 1 + test/mockupdb/test_reset_and_request_check.py | 1 + test/mockupdb/test_rsghost.py | 1 + test/mockupdb/test_slave_okay_rs.py | 3 +- test/mockupdb/test_slave_okay_sharded.py | 7 +- test/mockupdb/test_slave_okay_single.py | 1 + test/mod_wsgi_test/mod_wsgi_test.py | 5 +- test/mod_wsgi_test/test_client.py | 1 + test/mypy_fails/insert_many_dict.py | 2 + test/mypy_fails/insert_one_list.py | 2 + test/mypy_fails/raw_bson_document.py | 2 + test/mypy_fails/typedict_client.py | 2 + test/ocsp/test_ocsp.py | 1 + test/performance/perf_test.py | 1 + test/pymongo_mocks.py | 1 + test/qcheck.py | 1 + test/sigstop_sigcont.py | 1 + test/test_auth.py | 5 +- test/test_auth_spec.py | 1 + test/test_binary.py | 1 + test/test_bson.py | 1 + test/test_bson_corpus.py | 1 + test/test_bulk.py | 9 +-- test/test_change_stream.py | 1 + test/test_client.py | 13 ++-- test/test_client_context.py | 5 +- test/test_cmap.py | 1 + test/test_code.py | 1 + test/test_collation.py | 1 + test/test_collection.py | 17 ++--- test/test_collection_management.py | 1 + test/test_command_monitoring.py | 1 + test/test_common.py | 1 + ...nnections_survive_primary_stepdown_spec.py | 1 + test/test_create_entities.py | 2 + test/test_crud_unified.py | 1 + test/test_crud_v1.py | 1 + test/test_csot.py | 1 + test/test_cursor.py | 6 +- test/test_custom_types.py | 1 + test/test_data_lake.py | 1 + test/test_database.py | 1 + test/test_dbref.py | 1 + test/test_decimal128.py | 1 + test/test_default_exports.py | 2 + test/test_discovery_and_monitoring.py | 1 + test/test_dns.py | 1 + test/test_encryption.py | 3 +- test/test_errors.py | 1 + test/test_examples.py | 1 + test/test_fork.py | 1 + test/test_grid_file.py | 1 + test/test_gridfs.py | 1 + test/test_gridfs_bucket.py | 2 + test/test_gridfs_spec.py | 1 + test/test_heartbeat_monitoring.py | 1 + test/test_json_util.py | 1 + test/test_load_balancer.py | 1 + test/test_max_staleness.py | 1 + test/test_mongos_load_balancing.py | 3 +- test/test_monitor.py | 1 + test/test_monitoring.py | 1 + test/test_objectid.py | 1 + test/test_ocsp_cache.py | 1 + test/test_on_demand_csfle.py | 2 + test/test_pooling.py | 1 + test/test_pymongo.py | 1 + test/test_raw_bson.py | 1 + test/test_read_concern.py | 1 + test/test_read_preferences.py | 7 +- test/test_read_write_concern_spec.py | 1 + test/test_replica_set_reconfig.py | 3 +- test/test_retryable_reads.py | 1 + test/test_retryable_reads_unified.py | 1 + test/test_retryable_writes.py | 1 + test/test_retryable_writes_unified.py | 1 + test/test_run_command.py | 2 + test/test_saslprep.py | 3 +- test/test_sdam_monitoring_spec.py | 13 ++-- test/test_server.py | 1 + test/test_server_description.py | 1 + test/test_server_selection.py | 1 + test/test_server_selection_in_window.py | 1 + test/test_server_selection_rtt.py | 1 + test/test_session.py | 1 + test/test_sessions_unified.py | 1 + test/test_son.py | 5 +- test/test_srv_polling.py | 1 + test/test_ssl.py | 28 ++++---- test/test_streaming_protocol.py | 1 + test/test_threads.py | 1 + test/test_timestamp.py | 1 + test/test_topology.py | 5 +- test/test_transactions.py | 1 + test/test_transactions_unified.py | 1 + test/test_typing.py | 4 +- test/test_typing_strict.py | 2 + test/test_unified_format.py | 1 + test/test_uri_parser.py | 1 + test/test_uri_spec.py | 1 + test/test_versioned_api.py | 1 + test/test_write_concern.py | 3 +- test/unicode/test_utf8.py | 2 + test/unified_format.py | 2 + test/utils.py | 3 +- test/utils_selection_tests.py | 1 + test/utils_spec_runner.py | 1 + test/version.py | 1 + tools/clean.py | 15 ++-- tools/ensure_future_annotations_import.py | 12 ++-- tools/fail_if_no_c.py | 11 +-- tools/ocsptest.py | 1 + 189 files changed, 634 insertions(+), 454 deletions(-) delete mode 100644 .flake8 mode change 100755 => 100644 setup.py mode change 100755 => 100644 test/mockupdb/test_auth_recovering_member.py mode change 100755 => 100644 test/mockupdb/test_network_disconnect_primary.py mode change 100755 => 100644 test/mockupdb/test_op_msg.py mode change 100755 => 100644 test/mockupdb/test_reset_and_request_check.py diff --git a/.flake8 b/.flake8 deleted file mode 100644 index e5bc58921..000000000 --- a/.flake8 +++ /dev/null @@ -1,30 +0,0 @@ -[flake8] -max-line-length = 100 -enable-extensions = G -extend-ignore = - G200, G202, - # black adds spaces around ':' - E203, - # E501 line too long (let black handle line length) - E501 - # B305 `.next()` is not a thing on Python 3 - B305 -per-file-ignores = - # E402 module level import not at top of file - pymongo/__init__.py: E402 - - # G004 Logging statement uses f-string - pymongo/event_loggers.py: G004 - - # E402 module level import not at top of file - # B011 Do not call assert False since python -O removes these calls - # F405 'Foo' may be undefined, or defined from star imports - # E741 ambiguous variable name - # B007 Loop control variable 'foo' not used within the loop body - # F403 'from foo import *' used; unable to detect undefined names - # B001 Do not use bare `except:` - # E722 do not use bare 'except' - # E731 do not assign a lambda expression, use a def - # F811 redefinition of unused 'foo' from line XXX - # F841 local variable 'foo' is assigned to but never used - test/*: E402, B011, F405, E741, B007, F403, B001, E722, E731, F811, F841 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 44d9c78ba..3a90118fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,12 +24,12 @@ repos: files: \.py$ args: [--line-length=100] -- repo: https://github.com/PyCQA/isort - rev: 5.12.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.0 hooks: - - id: isort - files: \.py$ - args: [--profile=black] + - id: ruff + args: ["--fix", "--show-fixes"] - repo: https://github.com/adamchainz/blacken-docs rev: "1.13.0" @@ -38,18 +38,6 @@ repos: additional_dependencies: - black==22.3.0 -- repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 - hooks: - - id: flake8 - files: \.py$ - additional_dependencies: [ - 'flake8-bugbear==20.1.4', - 'flake8-logging-format==0.6.0', - 'flake8-implicit-str-concat==0.2.0', - ] - stages: [manual] - # We use the Python version instead of the original version which seems to require Docker # https://github.com/koalaman/shellcheck-precommit - repo: https://github.com/shellcheck-py/shellcheck-py diff --git a/bson/__init__.py b/bson/__init__.py index fca88da3b..2c4bd3a8b 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -244,7 +244,7 @@ def _raise_unknown_type(element_type: int, element_name: str) -> NoReturn: def _get_int( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[int, int]: """Decode a BSON int32 to python int.""" return _UNPACK_INT_FROM(data, position)[0], position + 4 @@ -257,7 +257,7 @@ def _get_c_string(data: Any, view: Any, position: int, opts: CodecOptions[Any]) def _get_float( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[float, int]: """Decode a BSON double to python float.""" return _UNPACK_FLOAT_FROM(data, position)[0], position + 8 @@ -282,7 +282,7 @@ def _get_object_size(data: Any, position: int, obj_end: int) -> Tuple[int, int]: try: obj_size = _UNPACK_INT_FROM(data, position)[0] except struct.error as exc: - raise InvalidBSON(str(exc)) + raise InvalidBSON(str(exc)) from None end = position + obj_size - 1 if data[end] != 0: raise InvalidBSON("bad eoo") @@ -358,7 +358,7 @@ def _get_array( def _get_binary( - data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy1: Any + data: Any, _view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy1: Any ) -> Tuple[Union[Binary, uuid.UUID], int]: """Decode a BSON binary to bson.binary.Binary or python UUID.""" length, subtype = _UNPACK_LENGTH_SUBTYPE_FROM(data, position) @@ -395,7 +395,7 @@ def _get_binary( def _get_oid( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[ObjectId, int]: """Decode a BSON ObjectId to bson.objectid.ObjectId.""" end = position + 12 @@ -403,7 +403,7 @@ def _get_oid( def _get_boolean( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[bool, int]: """Decode a BSON true/false to python True/False.""" end = position + 1 @@ -416,7 +416,7 @@ def _get_boolean( def _get_date( - data: Any, view: Any, position: int, dummy0: int, opts: CodecOptions[Any], dummy1: Any + data: Any, _view: Any, position: int, dummy0: int, opts: CodecOptions[Any], dummy1: Any ) -> Tuple[Union[datetime.datetime, DatetimeMS], int]: """Decode a BSON datetime to python datetime.datetime.""" return _millis_to_datetime(_UNPACK_LONG_FROM(data, position)[0], opts), position + 8 @@ -431,7 +431,7 @@ def _get_code( def _get_code_w_scope( - data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], element_name: str + data: Any, view: Any, position: int, _obj_end: int, opts: CodecOptions[Any], element_name: str ) -> Tuple[Code, int]: """Decode a BSON code_w_scope to bson.code.Code.""" code_end = position + _UNPACK_INT_FROM(data, position)[0] @@ -462,7 +462,7 @@ def _get_ref( def _get_timestamp( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[Timestamp, int]: """Decode a BSON timestamp to bson.timestamp.Timestamp.""" inc, timestamp = _UNPACK_TIMESTAMP_FROM(data, position) @@ -470,14 +470,14 @@ def _get_timestamp( def _get_int64( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[Int64, int]: """Decode a BSON int64 to bson.int64.Int64.""" return Int64(_UNPACK_LONG_FROM(data, position)[0]), position + 8 def _get_decimal128( - data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any ) -> Tuple[Decimal128, int]: """Decode a BSON decimal128 to bson.decimal128.Decimal128.""" end = position + 16 @@ -496,11 +496,11 @@ _ELEMENT_GETTER: dict[int, Callable[..., Tuple[Any, int]]] = { ord(BSONOBJ): _get_object, ord(BSONARR): _get_array, ord(BSONBIN): _get_binary, - ord(BSONUND): lambda u, v, w, x, y, z: (None, w), # Deprecated undefined + ord(BSONUND): lambda u, v, w, x, y, z: (None, w), # noqa: ARG005 # Deprecated undefined ord(BSONOID): _get_oid, ord(BSONBOO): _get_boolean, ord(BSONDAT): _get_date, - ord(BSONNUL): lambda u, v, w, x, y, z: (None, w), + ord(BSONNUL): lambda u, v, w, x, y, z: (None, w), # noqa: ARG005 ord(BSONRGX): _get_regex, ord(BSONREF): _get_ref, # Deprecated DBPointer ord(BSONCOD): _get_code, @@ -510,8 +510,8 @@ _ELEMENT_GETTER: dict[int, Callable[..., Tuple[Any, int]]] = { ord(BSONTIM): _get_timestamp, ord(BSONLON): _get_int64, ord(BSONDEC): _get_decimal128, - ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w), - ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w), + ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w), # noqa: ARG005 + ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w), # noqa: ARG005 } @@ -519,7 +519,7 @@ if _USE_C: def _element_to_dict( data: Any, - view: Any, + view: Any, # noqa: ARG001 position: int, obj_end: int, opts: CodecOptions[Any], @@ -615,11 +615,11 @@ def _bson_to_dict(data: Any, opts: CodecOptions[_DocumentType]) -> _DocumentType except Exception: # Change exception type to InvalidBSON but preserve traceback. _, exc_value, exc_tb = sys.exc_info() - raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) + raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None if _USE_C: - _bson_to_dict = _cbson._bson_to_dict # noqa: F811 + _bson_to_dict = _cbson._bson_to_dict _PACK_FLOAT = struct.Struct(" bytes: _utf_8_decode(string, None, True) return string + b"\x00" except UnicodeError: - raise InvalidStringData("strings in documents must be valid UTF-8: %r" % string) + raise InvalidStringData( + "strings in documents must be valid UTF-8: %r" % string + ) from None else: if "\x00" in string: raise InvalidDocument("BSON keys / regex patterns must not contain a NUL character") @@ -667,7 +669,9 @@ def _make_c_string(string: Union[str, bytes]) -> bytes: _utf_8_decode(string, None, True) return string + b"\x00" except UnicodeError: - raise InvalidStringData("strings in documents must be valid UTF-8: %r" % string) + raise InvalidStringData( + "strings in documents must be valid UTF-8: %r" % string + ) from None else: return _utf_8_encode(string)[0] + b"\x00" @@ -817,7 +821,7 @@ def _encode_int(name: bytes, value: int, dummy0: Any, dummy1: Any) -> bytes: try: return b"\x12" + name + _PACK_LONG(value) except struct.error: - raise OverflowError("BSON can only handle up to 8-byte ints") + raise OverflowError("BSON can only handle up to 8-byte ints") from None def _encode_timestamp(name: bytes, value: Any, dummy0: Any, dummy1: Any) -> bytes: @@ -830,7 +834,7 @@ def _encode_long(name: bytes, value: Any, dummy0: Any, dummy1: Any) -> bytes: try: return b"\x12" + name + _PACK_LONG(value) except struct.error: - raise OverflowError("BSON can only handle up to 8-byte ints") + raise OverflowError("BSON can only handle up to 8-byte ints") from None def _encode_decimal128(name: bytes, value: Decimal128, dummy0: Any, dummy1: Any) -> bytes: @@ -995,14 +999,14 @@ def _dict_to_bson( if not top_level or key != "_id": elements.append(_element_to_bson(key, value, check_keys, opts)) except AttributeError: - raise TypeError(f"encoder expected a mapping type but got: {doc!r}") + raise TypeError(f"encoder expected a mapping type but got: {doc!r}") from None encoded = b"".join(elements) return _PACK_INT(len(encoded) + 5) + encoded + b"\x00" if _USE_C: - _dict_to_bson = _cbson._dict_to_bson # noqa: F811 + _dict_to_bson = _cbson._dict_to_bson _CODEC_OPTIONS_TYPE_ERROR = TypeError("codec_options must be an instance of CodecOptions") @@ -1110,11 +1114,11 @@ def _decode_all(data: _ReadableBuffer, opts: CodecOptions[_DocumentType]) -> lis except Exception: # Change exception type to InvalidBSON but preserve traceback. _, exc_value, exc_tb = sys.exc_info() - raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) + raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None if _USE_C: - _decode_all = _cbson._decode_all # noqa: F811 + _decode_all = _cbson._decode_all @overload @@ -1207,7 +1211,7 @@ def _array_of_documents_to_buffer(view: memoryview) -> bytes: if _USE_C: - _array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer # noqa: F811 + _array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer def _convert_raw_document_lists_to_streams(document: Any) -> None: diff --git a/bson/binary.py b/bson/binary.py index c5a06d5a9..a4cd44e93 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -242,7 +242,7 @@ class Binary(bytes): @classmethod def from_uuid( cls: Type[Binary], uuid: UUID, uuid_representation: int = UuidRepresentation.STANDARD - ) -> "Binary": + ) -> Binary: """Create a BSON Binary object from a Python UUID. Creates a :class:`~bson.binary.Binary` object from a diff --git a/bson/codec_options.py b/bson/codec_options.py index 15da9e4db..2c64c6460 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -168,7 +168,7 @@ class TypeRegistry: if issubclass(cast(TypeCodec, codec).python_type, pytype): err_msg = ( "TypeEncoders cannot change how built-in types are " - "encoded (encoder {} transforms type {})".format(codec, pytype) + f"encoded (encoder {codec} transforms type {pytype})" ) raise TypeError(err_msg) diff --git a/bson/datetime_ms.py b/bson/datetime_ms.py index b4f419175..b6aebd05d 100644 --- a/bson/datetime_ms.py +++ b/bson/datetime_ms.py @@ -75,10 +75,10 @@ class DatetimeMS: def __repr__(self) -> str: return type(self).__name__ + "(" + str(self._value) + ")" - def __lt__(self, other: Union["DatetimeMS", int]) -> bool: + def __lt__(self, other: Union[DatetimeMS, int]) -> bool: return self._value < other - def __le__(self, other: Union["DatetimeMS", int]) -> bool: + def __le__(self, other: Union[DatetimeMS, int]) -> bool: return self._value <= other def __eq__(self, other: Any) -> bool: @@ -91,10 +91,10 @@ class DatetimeMS: return self._value != other._value return True - def __gt__(self, other: Union["DatetimeMS", int]) -> bool: + def __gt__(self, other: Union[DatetimeMS, int]) -> bool: return self._value > other - def __ge__(self, other: Union["DatetimeMS", int]) -> bool: + def __ge__(self, other: Union[DatetimeMS, int]) -> bool: return self._value >= other _type_marker = 9 diff --git a/bson/dbref.py b/bson/dbref.py index fc1ca7e17..50fcf6c02 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -89,7 +89,7 @@ class DBRef: try: return self.__kwargs[key] except KeyError: - raise AttributeError(key) + raise AttributeError(key) from None def as_doc(self) -> SON[str, Any]: """Get the SON document representation of this DBRef. diff --git a/bson/decimal128.py b/bson/decimal128.py index 35f9f46e0..f807452a6 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -298,7 +298,7 @@ class Decimal128: return str(dec) def __repr__(self) -> str: - return f"Decimal128('{str(self)}')" + return f"Decimal128('{self!s}')" def __setstate__(self, value: Tuple[int, int]) -> None: self.__high, self.__low = value diff --git a/bson/errors.py b/bson/errors.py index 7333b27b5..a3699e704 100644 --- a/bson/errors.py +++ b/bson/errors.py @@ -13,6 +13,7 @@ # limitations under the License. """Exceptions raised by the BSON package.""" +from __future__ import annotations class BSONError(Exception): diff --git a/bson/json_util.py b/bson/json_util.py index 2c34b2192..1a74a8136 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -737,8 +737,7 @@ def _parse_canonical_regex(doc: Any) -> Regex[str]: raise TypeError(f"Bad $regularExpression, extra field(s): {doc}") if len(regex) != 2: raise TypeError( - 'Bad $regularExpression must include only "pattern"' - 'and "options" components: {}'.format(doc) + f'Bad $regularExpression must include only "pattern and "options" components: {doc}' ) opts = regex["options"] if not isinstance(opts, str): @@ -812,7 +811,7 @@ def _parse_canonical_decimal128(doc: Any) -> Decimal128: def _parse_canonical_minkey(doc: Any) -> MinKey: """Decode a JSON MinKey to bson.min_key.MinKey.""" - if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: + if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: # noqa: E721 raise TypeError(f"$minKey value must be 1: {doc}") if len(doc) != 1: raise TypeError(f"Bad $minKey, extra field(s): {doc}") @@ -821,7 +820,7 @@ def _parse_canonical_minkey(doc: Any) -> MinKey: def _parse_canonical_maxkey(doc: Any) -> MaxKey: """Decode a JSON MaxKey to bson.max_key.MaxKey.""" - if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: + if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: # noqa: E721 raise TypeError("$maxKey value must be 1: %s", (doc,)) if len(doc) != 1: raise TypeError(f"Bad $minKey, extra field(s): {doc}") diff --git a/bson/objectid.py b/bson/objectid.py index c114a709b..2a3d9ebf5 100644 --- a/bson/objectid.py +++ b/bson/objectid.py @@ -57,7 +57,7 @@ class ObjectId: _type_marker = 7 - def __init__(self, oid: Optional[Union[str, "ObjectId", bytes]] = None) -> None: + def __init__(self, oid: Optional[Union[str, ObjectId, bytes]] = None) -> None: """Initialize a new ObjectId. An ObjectId is a 12-byte unique identifier consisting of: @@ -103,7 +103,7 @@ class ObjectId: self.__validate(oid) @classmethod - def from_datetime(cls: Type["ObjectId"], generation_time: datetime.datetime) -> "ObjectId": + def from_datetime(cls: Type[ObjectId], generation_time: datetime.datetime) -> ObjectId: """Create a dummy ObjectId instance with a specific generation time. This method is useful for doing range queries on a field @@ -138,7 +138,7 @@ class ObjectId: return cls(oid) @classmethod - def is_valid(cls: Type["ObjectId"], oid: Any) -> bool: + def is_valid(cls: Type[ObjectId], oid: Any) -> bool: """Checks if a `oid` string is valid or not. :Parameters: @@ -245,7 +245,7 @@ class ObjectId: return binascii.hexlify(self.__id).decode() def __repr__(self) -> str: - return f"ObjectId('{str(self)}')" + return f"ObjectId('{self!s}')" def __eq__(self, other: Any) -> bool: if isinstance(other, ObjectId): diff --git a/bson/raw_bson.py b/bson/raw_bson.py index d57b59cb7..50362398a 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -55,9 +55,8 @@ from __future__ import annotations from typing import Any, ItemsView, Iterator, Mapping, MutableMapping, Optional from bson import _get_object_size, _raw_to_dict -from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER +from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER, CodecOptions from bson.codec_options import DEFAULT_CODEC_OPTIONS as DEFAULT -from bson.codec_options import CodecOptions from bson.son import SON @@ -135,7 +134,7 @@ class RawBSONDocument(Mapping[str, Any]): elif not issubclass(codec_options.document_class, RawBSONDocument): raise TypeError( "RawBSONDocument cannot use CodecOptions with document " - "class {}".format(codec_options.document_class) + f"class {codec_options.document_class}" ) self.__codec_options = codec_options # Validate the bson object size. @@ -180,11 +179,7 @@ class RawBSONDocument(Mapping[str, Any]): return NotImplemented def __repr__(self) -> str: - return "{}({!r}, codec_options={!r})".format( - self.__class__.__name__, - self.raw, - self.__codec_options, - ) + return f"{self.__class__.__name__}({self.raw!r}, codec_options={self.__codec_options!r})" class _RawArrayBSONDocument(RawBSONDocument): diff --git a/bson/son.py b/bson/son.py index 8efe40ef0..c5df4e597 100644 --- a/bson/son.py +++ b/bson/son.py @@ -139,7 +139,7 @@ class SON(Dict[_Key, _Value]): try: k, v = next(iter(self.items())) except StopIteration: - raise KeyError("container is empty") + raise KeyError("container is empty") from None del self[k] return (k, v) @@ -151,7 +151,7 @@ class SON(Dict[_Key, _Value]): for k, v in other.items(): self[k] = v elif hasattr(other, "keys"): - for k in other.keys(): + for k in other: self[k] = other[k] else: for k, v in other: @@ -204,6 +204,6 @@ class SON(Dict[_Key, _Value]): memo[val_id] = out for k, v in self.items(): if not isinstance(v, RE_TYPE): - v = copy.deepcopy(v, memo) + v = copy.deepcopy(v, memo) # noqa: PLW2901 out[k] = v return out diff --git a/doc/conf.py b/doc/conf.py index cbb525b41..1ea51add8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- # # PyMongo documentation build configuration file # # This file is execfile()d with the current directory set to its containing dir. +from __future__ import annotations -import os import sys +from pathlib import Path -sys.path[0:0] = [os.path.abspath("..")] +sys.path[0:0] = [Path("..").resolve()] -import pymongo # noqa +import pymongo # noqa: E402 # -- General configuration ----------------------------------------------------- @@ -26,7 +26,7 @@ extensions = [ # Add optional extensions try: - import sphinxcontrib.shellcheck # noqa + import sphinxcontrib.shellcheck # noqa: F401 extensions += ["sphinxcontrib.shellcheck"] except ImportError: @@ -94,7 +94,7 @@ linkcheck_ignore = [ # -- Options for extensions ---------------------------------------------------- autoclass_content = "init" -doctest_path = [os.path.abspath("..")] +doctest_path = [Path("..").resolve()] doctest_test_doctest_blocks = "" @@ -108,7 +108,7 @@ db = client.doctest_test # -- Options for HTML output --------------------------------------------------- try: - import furo # noqa + import furo # noqa: F401 html_theme = "furo" except ImportError: diff --git a/green_framework_test.py b/green_framework_test.py index c1615b548..01f72b245 100644 --- a/green_framework_test.py +++ b/green_framework_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Test PyMongo with a variety of greenlet-based monkey-patching frameworks.""" +from __future__ import annotations import getopt import sys @@ -65,13 +66,10 @@ def run(framework_name, *args): def main(): """Parse options and run tests.""" - usage = """python %s FRAMEWORK_NAME + usage = f"""python {sys.argv[0]} FRAMEWORK_NAME Test PyMongo with a variety of greenlet-based monkey-patching frameworks. See -python %s --help-frameworks.""" % ( - sys.argv[0], - sys.argv[0], - ) +python {sys.argv[0]} --help-frameworks.""" try: opts, args = getopt.getopt(sys.argv[1:], "h", ["help", "help-frameworks"]) diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 549334899..63aa40623 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -224,7 +224,7 @@ class GridFS: doc = next(cursor) return GridOut(self.__collection, file_document=doc, session=session) except StopIteration: - raise NoFile("no version %d for filename %r" % (version, filename)) + raise NoFile("no version %d for filename %r" % (version, filename)) from None def get_last_version( self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any @@ -932,7 +932,7 @@ class GridFSBucket: grid_file = next(cursor) return GridOut(self._collection, file_document=grid_file, session=session) except StopIteration: - raise NoFile("no version %d for filename %r" % (revision, filename)) + raise NoFile("no version %d for filename %r" % (revision, filename)) from None @_csot.apply def download_to_stream_by_name( diff --git a/gridfs/errors.py b/gridfs/errors.py index 39736d55b..e8c02cef4 100644 --- a/gridfs/errors.py +++ b/gridfs/errors.py @@ -13,6 +13,7 @@ # limitations under the License. """Exceptions raised by the :mod:`gridfs` package""" +from __future__ import annotations from pymongo.errors import PyMongoError diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 682ceaba4..685d09749 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -357,12 +357,14 @@ class GridIn: except AttributeError: # string if not isinstance(data, (str, bytes)): - raise TypeError("can only write strings or file-like objects") + raise TypeError("can only write strings or file-like objects") from None if isinstance(data, str): try: data = data.encode(self.encoding) except AttributeError: - raise TypeError("must specify an encoding for file in order to write str") + raise TypeError( + "must specify an encoding for file in order to write str" + ) from None read = io.BytesIO(data).read if self._buffer.tell() > 0: @@ -395,7 +397,7 @@ class GridIn: def writeable(self) -> bool: return True - def __enter__(self) -> "GridIn": + def __enter__(self) -> GridIn: """Support for the context manager protocol.""" return self @@ -671,7 +673,7 @@ class GridOut(io.IOBase): def seekable(self) -> bool: return True - def __iter__(self) -> "GridOut": + def __iter__(self) -> GridOut: """Return an iterator over all of this file's data. The iterator will return lines (delimited by ``b'\\n'``) of @@ -708,7 +710,7 @@ class GridOut(io.IOBase): def writable(self) -> bool: return False - def __enter__(self) -> "GridOut": + def __enter__(self) -> GridOut: """Makes it possible to use :class:`GridOut` files with the context manager protocol. """ @@ -773,7 +775,7 @@ class _GridOutChunkIterator: return self._chunk_size return self._length - (self._chunk_size * (self._num_chunks - 1)) - def __iter__(self) -> "_GridOutChunkIterator": + def __iter__(self) -> _GridOutChunkIterator: return self def _create_cursor(self) -> None: @@ -806,7 +808,7 @@ class _GridOutChunkIterator: except StopIteration: if self._next_chunk >= self._num_chunks: raise - raise CorruptGridFile("no chunk #%d" % self._next_chunk) + raise CorruptGridFile("no chunk #%d" % self._next_chunk) from None if chunk["n"] != self._next_chunk: self.close() @@ -847,7 +849,7 @@ class GridOutIterator: def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession): self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0) - def __iter__(self) -> "GridOutIterator": + def __iter__(self) -> GridOutIterator: return self def next(self) -> bytes: @@ -914,6 +916,6 @@ class GridOutCursor(Cursor): def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError("Method does not exist for GridOutCursor") - def _clone_base(self, session: Optional[ClientSession]) -> "GridOutCursor": + def _clone_base(self, session: Optional[ClientSession]) -> GridOutCursor: """Creates an empty GridOutCursor for information to be copied into.""" return GridOutCursor(self.__root_collection, session=session) diff --git a/pymongo/_version.py b/pymongo/_version.py index 9022b2c15..1d296afbb 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -13,6 +13,8 @@ # limitations under the License. """Current version of PyMongo.""" +from __future__ import annotations + from typing import Tuple, Union version_tuple: Tuple[Union[int, str], ...] = (4, 6, 0, ".dev0") diff --git a/pymongo/auth.py b/pymongo/auth.py index d7237b633..58fc36d05 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -325,7 +325,7 @@ def _password_digest(username: str, password: str) -> str: if not isinstance(username, str): raise TypeError("username must be an instance of str") - md5hash = hashlib.md5() + md5hash = hashlib.md5() # noqa: S324 data = f"{username}:mongo:{password}" md5hash.update(data.encode("utf-8")) return md5hash.hexdigest() @@ -334,7 +334,7 @@ def _password_digest(username: str, password: str) -> str: def _auth_key(nonce: str, username: str, password: str) -> str: """Get an auth key to use for authentication.""" digest = _password_digest(username, password) - md5hash = hashlib.md5() + md5hash = hashlib.md5() # noqa: S324 data = f"{nonce}{username}{digest}" md5hash.update(data.encode("utf-8")) return md5hash.hexdigest() @@ -469,7 +469,7 @@ def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None kerberos.authGSSClientClean(ctx) except kerberos.KrbError as exc: - raise OperationFailure(str(exc)) + raise OperationFailure(str(exc)) from None def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None: diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index a327016d7..81f30c7ae 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -35,7 +35,7 @@ try: set_use_cached_credentials(True) except ImportError: - def set_cached_credentials(creds: Optional[AwsCredential]) -> None: + def set_cached_credentials(_creds: Optional[AwsCredential]) -> None: pass @@ -110,7 +110,9 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: # Clear the cached credentials if we hit a failure in auth. set_cached_credentials(None) # Convert to OperationFailure and include pymongo-auth-aws version. - raise OperationFailure(f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})") + raise OperationFailure( + f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})" + ) from None except Exception: # Clear the cached credentials if we hit a failure in auth. set_cached_credentials(None) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 220f91e67..ad9223809 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -144,7 +144,7 @@ class _OIDCAuthenticator: if principal_name: payload["n"] = principal_name - cmd = SON( + return SON( [ ("saslStart", 1), ("mechanism", "MONGODB-OIDC"), @@ -152,7 +152,6 @@ class _OIDCAuthenticator: ("autoAuthorize", 1), ] ) - return cmd def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]: # TODO: DRIVERS-2672, check for provider_name in self.properties here. @@ -207,7 +206,7 @@ class _OIDCAuthenticator: self.idp_info = server_resp # Handle the case of changed idp info. - if not self.idp_info == prev_idp_info: + if self.idp_info != prev_idp_info: self.access_token = None self.refresh_token = None diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 1ed7a117e..10e77d8b1 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -221,7 +221,7 @@ class _Bulk: ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) - cmd: dict[str, Any] = dict( + cmd: dict[str, Any] = dict( # noqa: C406 [("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)] ) if collation is not None: diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 8e47da680..75cd16979 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -225,7 +225,7 @@ class ChangeStream(Generic[_DocumentType]): if self._start_at_operation_time is None: raise OperationFailure( "Expected field 'operationTime' missing from command " - "response : {!r}".format(result) + f"response : {result!r}" ) def _run_aggregation_cmd( @@ -264,7 +264,7 @@ class ChangeStream(Generic[_DocumentType]): self._closed = True self._cursor.close() - def __iter__(self) -> "ChangeStream[_DocumentType]": + def __iter__(self) -> ChangeStream[_DocumentType]: return self @property @@ -406,7 +406,7 @@ class ChangeStream(Generic[_DocumentType]): self.close() raise InvalidOperation( "Cannot provide resume functionality when the resume token is missing." - ) + ) from None # If this is the last change document from the current batch, cache the # postBatchResumeToken. diff --git a/pymongo/client_session.py b/pymongo/client_session.py index fa9c62b59..0aac77011 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -200,7 +200,7 @@ class SessionOptions: def __init__( self, causal_consistency: Optional[bool] = None, - default_transaction_options: Optional["TransactionOptions"] = None, + default_transaction_options: Optional[TransactionOptions] = None, snapshot: Optional[bool] = False, ) -> None: if snapshot: @@ -227,7 +227,7 @@ class SessionOptions: return self._causal_consistency @property - def default_transaction_options(self) -> Optional["TransactionOptions"]: + def default_transaction_options(self) -> Optional[TransactionOptions]: """The default TransactionOptions to use for transactions started on this session. @@ -287,25 +287,25 @@ class TransactionOptions: if not isinstance(read_concern, ReadConcern): raise TypeError( "read_concern must be an instance of " - "pymongo.read_concern.ReadConcern, not: {!r}".format(read_concern) + f"pymongo.read_concern.ReadConcern, not: {read_concern!r}" ) if write_concern is not None: if not isinstance(write_concern, WriteConcern): raise TypeError( "write_concern must be an instance of " - "pymongo.write_concern.WriteConcern, not: {!r}".format(write_concern) + f"pymongo.write_concern.WriteConcern, not: {write_concern!r}" ) if not write_concern.acknowledged: raise ConfigurationError( "transactions do not support unacknowledged write concern" - ": {!r}".format(write_concern) + f": {write_concern!r}" ) if read_preference is not None: if not isinstance(read_preference, _ServerMode): raise TypeError( - "{!r} is not valid for read_preference. See " + f"{read_preference!r} is not valid for read_preference. See " "pymongo.read_preferences for valid " - "options.".format(read_preference) + "options." ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): @@ -354,7 +354,7 @@ def _validate_session_write_concern( else: raise ConfigurationError( "Explicit sessions are incompatible with " - "unacknowledged write concern: {!r}".format(write_concern) + f"unacknowledged write concern: {write_concern!r}" ) return session @@ -535,7 +535,7 @@ class ClientSession: if self._server_session is None: raise InvalidOperation("Cannot use ended session") - def __enter__(self) -> "ClientSession": + def __enter__(self) -> ClientSession: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -585,7 +585,7 @@ class ClientSession: def with_transaction( self, - callback: Callable[["ClientSession"], _T], + callback: Callable[[ClientSession], _T], read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, read_preference: Optional[_ServerMode] = None, @@ -838,7 +838,7 @@ class ClientSession: """ def func( - session: Optional[ClientSession], conn: Connection, retryable: bool + _session: Optional[ClientSession], conn: Connection, _retryable: bool ) -> dict[str, Any]: return self._finish_transaction(conn, command_name) @@ -1002,8 +1002,7 @@ class ClientSession: if self.in_transaction: if read_preference != ReadPreference.PRIMARY: raise InvalidOperation( - "read preference in a transaction must be primary, not: " - "{!r}".format(read_preference) + f"read preference in a transaction must be primary, not: {read_preference!r}" ) if self._transaction.state == _TxnState.STARTING: diff --git a/pymongo/collection.py b/pymongo/collection.py index 7768d5f52..7f4354e7d 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -368,8 +368,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]): if name.startswith("_"): full_name = f"{self.__name}.{name}" raise AttributeError( - "Collection has no attribute {!r}. To access the {}" - " collection, use database['{}'].".format(name, full_name, full_name) + f"Collection has no attribute {name!r}. To access the {full_name}" + f" collection, use database['{full_name}']." ) return self.__getitem__(name) @@ -563,7 +563,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): try: request._add_to_bulk(blk) except AttributeError: - raise TypeError(f"{request!r} is not a valid request") + raise TypeError(f"{request!r} is not a valid request") from None write_concern = self._write_concern_for(session) bulk_api_result = blk.execute(write_concern, session) @@ -1812,7 +1812,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], - server: Server, + _server: Server, conn: Connection, read_preference: Optional[_ServerMode], ) -> int: @@ -1901,7 +1901,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], - server: Server, + _server: Server, conn: Connection, read_preference: Optional[_ServerMode], ) -> int: @@ -2277,7 +2277,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], - server: Server, + _server: Server, conn: Connection, read_preference: _ServerMode, ) -> CommandCursor[MutableMapping[str, Any]]: @@ -2348,7 +2348,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): info = {} for index in cursor: index["key"] = list(index["key"].items()) - index = dict(index) + index = dict(index) # noqa: PLW2901 info[index.pop("name")] = index return info @@ -3038,7 +3038,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], - server: Server, + _server: Server, conn: Connection, read_preference: Optional[_ServerMode], ) -> list: diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index ea58e46f7..e0258f90f 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -75,7 +75,7 @@ class CommandCursor(Generic[_DocumentType]): if self.__killed: self.__end_session(True) - if "ns" in cursor_info: + if "ns" in cursor_info: # noqa: SIM401 self.__ns = cursor_info["ns"] else: self.__ns = collection.full_name @@ -121,7 +121,7 @@ class CommandCursor(Generic[_DocumentType]): """Explicitly close / kill this cursor.""" self.__die(True) - def batch_size(self, batch_size: int) -> "CommandCursor[_DocumentType]": + def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]: """Limits the number of documents returned in one batch. Each batch requires a round trip to the server. It can be adjusted to optimize performance and limit data transfer. @@ -340,7 +340,7 @@ class CommandCursor(Generic[_DocumentType]): """ return self._try_next(get_more_allowed=True) - def __enter__(self) -> "CommandCursor[_DocumentType]": + def __enter__(self) -> CommandCursor[_DocumentType]: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: diff --git a/pymongo/common.py b/pymongo/common.py index 794b7e31a..648571a6c 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -195,7 +195,7 @@ def validate_integer(option: str, value: Any) -> int: try: return int(value) except ValueError: - raise ValueError(f"The value of {option} must be an integer") + raise ValueError(f"The value of {option} must be an integer") from None raise TypeError(f"Wrong type for {option}, value must be an integer") @@ -287,9 +287,9 @@ def validate_positive_float(option: str, value: Any) -> float: try: value = float(value) except ValueError: - raise ValueError(errmsg) + raise ValueError(errmsg) from None except TypeError: - raise TypeError(errmsg) + raise TypeError(errmsg) from None # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at # one billion - this is a reasonable approximation for infinity @@ -388,10 +388,10 @@ def validate_uuid_representation(dummy: Any, value: Any) -> int: return _UUID_REPRESENTATIONS[value] except KeyError: raise ValueError( - "{} is an invalid UUID representation. " + f"{value} is an invalid UUID representation. " "Must be one of " - "{}".format(value, tuple(_UUID_REPRESENTATIONS)) - ) + f"{tuple(_UUID_REPRESENTATIONS)}" + ) from None def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]: @@ -411,7 +411,7 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]] tags[unquote_plus(key)] = unquote_plus(val) tag_sets.append(tags) except Exception: - raise ValueError(f"{tag_set!r} not a valid value for {name}") + raise ValueError(f"{tag_set!r} not a valid value for {name}") from None return tag_sets @@ -432,7 +432,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni if not isinstance(value, str): if not isinstance(value, dict): raise ValueError("Auth mechanism properties must be given as a string or a dictionary") - for key, value in value.items(): + for key, value in value.items(): # noqa: B020 if isinstance(value, str): props[key] = value elif isinstance(value, bool): @@ -462,20 +462,20 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni except ValueError: # Try not to leak the token. if "AWS_SESSION_TOKEN" in opt: - opt = ( + opt = ( # noqa: PLW2901 "AWS_SESSION_TOKEN:, did you forget " "to percent-escape the token with quote_plus?" ) raise ValueError( "auth mechanism properties must be " "key:value pairs like SERVICE_NAME:" - "mongodb, not {}.".format(opt) - ) + f"mongodb, not {opt}." + ) from None if key not in _MECHANISM_PROPS: raise ValueError( - "{} is not a supported auth " + f"{key} is not a supported auth " "mechanism property. Must be one of " - "{}.".format(key, tuple(_MECHANISM_PROPS)) + f"{tuple(_MECHANISM_PROPS)}." ) if key == "CANONICALIZE_HOST_NAME": props[key] = validate_boolean_or_string(key, val) @@ -499,9 +499,9 @@ def validate_document_class( is_mapping = issubclass(value.__origin__, abc.MutableMapping) if not is_mapping and not issubclass(value, RawBSONDocument): raise TypeError( - "{} must be dict, bson.son.SON, " + f"{option} must be dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or a " - "subclass of collections.MutableMapping".format(option) + "subclass of collections.MutableMapping" ) return value @@ -531,9 +531,9 @@ def validate_list_or_mapping(option: Any, value: Any) -> None: """Validates that 'value' is a list or a document.""" if not isinstance(value, (abc.Mapping, list)): raise TypeError( - "{} must either be a list or an instance of dict, " + f"{option} must either be a list or an instance of dict, " "bson.son.SON, or any other type that inherits from " - "collections.Mapping".format(option) + "collections.Mapping" ) @@ -541,9 +541,9 @@ def validate_is_mapping(option: str, value: Any) -> None: """Validate the type of method arguments that expect a document.""" if not isinstance(value, abc.Mapping): raise TypeError( - "{} must be an instance of dict, bson.son.SON, or " + f"{option} must be an instance of dict, bson.son.SON, or " "any other type that inherits from " - "collections.Mapping".format(option) + "collections.Mapping" ) @@ -551,10 +551,10 @@ def validate_is_document_type(option: str, value: Any) -> None: """Validate the type of method arguments that expect a MongoDB document.""" if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): raise TypeError( - "{} must be an instance of dict, bson.son.SON, " + f"{option} must be an instance of dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or " "a type that inherits from " - "collections.MutableMapping".format(option) + "collections.MutableMapping" ) @@ -626,9 +626,9 @@ def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str: """Validate the Unicode decode error handler option of CodecOptions.""" if value not in _UNICODE_DECODE_ERROR_HANDLERS: raise ValueError( - "{} is an invalid Unicode decode error handler. " + f"{value} is an invalid Unicode decode error handler. " "Must be one of " - "{}".format(value, tuple(_UNICODE_DECODE_ERROR_HANDLERS)) + f"{tuple(_UNICODE_DECODE_ERROR_HANDLERS)}" ) return value @@ -841,28 +841,28 @@ def get_validated_options( validated_options = _CaseInsensitiveDictionary() def get_normed_key(x: str) -> str: - return x # noqa: E731 + return x def get_setter_key(x: str) -> str: - return options.cased_key(x) # type: ignore[attr-defined] # noqa: E731 + return options.cased_key(x) # type: ignore[attr-defined] else: validated_options = {} def get_normed_key(x: str) -> str: - return x.lower() # noqa: E731 + return x.lower() def get_setter_key(x: str) -> str: - return x # noqa: E731 + return x for opt, value in options.items(): normed_key = get_normed_key(opt) try: validator = URI_OPTIONS_VALIDATOR_MAP.get(normed_key, raise_config_error) - value = validator(opt, value) + value = validator(opt, value) # noqa: PLW2901 except (ValueError, TypeError, ConfigurationError) as exc: if warn: - warnings.warn(str(exc)) + warnings.warn(str(exc), stacklevel=2) else: raise else: @@ -902,9 +902,9 @@ class BaseObject: if not isinstance(read_preference, _ServerMode): raise TypeError( - "{!r} is not valid for read_preference. See " + f"{read_preference!r} is not valid for read_preference. See " "pymongo.read_preferences for valid " - "options.".format(read_preference) + "options." ) self.__read_preference = read_preference @@ -1004,7 +1004,7 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]): return NotImplemented if len(self) != len(other): return False - for key in other: + for key in other: # noqa: SIM110 if self[key] != other[key]: return False diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index a29bf826d..ad54d628b 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -58,24 +58,27 @@ def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[s for compressor in compressors[:]: if compressor not in _SUPPORTED_COMPRESSORS: compressors.remove(compressor) - warnings.warn(f"Unsupported compressor: {compressor}") + warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2) elif compressor == "snappy" and not _HAVE_SNAPPY: compressors.remove(compressor) warnings.warn( "Wire protocol compression with snappy is not available. " - "You must install the python-snappy module for snappy support." + "You must install the python-snappy module for snappy support.", + stacklevel=2, ) elif compressor == "zlib" and not _HAVE_ZLIB: compressors.remove(compressor) warnings.warn( "Wire protocol compression with zlib is not available. " - "The zlib module is not available." + "The zlib module is not available.", + stacklevel=2, ) elif compressor == "zstd" and not _HAVE_ZSTD: compressors.remove(compressor) warnings.warn( "Wire protocol compression with zstandard is not available. " - "You must install the zstandard module for zstandard support." + "You must install the zstandard module for zstandard support.", + stacklevel=2, ) return compressors @@ -84,7 +87,7 @@ def validate_zlib_compression_level(option: str, value: Any) -> int: try: level = int(value) except Exception: - raise TypeError(f"{option} must be an integer, not {value!r}.") + raise TypeError(f"{option} must be an integer, not {value!r}.") from None if level < -1 or level > 9: raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) return level diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 4f9234e26..df154067a 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -335,7 +335,7 @@ class Cursor(Generic[_DocumentType]): def __del__(self) -> None: self.__die() - def rewind(self) -> "Cursor[_DocumentType]": + def rewind(self) -> Cursor[_DocumentType]: """Rewind this cursor to its unevaluated state. Reset this cursor if it has been partially or completely evaluated. @@ -353,7 +353,7 @@ class Cursor(Generic[_DocumentType]): return self - def clone(self) -> "Cursor[_DocumentType]": + def clone(self) -> Cursor[_DocumentType]: """Get a clone of this cursor. Returns a new Cursor instance with options matching those that have @@ -505,7 +505,7 @@ class Cursor(Generic[_DocumentType]): if self.__retrieved or self.__id is not None: raise InvalidOperation("cannot set options after executing query") - def add_option(self, mask: int) -> "Cursor[_DocumentType]": + def add_option(self, mask: int) -> Cursor[_DocumentType]: """Set arbitrary query flags using a bitmask. To set the tailable flag: @@ -525,7 +525,7 @@ class Cursor(Generic[_DocumentType]): self.__query_flags |= mask return self - def remove_option(self, mask: int) -> "Cursor[_DocumentType]": + def remove_option(self, mask: int) -> Cursor[_DocumentType]: """Unset arbitrary query flags using a bitmask. To unset the tailable flag: @@ -541,7 +541,7 @@ class Cursor(Generic[_DocumentType]): self.__query_flags &= ~mask return self - def allow_disk_use(self, allow_disk_use: bool) -> "Cursor[_DocumentType]": + def allow_disk_use(self, allow_disk_use: bool) -> Cursor[_DocumentType]: """Specifies whether MongoDB can use temporary disk files while processing a blocking sort operation. @@ -563,7 +563,7 @@ class Cursor(Generic[_DocumentType]): self.__allow_disk_use = allow_disk_use return self - def limit(self, limit: int) -> "Cursor[_DocumentType]": + def limit(self, limit: int) -> Cursor[_DocumentType]: """Limits the number of results to be returned by this cursor. Raises :exc:`TypeError` if `limit` is not an integer. Raises @@ -586,7 +586,7 @@ class Cursor(Generic[_DocumentType]): self.__limit = limit return self - def batch_size(self, batch_size: int) -> "Cursor[_DocumentType]": + def batch_size(self, batch_size: int) -> Cursor[_DocumentType]: """Limits the number of documents returned in one batch. Each batch requires a round trip to the server. It can be adjusted to optimize performance and limit data transfer. @@ -614,7 +614,7 @@ class Cursor(Generic[_DocumentType]): self.__batch_size = batch_size return self - def skip(self, skip: int) -> "Cursor[_DocumentType]": + def skip(self, skip: int) -> Cursor[_DocumentType]: """Skips the first `skip` results of this cursor. Raises :exc:`TypeError` if `skip` is not an integer. Raises @@ -635,7 +635,7 @@ class Cursor(Generic[_DocumentType]): self.__skip = skip return self - def max_time_ms(self, max_time_ms: Optional[int]) -> "Cursor[_DocumentType]": + def max_time_ms(self, max_time_ms: Optional[int]) -> Cursor[_DocumentType]: """Specifies a time limit for a query operation. If the specified time is exceeded, the operation will be aborted and :exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms` @@ -655,7 +655,7 @@ class Cursor(Generic[_DocumentType]): self.__max_time_ms = max_time_ms return self - def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> "Cursor[_DocumentType]": + def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> Cursor[_DocumentType]: """Specifies a time limit for a getMore operation on a :attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other types of cursor max_await_time_ms is ignored. @@ -687,7 +687,7 @@ class Cursor(Generic[_DocumentType]): ... @overload - def __getitem__(self, index: slice) -> "Cursor[_DocumentType]": + def __getitem__(self, index: slice) -> Cursor[_DocumentType]: ... def __getitem__(self, index: Union[int, slice]) -> Union[_DocumentType, Cursor[_DocumentType]]: @@ -770,7 +770,7 @@ class Cursor(Generic[_DocumentType]): raise IndexError("no such item for Cursor instance") raise TypeError("index %r cannot be applied to Cursor instances" % index) - def max_scan(self, max_scan: Optional[int]) -> "Cursor[_DocumentType]": + def max_scan(self, max_scan: Optional[int]) -> Cursor[_DocumentType]: """**DEPRECATED** - Limit the number of documents to scan when performing the query. @@ -790,7 +790,7 @@ class Cursor(Generic[_DocumentType]): self.__max_scan = max_scan return self - def max(self, spec: _Sort) -> "Cursor[_DocumentType]": + def max(self, spec: _Sort) -> Cursor[_DocumentType]: """Adds ``max`` operator that specifies upper bound for specific index. When using ``max``, :meth:`~hint` should also be configured to ensure @@ -813,7 +813,7 @@ class Cursor(Generic[_DocumentType]): self.__max = SON(spec) return self - def min(self, spec: _Sort) -> "Cursor[_DocumentType]": + def min(self, spec: _Sort) -> Cursor[_DocumentType]: """Adds ``min`` operator that specifies lower bound for specific index. When using ``min``, :meth:`~hint` should also be configured to ensure @@ -838,7 +838,7 @@ class Cursor(Generic[_DocumentType]): def sort( self, key_or_list: _Hint, direction: Optional[Union[int, str]] = None - ) -> "Cursor[_DocumentType]": + ) -> Cursor[_DocumentType]: """Sorts this cursor's results. Pass a field name and a direction, either @@ -944,7 +944,7 @@ class Cursor(Generic[_DocumentType]): else: self.__hint = helpers._index_document(index) - def hint(self, index: Optional[_Hint]) -> "Cursor[_DocumentType]": + def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]: """Adds a 'hint', telling Mongo the proper index to use for the query. Judicious use of hints can greatly improve query @@ -969,7 +969,7 @@ class Cursor(Generic[_DocumentType]): self.__set_hint(index) return self - def comment(self, comment: Any) -> "Cursor[_DocumentType]": + def comment(self, comment: Any) -> Cursor[_DocumentType]: """Adds a 'comment' to the cursor. http://mongodb.com/docs/manual/reference/operator/comment/ @@ -984,7 +984,7 @@ class Cursor(Generic[_DocumentType]): self.__comment = comment return self - def where(self, code: Union[str, Code]) -> "Cursor[_DocumentType]": + def where(self, code: Union[str, Code]) -> Cursor[_DocumentType]: """Adds a `$where`_ clause to this query. The `code` argument must be an instance of :class:`str` or @@ -1027,7 +1027,7 @@ class Cursor(Generic[_DocumentType]): self.__spec = spec return self - def collation(self, collation: Optional[_CollationIn]) -> "Cursor[_DocumentType]": + def collation(self, collation: Optional[_CollationIn]) -> Cursor[_DocumentType]: """Adds a :class:`~pymongo.collation.Collation` to this query. Raises :exc:`TypeError` if `collation` is not an instance of @@ -1253,7 +1253,7 @@ class Cursor(Generic[_DocumentType]): return self.__session return None - def __iter__(self) -> "Cursor[_DocumentType]": + def __iter__(self) -> Cursor[_DocumentType]: return self def next(self) -> _DocumentType: @@ -1267,13 +1267,13 @@ class Cursor(Generic[_DocumentType]): __next__ = next - def __enter__(self) -> "Cursor[_DocumentType]": + def __enter__(self) -> Cursor[_DocumentType]: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def __copy__(self) -> "Cursor[_DocumentType]": + def __copy__(self) -> Cursor[_DocumentType]: """Support function for `copy.copy()`. .. versionadded:: 2.4 @@ -1320,15 +1320,15 @@ class Cursor(Generic[_DocumentType]): for key, value in iterator: if isinstance(value, (dict, list)) and not isinstance(value, SON): - value = self._deepcopy(value, memo) + value = self._deepcopy(value, memo) # noqa: PLW2901 elif not isinstance(value, RE_TYPE): - value = copy.deepcopy(value, memo) + value = copy.deepcopy(value, memo) # noqa: PLW2901 if is_list: y.append(value) # type: ignore[union-attr] else: if not isinstance(key, RE_TYPE): - key = copy.deepcopy(key, memo) + key = copy.deepcopy(key, memo) # noqa: PLW2901 y[key] = value return y diff --git a/pymongo/daemon.py b/pymongo/daemon.py index 974014149..b40384df1 100644 --- a/pymongo/daemon.py +++ b/pymongo/daemon.py @@ -63,7 +63,7 @@ if sys.platform == "win32": try: with open(os.devnull, "r+b") as devnull: popen = subprocess.Popen( - args, + args, # noqa: S603 creationflags=_DETACHED_PROCESS, stdin=devnull, stderr=devnull, @@ -94,7 +94,11 @@ else: try: with open(os.devnull, "r+b") as devnull: return subprocess.Popen( - args, close_fds=True, stdin=devnull, stderr=devnull, stdout=devnull + args, # noqa: S603 + close_fds=True, + stdin=devnull, + stderr=devnull, + stdout=devnull, ) except FileNotFoundError as exc: warnings.warn( @@ -108,7 +112,7 @@ else: """Spawn a daemon process using a double subprocess.Popen.""" spawner_args = [sys.executable, _THIS_FILE] spawner_args.extend(args) - temp_proc = subprocess.Popen(spawner_args, close_fds=True) + temp_proc = subprocess.Popen(spawner_args, close_fds=True) # noqa: S603 # Reap the intermediate child process to avoid creating zombie # processes. _popen_wait(temp_proc, _WAIT_TIMEOUT) diff --git a/pymongo/database.py b/pymongo/database.py index 73cf01d0a..70cdee2dc 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -74,7 +74,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def __init__( self, - client: "MongoClient[_DocumentType]", + client: MongoClient[_DocumentType], name: str, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, @@ -144,7 +144,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): self._timeout = client.options.timeout @property - def client(self) -> "MongoClient[_DocumentType]": + def client(self) -> MongoClient[_DocumentType]: """The client instance for this :class:`Database`.""" return self.__client @@ -224,12 +224,12 @@ class Database(common.BaseObject, Generic[_DocumentType]): """ if name.startswith("_"): raise AttributeError( - "Database has no attribute {!r}. To access the {}" - " collection, use database[{!r}].".format(name, name, name) + f"Database has no attribute {name!r}. To access the {name}" + f" collection, use database[{name!r}]." ) return self.__getitem__(name) - def __getitem__(self, name: str) -> "Collection[_DocumentType]": + def __getitem__(self, name: str) -> Collection[_DocumentType]: """Get a collection of this database by name. Raises InvalidName if an invalid collection name is used. @@ -791,7 +791,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, - codec_options: "Optional[bson.codec_options.CodecOptions[_CodecDocumentType]]" = None, + codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None, session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1012,7 +1012,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], - server: Server, + _server: Server, conn: Connection, read_preference: _ServerMode, ) -> dict[str, Any]: @@ -1090,7 +1090,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], - server: Server, + _server: Server, conn: Connection, read_preference: _ServerMode, ) -> CommandCursor[MutableMapping[str, Any]]: @@ -1377,7 +1377,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): if dbref.database is not None and dbref.database != self.__name: raise ValueError( "trying to dereference a DBRef that points to " - "another database ({!r} not {!r})".format(dbref.database, self.__name) + f"another database ({dbref.database!r} not {self.__name!r})" ) return self[dbref.collection].find_one( {"_id": dbref.id}, session=session, comment=comment, **kwargs diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 7252d8f64..9e7cfbda3 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -31,13 +31,12 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])): def __new__( cls, name: str, version: Optional[str] = None, platform: Optional[str] = None - ) -> "DriverInfo": + ) -> DriverInfo: self = super().__new__(cls, name, version, platform) for key, value in self._asdict().items(): if value is not None and not isinstance(value, str): raise TypeError( - "Wrong type for DriverInfo {} option, value " - "must be an instance of str".format(key) + f"Wrong type for DriverInfo {key} option, value must be an instance of str" ) return self diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 985f9132b..7a390e50f 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -34,7 +34,7 @@ from typing import ( try: from pymongocrypt.auto_encrypter import AutoEncrypter - from pymongocrypt.errors import MongoCryptError # noqa: F401 + from pymongocrypt.errors import MongoCryptError from pymongocrypt.explicit_encrypter import ExplicitEncrypter from pymongocrypt.mongocrypt import MongoCryptOptions from pymongocrypt.state_machine import MongoCryptCallback @@ -102,7 +102,7 @@ def _wrap_encryption_errors() -> Iterator[None]: # we should propagate them unchanged. raise except Exception as exc: - raise EncryptionError(exc) + raise EncryptionError(exc) from None class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] @@ -177,7 +177,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] raise OSError("KMS connection closed") kms_context.feed(data) except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") + raise socket.timeout("timed out") from None finally: conn.close() except (PyMongoError, MongoCryptError): @@ -414,8 +414,7 @@ class _Encrypter: with _wrap_encryption_errors(): encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) # TODO: PYTHON-1922 avoid decoding the encrypted_cmd. - encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) - return encrypt_cmd + return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) def decrypt(self, response: bytes) -> Optional[bytes]: """Decrypt a MongoDB command response. diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 6faaea2db..cd7d434b0 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -119,7 +119,7 @@ def _index_list( values: list[tuple[str, int]] = [] for item in key_or_list: if isinstance(item, str): - item = (item, ASCENDING) + item = (item, ASCENDING) # noqa: PLW2901 values.append(item) return values @@ -146,7 +146,7 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]: else: for item in index_list: if isinstance(item, str): - item = (item, ASCENDING) + item = (item, ASCENDING) # noqa: PLW2901 key, value = item _validate_index_key_pair(key, value) index[key] = value diff --git a/pymongo/max_staleness_selectors.py b/pymongo/max_staleness_selectors.py index 10c136a43..72edf555b 100644 --- a/pymongo/max_staleness_selectors.py +++ b/pymongo/max_staleness_selectors.py @@ -67,7 +67,7 @@ def _with_primary(max_staleness: int, selection: Selection) -> Selection: for s in selection.server_descriptions: if s.server_type == SERVER_TYPE.RSSecondary: # See max-staleness.rst for explanation of this formula. - assert s.last_write_date and primary.last_write_date + assert s.last_write_date and primary.last_write_date # noqa: PT018 staleness = ( (s.last_update_time - s.last_write_date) - (primary.last_update_time - primary.last_write_date) @@ -95,7 +95,7 @@ def _no_primary(max_staleness: int, selection: Selection) -> Selection: for s in selection.server_descriptions: if s.server_type == SERVER_TYPE.RSSecondary: # See max-staleness.rst for explanation of this formula. - assert smax.last_write_date and s.last_write_date + assert smax.last_write_date and s.last_write_date # noqa: PT018 staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency if staleness <= max_staleness: diff --git a/pymongo/message.py b/pymongo/message.py index 5190ce23e..468574550 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -106,14 +106,14 @@ _OP_MAP = { } _FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} -_UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions( +_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( unicode_decode_error_handler="replace" ) def _randint() -> int: """Generate a pseudo random 32 bit integer.""" - return random.randint(MIN_INT32, MAX_INT32) + return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 def _maybe_add_read_preference( @@ -731,7 +731,7 @@ def _op_msg_uncompressed( if _use_c: - _op_msg_uncompressed = _cmessage._op_msg # noqa: F811 + _op_msg_uncompressed = _cmessage._op_msg def _op_msg( @@ -833,7 +833,7 @@ def _query_uncompressed( if _use_c: - _query_uncompressed = _cmessage._query_message # noqa: F811 + _query_uncompressed = _cmessage._query_message def _query( @@ -889,7 +889,7 @@ def _get_more_uncompressed( if _use_c: - _get_more_uncompressed = _cmessage._get_more_message # noqa: F811 + _get_more_uncompressed = _cmessage._get_more_message def _get_more( @@ -942,7 +942,7 @@ class _BulkWriteContext: self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() if self.publish else None self.session = session - self.compress = True if conn.compression_context else False + self.compress = bool(conn.compression_context) self.op_type = op_type self.codec = codec @@ -1222,7 +1222,7 @@ def _batched_op_msg_impl( try: buf.write(_OP_MSG_MAP[operation]) except KeyError: - raise InvalidOperation("Unknown command") + raise InvalidOperation("Unknown command") from None to_send = [] idx = 0 @@ -1278,7 +1278,7 @@ def _encode_batched_op_msg( if _use_c: - _encode_batched_op_msg = _cmessage._encode_batched_op_msg # noqa: F811 + _encode_batched_op_msg = _cmessage._encode_batched_op_msg def _batched_op_msg_compressed( @@ -1328,7 +1328,7 @@ def _batched_op_msg( if _use_c: - _batched_op_msg = _cmessage._batched_op_msg # noqa: F811 + _batched_op_msg = _cmessage._batched_op_msg def _do_batched_op_msg( @@ -1371,7 +1371,7 @@ def _encode_batched_write_command( if _use_c: - _encode_batched_write_command = _cmessage._encode_batched_write_command # noqa: F811 + _encode_batched_write_command = _cmessage._encode_batched_write_command def _batched_write_command_impl( @@ -1410,7 +1410,7 @@ def _batched_write_command_impl( try: buf.write(_OP_MAP[operation]) except KeyError: - raise InvalidOperation("Unknown command") + raise InvalidOperation("Unknown command") from None # Where to write list document length list_start = buf.tell() - 4 @@ -1586,7 +1586,7 @@ class _OpMsg: def raw_response( self, cursor_id: Optional[int] = None, - user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006 + user_fields: Optional[Mapping[str, Any]] = {}, ) -> list[Mapping[str, Any]]: """ cursor_id is ignored diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 72ea671fe..422f8924f 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1297,7 +1297,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # We're running a getMore or this session is pinned to a mongos. server = topology.select_server_by_address(address) if not server: - raise AutoReconnect("server %s:%s no longer available" % address) + raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031 else: server = topology.select_server(server_selector) return server @@ -1380,7 +1380,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): ) def _cmd( - session: Optional[ClientSession], + _session: Optional[ClientSession], server: Server, conn: Connection, read_preference: _ServerMode, @@ -1579,8 +1579,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """ if name.startswith("_"): raise AttributeError( - "MongoClient has no attribute {!r}. To access the {}" - " database, use client[{!r}].".format(name, name, name) + f"MongoClient has no attribute {name!r}. To access the {name}" + f" database, use client[{name!r}]." ) return self.__getitem__(name) @@ -2132,7 +2132,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): write_concern=DEFAULT_WRITE_CONCERN, ) - def __enter__(self) -> "MongoClient[_DocumentType]": + def __enter__(self) -> MongoClient[_DocumentType]: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -2153,8 +2153,7 @@ def _retryable_error_doc(exc: PyMongoError) -> Optional[Mapping[str, Any]]: # Check the last writeConcernError to determine if this # BulkWriteError is retryable. wces = exc.details["writeConcernErrors"] - wce = wces[-1] if wces else None - return wce + return wces[-1] if wces else None if isinstance(exc, (NotPrimaryError, OperationFailure)): return cast(Mapping[str, Any], exc.details) return None diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index d8f370a12..03b3c5318 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -226,7 +226,7 @@ class CommandListener(_EventListener): and `CommandFailedEvent`. """ - def started(self, event: "CommandStartedEvent") -> None: + def started(self, event: CommandStartedEvent) -> None: """Abstract method to handle a `CommandStartedEvent`. :Parameters: @@ -234,7 +234,7 @@ class CommandListener(_EventListener): """ raise NotImplementedError - def succeeded(self, event: "CommandSucceededEvent") -> None: + def succeeded(self, event: CommandSucceededEvent) -> None: """Abstract method to handle a `CommandSucceededEvent`. :Parameters: @@ -242,7 +242,7 @@ class CommandListener(_EventListener): """ raise NotImplementedError - def failed(self, event: "CommandFailedEvent") -> None: + def failed(self, event: CommandFailedEvent) -> None: """Abstract method to handle a `CommandFailedEvent`. :Parameters: @@ -267,7 +267,7 @@ class ConnectionPoolListener(_EventListener): .. versionadded:: 3.9 """ - def pool_created(self, event: "PoolCreatedEvent") -> None: + def pool_created(self, event: PoolCreatedEvent) -> None: """Abstract method to handle a :class:`PoolCreatedEvent`. Emitted when a connection Pool is created. @@ -277,7 +277,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def pool_ready(self, event: "PoolReadyEvent") -> None: + def pool_ready(self, event: PoolReadyEvent) -> None: """Abstract method to handle a :class:`PoolReadyEvent`. Emitted when a connection Pool is marked ready. @@ -289,7 +289,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def pool_cleared(self, event: "PoolClearedEvent") -> None: + def pool_cleared(self, event: PoolClearedEvent) -> None: """Abstract method to handle a `PoolClearedEvent`. Emitted when a connection Pool is cleared. @@ -299,7 +299,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def pool_closed(self, event: "PoolClosedEvent") -> None: + def pool_closed(self, event: PoolClosedEvent) -> None: """Abstract method to handle a `PoolClosedEvent`. Emitted when a connection Pool is closed. @@ -309,7 +309,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_created(self, event: "ConnectionCreatedEvent") -> None: + def connection_created(self, event: ConnectionCreatedEvent) -> None: """Abstract method to handle a :class:`ConnectionCreatedEvent`. Emitted when a connection Pool creates a Connection object. @@ -319,7 +319,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_ready(self, event: "ConnectionReadyEvent") -> None: + def connection_ready(self, event: ConnectionReadyEvent) -> None: """Abstract method to handle a :class:`ConnectionReadyEvent`. Emitted when a connection has finished its setup, and is now ready to @@ -330,7 +330,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_closed(self, event: "ConnectionClosedEvent") -> None: + def connection_closed(self, event: ConnectionClosedEvent) -> None: """Abstract method to handle a :class:`ConnectionClosedEvent`. Emitted when a connection Pool closes a connection. @@ -340,7 +340,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_check_out_started(self, event: "ConnectionCheckOutStartedEvent") -> None: + def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. Emitted when the driver starts attempting to check out a connection. @@ -350,7 +350,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_check_out_failed(self, event: "ConnectionCheckOutFailedEvent") -> None: + def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. Emitted when the driver's attempt to check out a connection fails. @@ -360,7 +360,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None: + def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. Emitted when the driver successfully checks out a connection. @@ -370,7 +370,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None: + def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: """Abstract method to handle a :class:`ConnectionCheckedInEvent`. Emitted when the driver checks in a connection back to the connection @@ -391,7 +391,7 @@ class ServerHeartbeatListener(_EventListener): .. versionadded:: 3.3 """ - def started(self, event: "ServerHeartbeatStartedEvent") -> None: + def started(self, event: ServerHeartbeatStartedEvent) -> None: """Abstract method to handle a `ServerHeartbeatStartedEvent`. :Parameters: @@ -399,7 +399,7 @@ class ServerHeartbeatListener(_EventListener): """ raise NotImplementedError - def succeeded(self, event: "ServerHeartbeatSucceededEvent") -> None: + def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: """Abstract method to handle a `ServerHeartbeatSucceededEvent`. :Parameters: @@ -407,7 +407,7 @@ class ServerHeartbeatListener(_EventListener): """ raise NotImplementedError - def failed(self, event: "ServerHeartbeatFailedEvent") -> None: + def failed(self, event: ServerHeartbeatFailedEvent) -> None: """Abstract method to handle a `ServerHeartbeatFailedEvent`. :Parameters: @@ -424,7 +424,7 @@ class TopologyListener(_EventListener): .. versionadded:: 3.3 """ - def opened(self, event: "TopologyOpenedEvent") -> None: + def opened(self, event: TopologyOpenedEvent) -> None: """Abstract method to handle a `TopologyOpenedEvent`. :Parameters: @@ -432,7 +432,7 @@ class TopologyListener(_EventListener): """ raise NotImplementedError - def description_changed(self, event: "TopologyDescriptionChangedEvent") -> None: + def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: """Abstract method to handle a `TopologyDescriptionChangedEvent`. :Parameters: @@ -440,7 +440,7 @@ class TopologyListener(_EventListener): """ raise NotImplementedError - def closed(self, event: "TopologyClosedEvent") -> None: + def closed(self, event: TopologyClosedEvent) -> None: """Abstract method to handle a `TopologyClosedEvent`. :Parameters: @@ -457,7 +457,7 @@ class ServerListener(_EventListener): .. versionadded:: 3.3 """ - def opened(self, event: "ServerOpeningEvent") -> None: + def opened(self, event: ServerOpeningEvent) -> None: """Abstract method to handle a `ServerOpeningEvent`. :Parameters: @@ -465,7 +465,7 @@ class ServerListener(_EventListener): """ raise NotImplementedError - def description_changed(self, event: "ServerDescriptionChangedEvent") -> None: + def description_changed(self, event: ServerDescriptionChangedEvent) -> None: """Abstract method to handle a `ServerDescriptionChangedEvent`. :Parameters: @@ -473,7 +473,7 @@ class ServerListener(_EventListener): """ raise NotImplementedError - def closed(self, event: "ServerClosedEvent") -> None: + def closed(self, event: ServerClosedEvent) -> None: """Abstract method to handle a `ServerClosedEvent`. :Parameters: @@ -496,10 +496,10 @@ def _validate_event_listeners( for listener in listeners: if not isinstance(listener, _EventListener): raise TypeError( - "Listeners for {} must be either a " + f"Listeners for {option} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener.".format(option) + "ConnectionPoolListener." ) return listeners @@ -514,10 +514,10 @@ def register(listener: _EventListener) -> None: """ if not isinstance(listener, _EventListener): raise TypeError( - "Listeners for {} must be either a " + f"Listeners for {listener} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener.".format(listener) + "ConnectionPoolListener." ) if isinstance(listener, CommandListener): _LISTENERS.command_listeners.append(listener) @@ -1147,11 +1147,7 @@ class _ServerEvent: return self.__topology_id def __repr__(self) -> str: - return "<{} {} topology_id: {}>".format( - self.__class__.__name__, - self.server_address, - self.topology_id, - ) + return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" class ServerDescriptionChangedEvent(_ServerEvent): @@ -1216,7 +1212,7 @@ class ServerClosedEvent(_ServerEvent): class TopologyEvent: """Base class for topology description events.""" - __slots__ = "__topology_id" + __slots__ = ("__topology_id",) def __init__(self, topology_id: ObjectId) -> None: self.__topology_id = topology_id diff --git a/pymongo/network.py b/pymongo/network.py index 14160a516..fb4388121 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -265,8 +265,8 @@ def receive_message( ) if length > max_message_size: raise ProtocolError( - "Message length ({!r}) is larger than server max " - "message size ({!r})".format(length, max_message_size) + f"Message length ({length!r}) is larger than server max " + f"message size ({max_message_size!r})" ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( @@ -279,7 +279,9 @@ def receive_message( try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: - raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}") + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None return unpack_reply(data) @@ -337,8 +339,8 @@ def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[fl conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) chunk_length = conn.conn.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") - except OSError as exc: # noqa: B014 + raise socket.timeout("timed out") from None + except OSError as exc: if _errno_from_exception(exc) == errno.EINTR: continue raise diff --git a/pymongo/ocsp_support.py b/pymongo/ocsp_support.py index ccc03ae67..1bda3b4d7 100644 --- a/pymongo/ocsp_support.py +++ b/pymongo/ocsp_support.py @@ -180,7 +180,7 @@ def _public_key_hash(cert: Certificate) -> bytes: pbytes = public_key.public_bytes(_Encoding.X962, _PublicFormat.UncompressedPoint) else: pbytes = public_key.public_bytes(_Encoding.DER, _PublicFormat.SubjectPublicKeyInfo) - digest = _Hash(_SHA1(), backend=_default_backend()) + digest = _Hash(_SHA1(), backend=_default_backend()) # noqa: S303 digest.update(pbytes) return digest.finalize() @@ -262,7 +262,7 @@ def _verify_response_signature(issuer: Certificate, response: OCSPResponse) -> i def _build_ocsp_request(cert: Certificate, issuer: Certificate) -> OCSPRequest: # https://cryptography.io/en/latest/x509/ocsp/#creating-requests builder = _OCSPRequestBuilder() - builder = builder.add_certificate(cert, issuer, _SHA1()) + builder = builder.add_certificate(cert, issuer, _SHA1()) # noqa: S303 return builder.build() diff --git a/pymongo/operations.py b/pymongo/operations.py index adaca5707..2a41575a3 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -547,7 +547,7 @@ class IndexModel: class SearchIndexModel: """Represents a search index to create.""" - __slots__ = "__document" + __slots__ = ("__document",) def __init__(self, definition: Mapping[str, Any], name: Optional[str] = None) -> None: """Create a Search Index instance. diff --git a/pymongo/pool.py b/pymongo/pool.py index afe3a4313..d4248722e 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -116,7 +116,7 @@ except ImportError: # Windows, various platforms we don't claim to support # (Jython, IronPython, ...), systems that don't provide # everything we need from fcntl, etc. - def _set_non_inheritable_non_atomic(fd: int) -> None: + def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 """Dummy function for platforms that don't provide fcntl.""" @@ -1076,7 +1076,7 @@ class Connection: # shutdown. try: self.conn.close() - except Exception: + except Exception: # noqa: S110 pass def conn_closed(self) -> bool: @@ -1250,7 +1250,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. # Raise _CertificateError directly like we do after match_hostname # below. raise - except (OSError, SSLError) as exc: # noqa: B014 + except (OSError, SSLError) as exc: sock.close() # We raise AutoReconnect for transient and permanent SSL handshake # failures alike. Permanent handshake failures, like protocol @@ -1811,7 +1811,7 @@ class Pool: return True if self._check_interval_seconds is not None and ( - 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds + self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds ): if conn.conn_closed(): conn.close_conn(ConnectionClosedReason.ERROR) @@ -1847,7 +1847,7 @@ class Pool: ) raise WaitQueueTimeoutError( "Timed out while checking out a connection from connection pool. " - "maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout) + f"maxPoolSize: {self.opts.max_pool_size}, timeout: {timeout}" ) def __del__(self) -> None: diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 4bef96ec7..6657937e9 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -86,7 +86,7 @@ def _is_ip_address(address: Any) -> bool: try: _ip_address(address) return True - except (ValueError, UnicodeError): # noqa: B014 + except (ValueError, UnicodeError): return False @@ -122,8 +122,8 @@ class _sslConn(_SSL.Connection): # Check for closed socket. if self.fileno() == -1: if timeout and _time.monotonic() - start > timeout: - raise _socket.timeout("timed out") - raise SSLError("Underlying socket has been closed") + raise _socket.timeout("timed out") from None + raise SSLError("Underlying socket has been closed") from None if isinstance(exc, _SSL.WantReadError): want_read = True want_write = False @@ -135,7 +135,7 @@ class _sslConn(_SSL.Connection): want_write = True self.socket_checker.select(self, want_read, want_write, timeout) if timeout and _time.monotonic() - start > timeout: - raise _socket.timeout("timed out") + raise _socket.timeout("timed out") from None continue def do_handshake(self, *args: Any, **kwargs: Any) -> None: @@ -169,7 +169,7 @@ class _sslConn(_SSL.Connection): # XXX: It's not clear if this can actually happen. PyOpenSSL # doesn't appear to have any interrupt handling, nor any interrupt # errors for OpenSSL connections. - except OSError as exc: # noqa: B014 + except OSError as exc: if _errno_from_exception(exc) == _EINTR: continue raise @@ -226,10 +226,10 @@ class SSLContext: """Setter for verify_mode.""" def _cb( - connobj: _SSL.Connection, - x509obj: _crypto.X509, - errnum: int, - errdepth: int, + _connobj: _SSL.Connection, + _x509obj: _crypto.X509, + _errnum: int, + _errdepth: int, retcode: int, ) -> bool: # It seems we don't need to do anything here. Twisted doesn't, @@ -295,7 +295,7 @@ class SSLContext: # Password callback MUST be set first or it will be ignored. if password: - def _pwcb(max_length: int, prompt_twice: bool, user_data: bytes) -> bytes: + def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: bytes) -> bytes: # XXX:We could check the password length against what OpenSSL # tells us is the max, but we can't raise an exception, so... # warn? @@ -410,5 +410,5 @@ class SSLContext: else: _verify_hostname(ssl_conn, server_hostname) except (_SICertificateError, _SIVerificationError) as exc: - raise _CertificateError(str(exc)) + raise _CertificateError(str(exc)) from None return ssl_conn diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index f731fdfaf..986cc772b 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -64,9 +64,9 @@ def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: for tags in tag_sets: if not isinstance(tags, abc.Mapping): raise TypeError( - "Tag set {!r} invalid, must be an instance of dict, " + f"Tag set {tags!r} invalid, must be an instance of dict, " "bson.son.SON or other type that inherits from " - "collection.Mapping".format(tags) + "collection.Mapping" ) return list(tag_sets) diff --git a/pymongo/results.py b/pymongo/results.py index ab830c161..266101921 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -32,10 +32,10 @@ class _WriteResult: """Raise an exception on property access if unacknowledged.""" if not self.__acknowledged: raise InvalidOperation( - "A value for {} is not available when " + f"A value for {property_name} is not available when " "the write is unacknowledged. Check the " "acknowledged attribute to avoid this " - "error.".format(property_name) + "error." ) @property diff --git a/pymongo/saslprep.py b/pymongo/saslprep.py index 8f9251634..02c845079 100644 --- a/pymongo/saslprep.py +++ b/pymongo/saslprep.py @@ -22,7 +22,9 @@ try: except ImportError: HAVE_STRINGPREP = False - def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> Any: + def saslprep( + data: Any, prohibit_unassigned_code_points: Optional[bool] = True # noqa: ARG001 + ) -> Any: """SASLprep dummy""" if isinstance(data, str): raise TypeError( diff --git a/pymongo/server_api.py b/pymongo/server_api.py index 47812818d..90505bc5a 100644 --- a/pymongo/server_api.py +++ b/pymongo/server_api.py @@ -122,12 +122,12 @@ class ServerApi: if strict is not None and not isinstance(strict, bool): raise TypeError( "Wrong type for ServerApi strict, value must be an instance " - "of bool, not {}".format(type(strict)) + f"of bool, not {type(strict)}" ) if deprecation_errors is not None and not isinstance(deprecation_errors, bool): raise TypeError( "Wrong type for ServerApi deprecation_errors, value must be " - "an instance of bool, not {}".format(type(deprecation_errors)) + f"an instance of bool, not {type(deprecation_errors)}" ) self._version = version self._strict = strict diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 490970581..3b4131f32 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -257,7 +257,7 @@ class ServerDescription: def topology_version(self) -> Optional[Mapping[str, Any]]: return self._topology_version - def to_unknown(self, error: Optional[Exception] = None) -> "ServerDescription": + def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: unknown = ServerDescription(self.address, error=error) unknown._topology_version = self.topology_version return unknown diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index afde6bba6..76c8b5161 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -75,7 +75,7 @@ class _SrvResolver: try: self.__plist = self.__fqdn.split(".")[1:] except Exception: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None self.__slen = len(self.__plist) if self.__slen < 2: raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) @@ -87,7 +87,7 @@ class _SrvResolver: # No TXT records return None except Exception as exc: - raise ConfigurationError(str(exc)) + raise ConfigurationError(str(exc)) from None if len(results) > 1: raise ConfigurationError("Only one TXT record is supported") return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") @@ -102,7 +102,7 @@ class _SrvResolver: # Raise the original error. raise # Else, raise all errors as ConfigurationError. - raise ConfigurationError(str(exc)) + raise ConfigurationError(str(exc)) from None return results def _get_srv_response_and_hosts( @@ -120,7 +120,7 @@ class _SrvResolver: try: nlist = node[0].lower().split(".")[1:][-self.__slen :] except Exception: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None if self.__plist != nlist: raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__srv_max_hosts: diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 4c5e46dfd..3c9ee01ef 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -35,7 +35,7 @@ if HAVE_SSL: # CPython ssl module constants to configure certificate verification # at a high level. This is legacy behavior, but requires us to # import the ssl module even if we're only using it for this purpose. - import ssl as _stdlibssl # noqa + import ssl as _stdlibssl # noqa: F401 from ssl import CERT_NONE, CERT_REQUIRED HAS_SNI = _ssl.HAS_SNI @@ -74,12 +74,14 @@ if HAVE_SSL: try: ctx.load_cert_chain(certfile, None, passphrase) except _ssl.SSLError as exc: - raise ConfigurationError(f"Private key doesn't match certificate: {exc}") + raise ConfigurationError(f"Private key doesn't match certificate: {exc}") from None if crlfile is not None: if _ssl.IS_PYOPENSSL: raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL") # Match the server's behavior. - setattr(ctx, "verify_flags", getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0)) # noqa + ctx.verify_flags = getattr( # type:ignore[attr-defined] + _ssl, "VERIFY_CRL_CHECK_LEAF", 0 + ) ctx.load_verify_locations(crlfile) if ca_certs is not None: ctx.load_verify_locations(ca_certs) diff --git a/pymongo/topology.py b/pymongo/topology.py index 45e018c3b..786be3ec9 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -186,7 +186,8 @@ class Topology: "MongoClient opened before fork. May not be entirely fork-safe, " "proceed with caution. See PyMongo's documentation for details: " "https://pymongo.readthedocs.io/en/stable/faq.html#" - "is-pymongo-fork-safe" + "is-pymongo-fork-safe", + stacklevel=2, ) with self._lock: # Close servers and clear the pools. diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 9f9855714..141f74edf 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -168,12 +168,12 @@ class TopologyDescription: def has_server(self, address: _Address) -> bool: return address in self._server_descriptions - def reset_server(self, address: _Address) -> "TopologyDescription": + def reset_server(self, address: _Address) -> TopologyDescription: """A copy of this description, with one server marked Unknown.""" unknown_sd = self._server_descriptions[address].to_unknown() return updated_topology_description(self, unknown_sd) - def reset(self) -> "TopologyDescription": + def reset(self) -> TopologyDescription: """A copy of this description, with all servers marked Unknown.""" if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary @@ -381,7 +381,7 @@ _SERVER_TYPE_TO_TOPOLOGY_TYPE = { def updated_topology_description( topology_description: TopologyDescription, server_description: ServerDescription -) -> "TopologyDescription": +) -> TopologyDescription: """Return an updated copy of a TopologyDescription. :Parameters: @@ -672,5 +672,5 @@ def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: for s in sds.values(): if s.server_type == SERVER_TYPE.RSPrimary: return TOPOLOGY_TYPE.ReplicaSetWithPrimary - else: + else: # noqa: PLW0120 return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index ab86e11bf..2f740ea2d 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -178,7 +178,7 @@ def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionar options.setdefault(key, []).append(value) else: if key in options: - warnings.warn(f"Duplicate URI option '{key}'.") + warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) if key.lower() == "authmechanismproperties": val = value else: @@ -350,7 +350,7 @@ def split_options( else: raise ValueError except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") + raise InvalidURI("MongoDB URI options are key=value pairs.") from None options = _handle_security_options(options) @@ -598,14 +598,14 @@ def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict if not isinstance(kms_tls_options, dict): raise TypeError("kms_tls_options must be a dict") contexts = {} - for provider, opts in kms_tls_options.items(): - if not isinstance(opts, dict): + for provider, options in kms_tls_options.items(): + if not isinstance(options, dict): raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - opts.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(opts) + options.setdefault("tls", True) + opts = _CaseInsensitiveDictionary(options) opts = _handle_security_options(opts) opts = _normalize_options(opts) - opts = validate_options(opts) + opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) if ssl_context is None: raise ConfigurationError("TLS is required for KMS providers") @@ -628,7 +628,7 @@ if __name__ == "__main__": import pprint try: - pprint.pprint(parse_uri(sys.argv[1])) + pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 except InvalidURI as exc: - print(exc) + print(exc) # noqa: T201 sys.exit(0) diff --git a/pyproject.toml b/pyproject.toml index 66afea4c6..c4af032eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,3 +95,71 @@ include = ["bson","gridfs", "pymongo"] bson=["py.typed", "*.pyi"] pymongo=["py.typed", "*.pyi"] gridfs=["py.typed", "*.pyi"] + +[tool.ruff] +target-version = "py37" +line-length = 100 +select = [ + "E", "F", "W", # flake8 + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "RET", # flake8-return + "RUF", # Ruff-specific + "S", # flake8-bandit + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable +] +extend-ignore = [ + "PLR", # Design related pylint codes + "E501", # Line too long + "PT004", # Use underscore for non-returning fixture (use usefixture instead) + "UP007", # Use `X | Y` for type annotation + "EM101", # Exception must not use a string literal, assign to variable first + "EM102", # Exception must not use an f-string literal, assign to variable first + "G004", # Logging statement uses f-string" + "UP006", # Use `type` instead of `Type` for type annotation" + "RET505", # Unnecessary `elif` after `return` statement" + "RET506", # Unnecessary `elif` after `raise` statement + "SIM108", # Use ternary operator" + "PTH123", # `open()` should be replaced by `Path.open()`" + "SIM102", # Use a single `if` statement instead of nested `if` statements + "SIM105", # Use `contextlib.suppress(OSError)` instead of `try`-`except`-`pass` + "ARG002", # Unused method argument: + "S101", # Use of `assert` detected + "SIM114", # Combine `if` branches using logical `or` operator + "PGH003", # Use specific rule codes when ignoring type issues + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` + "EM103", # Exception must not use a `.format()` string directly, assign to variable first + "C408", # Unnecessary `dict` call (rewrite as a literal) + "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements +] +unfixable = [ + "RUF100", # Unused noqa + "T20", # Removes print statements + "F841", # Removes unused variables +] +exclude = [] +flake8-unused-arguments.ignore-variadic-names = true +isort.required-imports = ["from __future__ import annotations"] +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?)|dummy.*)$" + +[tool.ruff.per-file-ignores] +"pymongo/__init__.py" = ["E402"] +"test/*.py" = ["PT", "E402", "PLW", "SIM", "E741", "PTH", "S", "B904", "E722", "T201", + "RET", "ARG", "F405", "B028", "PGH001", "B018", "F403", "RUF015", "E731", "B007", + "UP031", "F401", "B023", "F811"] +"green_framework_test.py" = ["T201"] diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 144ce9850..c8f0f712d --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys import warnings @@ -75,7 +77,8 @@ https://pymongo.readthedocs.io/en/stable/installation.html#osx % ( "Extension modules", "There was an issue with your platform configuration - see above.", - ) + ), + stacklevel=2, ) def build_extension(self, ext): @@ -90,9 +93,10 @@ https://pymongo.readthedocs.io/en/stable/installation.html#osx warnings.warn( self.warning_message % ( - "The %s extension module" % (name,), + "The %s extension module" % (name,), # noqa: UP031 "The output above this warning shows how the compilation failed.", - ) + ), + stacklevel=2, ) diff --git a/test/__init__.py b/test/__init__.py index 5abca5a18..cea27c01f 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """Test suite for pymongo, bson, and gridfs.""" +from __future__ import annotations import base64 import gc @@ -29,7 +30,7 @@ import unittest import warnings try: - import ipaddress # noqa + import ipaddress HAVE_IPADDRESS = True except ImportError: @@ -795,11 +796,12 @@ class ClientContext: return True return False - def require_cluster_type(self, topologies=[]): # noqa + def require_cluster_type(self, topologies=None): """Run a test only if the client is connected to a cluster that conforms to one of the specified topologies. Acceptable topologies are 'single', 'replicaset', and 'sharded'. """ + topologies = topologies or [] def _is_valid_topology(): return self.is_topology_type(topologies) @@ -1169,9 +1171,9 @@ def print_running_topology(topology): if running: print( "WARNING: found Topology with running threads:\n" - " Threads: {}\n" - " Topology: {}\n" - " Creation traceback:\n{}".format(running, topology, topology._settings._stack) + f" Threads: {running}\n" + f" Topology: {topology}\n" + f" Creation traceback:\n{topology._settings._stack}" ) @@ -1238,7 +1240,7 @@ def clear_warning_registry(): """Clear the __warningregistry__ for all modules.""" for _, module in list(sys.modules.items()): if hasattr(module, "__warningregistry__"): - setattr(module, "__warningregistry__", {}) # noqa + module.__warningregistry__ = {} # type:ignore[attr-defined] class SystemCertsPatcher: diff --git a/test/atlas/test_connection.py b/test/atlas/test_connection.py index 036e4772f..2c45241ea 100644 --- a/test/atlas/test_connection.py +++ b/test/atlas/test_connection.py @@ -13,6 +13,7 @@ # limitations under the License. """Test connections to various Atlas cluster types.""" +from __future__ import annotations import os import sys diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index e180d8b06..d0bb41b73 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -13,6 +13,7 @@ # limitations under the License. """Test MONGODB-AWS Authentication.""" +from __future__ import annotations import os import sys diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index b1315455f..29de512da 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -13,6 +13,7 @@ # limitations under the License. """Test MONGODB-OIDC Authentication.""" +from __future__ import annotations import os import sys diff --git a/test/conftest.py b/test/conftest.py index 400fd9ed7..b65c64186 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from test import setup, teardown import pytest diff --git a/test/crud_v2_format.py b/test/crud_v2_format.py index f711a125c..8eadad843 100644 --- a/test/crud_v2_format.py +++ b/test/crud_v2_format.py @@ -16,6 +16,7 @@ https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.rst """ +from __future__ import annotations from test.utils_spec_runner import SpecRunner diff --git a/test/lambda/mongodb/app.py b/test/lambda/mongodb/app.py index d56fbec3a..65e6dc88f 100644 --- a/test/lambda/mongodb/app.py +++ b/test/lambda/mongodb/app.py @@ -4,6 +4,8 @@ Lambda function for Python Driver testing Creates the client that is cached for all requests, subscribes to relevant events, and forces the connection pool to get populated. """ +from __future__ import annotations + import json import os diff --git a/test/mockupdb/operations.py b/test/mockupdb/operations.py index 692f9aef0..34302aa55 100644 --- a/test/mockupdb/operations.py +++ b/test/mockupdb/operations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from collections import namedtuple diff --git a/test/mockupdb/test_auth_recovering_member.py b/test/mockupdb/test_auth_recovering_member.py old mode 100755 new mode 100644 index 33d33da24..2051a24af --- a/test/mockupdb/test_auth_recovering_member.py +++ b/test/mockupdb/test_auth_recovering_member.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_cluster_time.py b/test/mockupdb/test_cluster_time.py index e4f3e12d0..a64804541 100644 --- a/test/mockupdb/test_cluster_time.py +++ b/test/mockupdb/test_cluster_time.py @@ -13,6 +13,7 @@ # limitations under the License. """Test $clusterTime handling.""" +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_cursor_namespace.py b/test/mockupdb/test_cursor_namespace.py index 10788ac0f..e6713abf1 100644 --- a/test/mockupdb/test_cursor_namespace.py +++ b/test/mockupdb/test_cursor_namespace.py @@ -13,6 +13,7 @@ # limitations under the License. """Test list_indexes with more than one batch.""" +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_getmore_sharded.py b/test/mockupdb/test_getmore_sharded.py index 5f5400ab0..b06b48f01 100644 --- a/test/mockupdb/test_getmore_sharded.py +++ b/test/mockupdb/test_getmore_sharded.py @@ -13,6 +13,8 @@ # limitations under the License. """Test PyMongo cursor with a sharded cluster.""" +from __future__ import annotations + import unittest from queue import Queue diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 3d002cbbf..00cae32ee 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import unittest @@ -63,7 +64,7 @@ class TestHandshake(unittest.TestCase): client = MongoClient( "mongodb://" + primary.address_string, appname="my app", # For _check_handshake_data() - **dict([k_map.get((k, v), (k, v)) for k, v in kwargs.items()]) # type: ignore[arg-type] + **dict([k_map.get((k, v), (k, v)) for k, v in kwargs.items()]), # type: ignore[arg-type] ) self.addCleanup(client.close) @@ -236,14 +237,12 @@ class TestHandshake(unittest.TestCase): request.reply( OpMsgReply( **primary_response, - **{ - "payload": b"r=wPleNM8S5p8gMaffMDF7Py4ru9bnmmoqb0" - b"1WNPsil6o=pAvr6B1garhlwc6MKNQ93ZfFky" - b"tXdF9r," - b"s=4dcxugMJq2P4hQaDbGXZR8uR3ei" - b"PHrSmh4uhkg==,i=15000", - "saslSupportedMechs": ["SCRAM-SHA-1"], - } + payload=b"r=wPleNM8S5p8gMaffMDF7Py4ru9bnmmoqb0" + b"1WNPsil6o=pAvr6B1garhlwc6MKNQ93ZfFky" + b"tXdF9r," + b"s=4dcxugMJq2P4hQaDbGXZR8uR3ei" + b"PHrSmh4uhkg==,i=15000", + saslSupportedMechs=["SCRAM-SHA-1"], ) ) return None diff --git a/test/mockupdb/test_initial_ismaster.py b/test/mockupdb/test_initial_ismaster.py index 155ae6152..97864dd25 100644 --- a/test/mockupdb/test_initial_ismaster.py +++ b/test/mockupdb/test_initial_ismaster.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import time import unittest diff --git a/test/mockupdb/test_list_indexes.py b/test/mockupdb/test_list_indexes.py index 20764e6e5..163c25c37 100644 --- a/test/mockupdb/test_list_indexes.py +++ b/test/mockupdb/test_list_indexes.py @@ -13,6 +13,7 @@ # limitations under the License. """Test list_indexes with more than one batch.""" +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_max_staleness.py b/test/mockupdb/test_max_staleness.py index 02efb6a71..88a3c13e6 100644 --- a/test/mockupdb/test_max_staleness.py +++ b/test/mockupdb/test_max_staleness.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_mixed_version_sharded.py b/test/mockupdb/test_mixed_version_sharded.py index 7813069c9..515c27987 100644 --- a/test/mockupdb/test_mixed_version_sharded.py +++ b/test/mockupdb/test_mixed_version_sharded.py @@ -13,6 +13,7 @@ # limitations under the License. """Test PyMongo with a mixed-version cluster.""" +from __future__ import annotations import time import unittest diff --git a/test/mockupdb/test_mongos_command_read_mode.py b/test/mockupdb/test_mongos_command_read_mode.py index 62bd76cf0..dff1288e6 100644 --- a/test/mockupdb/test_mongos_command_read_mode.py +++ b/test/mockupdb/test_mongos_command_read_mode.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import itertools import unittest diff --git a/test/mockupdb/test_network_disconnect_primary.py b/test/mockupdb/test_network_disconnect_primary.py old mode 100755 new mode 100644 index 936130484..d05cfb531 --- a/test/mockupdb/test_network_disconnect_primary.py +++ b/test/mockupdb/test_network_disconnect_primary.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py old mode 100755 new mode 100644 index e8542e2fe..dd9525496 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import unittest from collections import namedtuple diff --git a/test/mockupdb/test_op_msg_read_preference.py b/test/mockupdb/test_op_msg_read_preference.py index a3aef1541..0fa7b8486 100644 --- a/test/mockupdb/test_op_msg_read_preference.py +++ b/test/mockupdb/test_op_msg_read_preference.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import itertools diff --git a/test/mockupdb/test_query_read_pref_sharded.py b/test/mockupdb/test_query_read_pref_sharded.py index 7ad4f2afc..529770988 100644 --- a/test/mockupdb/test_query_read_pref_sharded.py +++ b/test/mockupdb/test_query_read_pref_sharded.py @@ -13,6 +13,7 @@ # limitations under the License. """Test PyMongo query and read preference with a sharded cluster.""" +from __future__ import annotations import unittest diff --git a/test/mockupdb/test_reset_and_request_check.py b/test/mockupdb/test_reset_and_request_check.py old mode 100755 new mode 100644 index c55449937..12c0bec9a --- a/test/mockupdb/test_reset_and_request_check.py +++ b/test/mockupdb/test_reset_and_request_check.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import itertools import time diff --git a/test/mockupdb/test_rsghost.py b/test/mockupdb/test_rsghost.py index 354399728..c9f2b89f0 100644 --- a/test/mockupdb/test_rsghost.py +++ b/test/mockupdb/test_rsghost.py @@ -13,6 +13,7 @@ # limitations under the License. """Test connections to RSGhost nodes.""" +from __future__ import annotations import datetime import unittest diff --git a/test/mockupdb/test_slave_okay_rs.py b/test/mockupdb/test_slave_okay_rs.py index 225d8e407..ba5b976d6 100644 --- a/test/mockupdb/test_slave_okay_rs.py +++ b/test/mockupdb/test_slave_okay_rs.py @@ -16,6 +16,7 @@ Just make sure SlaveOkay is *not* set on primary reads. """ +from __future__ import annotations import unittest @@ -50,7 +51,7 @@ class TestSlaveOkayRS(unittest.TestCase): def create_slave_ok_rs_test(operation): def test(self): self.setup_server() - assert not operation.op_type == "always-use-secondary" + assert operation.op_type != "always-use-secondary" client = MongoClient(self.primary.uri, replicaSet="rs") self.addCleanup(client.close) diff --git a/test/mockupdb/test_slave_okay_sharded.py b/test/mockupdb/test_slave_okay_sharded.py index 5a590bcf1..45b7d51ba 100644 --- a/test/mockupdb/test_slave_okay_sharded.py +++ b/test/mockupdb/test_slave_okay_sharded.py @@ -18,6 +18,8 @@ - A direct connection to a slave. - A direct connection to a mongos. """ +from __future__ import annotations + import itertools import unittest from queue import Queue @@ -43,10 +45,7 @@ class TestSlaveOkaySharded(unittest.TestCase): "ismaster", minWireVersion=2, maxWireVersion=6, ismaster=True, msg="isdbgrid" ) - self.mongoses_uri = "mongodb://{},{}".format( - self.mongos1.address_string, - self.mongos2.address_string, - ) + self.mongoses_uri = f"mongodb://{self.mongos1.address_string},{self.mongos2.address_string}" def create_slave_ok_sharded_test(mode, operation): diff --git a/test/mockupdb/test_slave_okay_single.py b/test/mockupdb/test_slave_okay_single.py index 90b99df49..b03232807 100644 --- a/test/mockupdb/test_slave_okay_single.py +++ b/test/mockupdb/test_slave_okay_single.py @@ -18,6 +18,7 @@ - A direct connection to a slave. - A direct connection to a mongos. """ +from __future__ import annotations import itertools import unittest diff --git a/test/mod_wsgi_test/mod_wsgi_test.py b/test/mod_wsgi_test/mod_wsgi_test.py index 77b5475ab..c5f5c3086 100644 --- a/test/mod_wsgi_test/mod_wsgi_test.py +++ b/test/mod_wsgi_test/mod_wsgi_test.py @@ -14,6 +14,7 @@ """Minimal test of PyMongo in a WSGI application, see bug PYTHON-353 """ +from __future__ import annotations import datetime import os @@ -42,10 +43,10 @@ from pymongo.mongo_client import MongoClient assert bson.has_c() assert pymongo.has_c() -OPTS: "CodecOptions[dict]" = CodecOptions( +OPTS: CodecOptions[dict] = CodecOptions( uuid_representation=STANDARD, datetime_conversion=DatetimeConversion.DATETIME_AUTO ) -client: "MongoClient[dict]" = MongoClient() +client: MongoClient[dict] = MongoClient() # Use a unique collection name for each process: coll_name = f"test-{uuid.uuid4()}" collection = client.test.get_collection(coll_name, codec_options=OPTS) diff --git a/test/mod_wsgi_test/test_client.py b/test/mod_wsgi_test/test_client.py index 9f6f59a7f..63ae88347 100644 --- a/test/mod_wsgi_test/test_client.py +++ b/test/mod_wsgi_test/test_client.py @@ -13,6 +13,7 @@ # limitations under the License. """Test client for mod_wsgi application, see bug PYTHON-353.""" +from __future__ import annotations import _thread as thread import random diff --git a/test/mypy_fails/insert_many_dict.py b/test/mypy_fails/insert_many_dict.py index 7cbabc28f..5f9a2d45a 100644 --- a/test/mypy_fails/insert_many_dict.py +++ b/test/mypy_fails/insert_many_dict.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pymongo import MongoClient client: MongoClient = MongoClient() diff --git a/test/mypy_fails/insert_one_list.py b/test/mypy_fails/insert_one_list.py index 12079ffc6..7c27d5cac 100644 --- a/test/mypy_fails/insert_one_list.py +++ b/test/mypy_fails/insert_one_list.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pymongo import MongoClient client: MongoClient = MongoClient() diff --git a/test/mypy_fails/raw_bson_document.py b/test/mypy_fails/raw_bson_document.py index 0e1722487..49f3659e9 100644 --- a/test/mypy_fails/raw_bson_document.py +++ b/test/mypy_fails/raw_bson_document.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from bson.raw_bson import RawBSONDocument from pymongo import MongoClient diff --git a/test/mypy_fails/typedict_client.py b/test/mypy_fails/typedict_client.py index 6619df10f..37c3f0bfc 100644 --- a/test/mypy_fails/typedict_client.py +++ b/test/mypy_fails/typedict_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TypedDict from pymongo import MongoClient diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index dc2650499..de2714cc0 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -13,6 +13,7 @@ # limitations under the License. """Test OCSP.""" +from __future__ import annotations import logging import os diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index 062058e09..2ad4edaf8 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the MongoDB Driver Performance Benchmarking Spec.""" +from __future__ import annotations import multiprocessing as mp import os diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index f804f381b..2b291c7bd 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -13,6 +13,7 @@ # limitations under the License. """Tools for mocking parts of PyMongo to test other parts.""" +from __future__ import annotations import contextlib import weakref diff --git a/test/qcheck.py b/test/qcheck.py index 52e4c46b8..739d4948e 100644 --- a/test/qcheck.py +++ b/test/qcheck.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import datetime import random diff --git a/test/sigstop_sigcont.py b/test/sigstop_sigcont.py index 6f84b6a6a..95a36ad7a 100644 --- a/test/sigstop_sigcont.py +++ b/test/sigstop_sigcont.py @@ -13,6 +13,7 @@ # limitations under the License. """Used by test_client.TestClient.test_sigstop_sigcont.""" +from __future__ import annotations import logging import os diff --git a/test/test_auth.py b/test/test_auth.py index 160e718c0..2240a4b5b 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -13,6 +13,7 @@ # limitations under the License. """Authentication Tests.""" +from __future__ import annotations import os import sys @@ -179,7 +180,7 @@ class TestGSSAPI(unittest.TestCase): client[GSSAPI_DB].list_collection_names() - uri = uri + f"&replicaSet={str(set_name)}" + uri = uri + f"&replicaSet={set_name!s}" client = MongoClient(uri) client[GSSAPI_DB].list_collection_names() @@ -196,7 +197,7 @@ class TestGSSAPI(unittest.TestCase): client[GSSAPI_DB].list_collection_names() - mech_uri = mech_uri + f"&replicaSet={str(set_name)}" + mech_uri = mech_uri + f"&replicaSet={set_name!s}" client = MongoClient(mech_uri) client[GSSAPI_DB].list_collection_names() diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 6a8ec35ec..4976a6dd4 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the auth spec tests.""" +from __future__ import annotations import glob import json diff --git a/test/test_binary.py b/test/test_binary.py index 158a99029..fafb6da16 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the Binary wrapper.""" +from __future__ import annotations import array import base64 diff --git a/test/test_bson.py b/test/test_bson.py index 9849e8b6d..749c63bdf 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -14,6 +14,7 @@ # limitations under the License. """Test the bson module.""" +from __future__ import annotations import array import collections diff --git a/test/test_bson_corpus.py b/test/test_bson_corpus.py index 193a6dff3..96ef458ec 100644 --- a/test/test_bson_corpus.py +++ b/test/test_bson_corpus.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the BSON corpus specification tests.""" +from __future__ import annotations import binascii import codecs diff --git a/test/test_bulk.py b/test/test_bulk.py index 6a2af3143..6619d33f4 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the bulk API.""" +from __future__ import annotations import sys import uuid @@ -830,7 +831,7 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests) self.assertFalse(result.acknowledged) - wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents") + wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents") wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') def test_no_results_ordered_failure(self): @@ -845,7 +846,7 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests) self.assertFalse(result.acknowledged) - wait_until(lambda: 3 == self.coll.count_documents({}), "insert 3 documents") + wait_until(lambda: self.coll.count_documents({}) == 3, "insert 3 documents") self.assertEqual({"_id": 1}, self.coll.find_one({"_id": 1})) def test_no_results_unordered_success(self): @@ -857,7 +858,7 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests, ordered=False) self.assertFalse(result.acknowledged) - wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents") + wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents") wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') def test_no_results_unordered_failure(self): @@ -872,7 +873,7 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests, ordered=False) self.assertFalse(result.acknowledged) - wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents") + wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents") wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') diff --git a/test/test_change_stream.py b/test/test_change_stream.py index a681de4b4..515ee436c 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the change_stream module.""" +from __future__ import annotations import os import random diff --git a/test/test_client.py b/test/test_client.py index b8f0765ea..c929b7525 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the mongo_client module.""" +from __future__ import annotations import _thread as thread import contextlib @@ -558,7 +559,7 @@ class TestClient(IntegrationTest): # connections could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: 1 <= len(server._pool.conns), "replace stale socket") + wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") client.close() def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): @@ -572,7 +573,7 @@ class TestClient(IntegrationTest): # maxPoolSize=1 should prevent two connections from being created. self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: 1 == len(server._pool.conns), "replace stale socket") + wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") client.close() def test_max_idle_time_reaper_removes_stale(self): @@ -588,7 +589,7 @@ class TestClient(IntegrationTest): pass self.assertIs(conn_one, conn_two) wait_until( - lambda: 0 == len(server._pool.conns), + lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) client.close() @@ -603,14 +604,14 @@ class TestClient(IntegrationTest): client = rs_or_single_client(minPoolSize=10) server = client._get_topology().select_server(readable_server_selector) wait_until( - lambda: 10 == len(server._pool.conns), "pool initialized with 10 connections" + lambda: len(server._pool.conns) == 10, "pool initialized with 10 connections" ) # Assert that if a socket is closed, a new one takes its place with server._pool.checkout() as conn: conn.close_conn(None) wait_until( - lambda: 10 == len(server._pool.conns), + lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", ) self.assertFalse(conn in server._pool.conns) @@ -745,7 +746,7 @@ class TestClient(IntegrationTest): def test_repr(self): # Used to test 'eval' below. - import bson # noqa: F401 + import bson client = MongoClient( # type: ignore[type-var] "mongodb://localhost:27017,localhost:27018/?replicaSet=replset" diff --git a/test/test_client_context.py b/test/test_client_context.py index 72da8dbc3..196647cb0 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import sys @@ -40,9 +41,7 @@ class TestClientContext(unittest.TestCase): self.assertTrue( client_context.connected and client_context.serverless, "client context must be connected to serverless when " - "TEST_SERVERLESS is set. Failed attempts:\n{}".format( - client_context.connection_attempt_info() - ), + f"TEST_SERVERLESS is set. Failed attempts:\n{client_context.connection_attempt_info()}", ) def test_enableTestCommands_is_disabled(self): diff --git a/test/test_cmap.py b/test/test_cmap.py index c0e55dbe8..59757434e 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -13,6 +13,7 @@ # limitations under the License. """Execute Transactions Spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_code.py b/test/test_code.py index 9e44ca496..c564e3e04 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for the Code wrapper.""" +from __future__ import annotations import sys diff --git a/test/test_collation.py b/test/test_collation.py index 7f4bbf475..bedf0a2ea 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the collation module.""" +from __future__ import annotations import functools import warnings diff --git a/test/test_collection.py b/test/test_collection.py index bbaac0123..494719245 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the collection module.""" +from __future__ import annotations import contextlib import re @@ -482,7 +483,7 @@ class TestCollection(IntegrationTest): db.test.drop_indexes() self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)])) - for _dummy, info in db.test.index_information().items(): + for dummy, info in db.test.index_information().items(): field, idx_type = info["key"][0] if field == "geo" and idx_type == "2dsphere": break @@ -501,7 +502,7 @@ class TestCollection(IntegrationTest): db.test.drop_indexes() self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)])) - for _dummy, info in db.test.index_information().items(): + for dummy, info in db.test.index_information().items(): field, idx_type = info["key"][0] if field == "a" and idx_type == "hashed": break @@ -715,7 +716,7 @@ class TestCollection(IntegrationTest): self.assertEqual(document["_id"], result.inserted_id) self.assertFalse(result.acknowledged) # The insert failed duplicate key... - wait_until(lambda: 2 == db.test.count_documents({}), "forcing duplicate key error") + wait_until(lambda: db.test.count_documents({}) == 2, "forcing duplicate key error") document = RawBSONDocument(encode({"_id": ObjectId(), "foo": "bar"})) result = db.test.insert_one(document) @@ -816,7 +817,7 @@ class TestCollection(IntegrationTest): self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) - wait_until(lambda: 0 == db.test.count_documents({}), "delete 1 documents") + wait_until(lambda: db.test.count_documents({}) == 0, "delete 1 documents") def test_delete_many(self): self.db.test.drop() @@ -837,7 +838,7 @@ class TestCollection(IntegrationTest): self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) - wait_until(lambda: 0 == db.test.count_documents({}), "delete 2 documents") + wait_until(lambda: db.test.count_documents({}) == 0, "delete 2 documents") def test_command_document_too_large(self): large = "*" * (client_context.max_bson_size + _COMMAND_OVERHEAD) @@ -1537,7 +1538,7 @@ class TestCollection(IntegrationTest): while True: cursor.next() n += 1 - if 3 == n: + if n == 3: self.assertFalse(cursor.alive) break @@ -1848,7 +1849,7 @@ class TestCollection(IntegrationTest): unack_coll = db.collection_2.with_options(write_concern=WriteConcern(w=0)) unack_coll.insert_many(insert_second_fails) wait_until( - lambda: 1 == db.collection_2.count_documents({}), "insert 1 document", timeout=60 + lambda: db.collection_2.count_documents({}) == 1, "insert 1 document", timeout=60 ) db.collection_2.drop() @@ -1878,7 +1879,7 @@ class TestCollection(IntegrationTest): # Only the first and third documents are inserted. wait_until( - lambda: 2 == db.collection_4.count_documents({}), "insert 2 documents", timeout=60 + lambda: db.collection_4.count_documents({}) == 2, "insert 2 documents", timeout=60 ) db.collection_4.drop() diff --git a/test/test_collection_management.py b/test/test_collection_management.py index c5e29eda8..0eacde130 100644 --- a/test/test_collection_management.py +++ b/test/test_collection_management.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the collection management unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_command_monitoring.py b/test/test_command_monitoring.py index c88b7ef81..d2f578824 100644 --- a/test/test_command_monitoring.py +++ b/test/test_command_monitoring.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the command monitoring unified format spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_common.py b/test/test_common.py index f1769cb21..fdd4513d0 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the pymongo common module.""" +from __future__ import annotations import sys import uuid diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 38bda4e2a..ef8500ae6 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -13,6 +13,7 @@ # limitations under the License. """Test compliance with the connections survive primary step down spec.""" +from __future__ import annotations import sys diff --git a/test/test_create_entities.py b/test/test_create_entities.py index 1e46614da..b7965d4a1 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index cc9a521b3..92a60a47f 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the CRUD unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index 46aab2fba..8eb8f3581 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the collection module.""" +from __future__ import annotations import os import sys diff --git a/test/test_csot.py b/test/test_csot.py index c2a62aa7f..e8ee92d4a 100644 --- a/test/test_csot.py +++ b/test/test_csot.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the CSOT unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_cursor.py b/test/test_cursor.py index 37c0335fa..284a5ac97 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -13,6 +13,8 @@ # limitations under the License. """Test the cursor module.""" +from __future__ import annotations + import copy import gc import itertools @@ -1176,7 +1178,7 @@ class TestCursor(IntegrationTest): while True: cursor.next() n += 1 - if 3 == n: + if n == 3: self.assertFalse(cursor.alive) break @@ -1407,7 +1409,7 @@ class TestRawBatchCursor(IntegrationTest): # The batch is a list of one raw bytes object. self.assertEqual(len(csr["firstBatch"]), 1) - self.assertEqual(decode_all(csr["firstBatch"][0]), [{"_id": i} for i in range(0, 4)]) + self.assertEqual(decode_all(csr["firstBatch"][0]), [{"_id": i} for i in range(4)]) listener.reset() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 7e190483a..da4bf0334 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -13,6 +13,7 @@ # limitations under the License. """Test support for callbacks to encode/decode custom types.""" +from __future__ import annotations import datetime import sys diff --git a/test/test_data_lake.py b/test/test_data_lake.py index ac4a56dc6..283ef074c 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -13,6 +13,7 @@ # limitations under the License. """Test Atlas Data Lake.""" +from __future__ import annotations import os import sys diff --git a/test/test_database.py b/test/test_database.py index 041b339e6..b141bb35f 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the database module.""" +from __future__ import annotations import re import sys diff --git a/test/test_dbref.py b/test/test_dbref.py index 107d95d23..d170f43f5 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the dbref module.""" +from __future__ import annotations import pickle import sys diff --git a/test/test_decimal128.py b/test/test_decimal128.py index b46f94f59..46819dd58 100644 --- a/test/test_decimal128.py +++ b/test/test_decimal128.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for Decimal128.""" +from __future__ import annotations import pickle import sys diff --git a/test/test_default_exports.py b/test/test_default_exports.py index 8a85110a7..4b02e0e31 100644 --- a/test/test_default_exports.py +++ b/test/test_default_exports.py @@ -13,6 +13,8 @@ # limitations under the License. """Test the default exports of the top level packages.""" +from __future__ import annotations + import inspect import unittest diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 21e2e3a05..7053f20e1 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the topology module.""" +from __future__ import annotations import os import sys diff --git a/test/test_dns.py b/test/test_dns.py index 9312e37e0..0fe57a4fe 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the SRV support tests.""" +from __future__ import annotations import glob import json diff --git a/test/test_encryption.py b/test/test_encryption.py index 2f61b52ff..2ffb6d493 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -13,6 +13,7 @@ # limitations under the License. """Test client side encryption spec.""" +from __future__ import annotations import base64 import copy @@ -2658,7 +2659,7 @@ class TestRangeQueryProse(EncryptionIntegrationTest): EncryptionError, "expected matching 'min' and value type. Got range option" ): self.client_encryption.encrypt( - int(6) if cast_func != int else float(6), + 6 if cast_func != int else float(6), key_id=self.key1_id, algorithm=Algorithm.RANGEPREVIEW, contention_factor=0, diff --git a/test/test_errors.py b/test/test_errors.py index 747da4847..2cee7c15d 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import pickle import sys diff --git a/test/test_examples.py b/test/test_examples.py index b9508d4f1..e003d8459 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -13,6 +13,7 @@ # limitations under the License. """MongoDB documentation examples in Python.""" +from __future__ import annotations import datetime import sys diff --git a/test/test_fork.py b/test/test_fork.py index 422cd89f2..7b19e4cd8 100644 --- a/test/test_fork.py +++ b/test/test_fork.py @@ -13,6 +13,7 @@ # limitations under the License. """Test that pymongo resets its own locks after a fork.""" +from __future__ import annotations import os import sys diff --git a/test/test_grid_file.py b/test/test_grid_file.py index ec88501a5..344a248b4 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for the grid_file module.""" +from __future__ import annotations import datetime import io diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 4ba8467d2..f94736708 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for the gridfs package.""" +from __future__ import annotations import datetime import sys diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index e5695f2c3..53e5cad54 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -14,6 +14,8 @@ # limitations under the License. """Tests for the gridfs package.""" +from __future__ import annotations + import datetime import itertools import threading diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index d080c05c4..6840b6ae0 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the GridFS unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index a14ab9a3a..5c75ab01d 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the monitoring of the server heartbeats.""" +from __future__ import annotations import sys diff --git a/test/test_json_util.py b/test/test_json_util.py index b7960a16e..a35d736eb 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -13,6 +13,7 @@ # limitations under the License. """Test some utilities for working with JSON and PyMongo.""" +from __future__ import annotations import datetime import json diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index fa4acc3a2..0fb877652 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the Load Balancer unified spec tests.""" +from __future__ import annotations import gc import os diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 1596c1682..a87cbad58 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -13,6 +13,7 @@ # limitations under the License. """Test maxStalenessSeconds support.""" +from __future__ import annotations import os import sys diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 9e83e879a..a1e243884 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -13,6 +13,7 @@ # limitations under the License. """Test MongoClient's mongos load balancing using a mock.""" +from __future__ import annotations import sys import threading @@ -71,7 +72,7 @@ class TestMongosLoadBalancing(MockClientTest): mongoses=["a:1", "b:2", "c:3"], host="a:1,b:2,c:3", connect=False, - **kwargs + **kwargs, ) self.addCleanup(mock_client.close) diff --git a/test/test_monitor.py b/test/test_monitor.py index 9ee3c52ff..0495a8cbc 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the monitor module.""" +from __future__ import annotations import gc import sys diff --git a/test/test_monitoring.py b/test/test_monitoring.py index e135a52e7..6880a30dc 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import datetime diff --git a/test/test_objectid.py b/test/test_objectid.py index cb96feaf3..771ba0942 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the objectid module.""" +from __future__ import annotations import datetime import pickle diff --git a/test/test_ocsp_cache.py b/test/test_ocsp_cache.py index 7fff4fd90..1cc025ccb 100644 --- a/test/test_ocsp_cache.py +++ b/test/test_ocsp_cache.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the pymongo ocsp_support module.""" +from __future__ import annotations import random import sys diff --git a/test/test_on_demand_csfle.py b/test/test_on_demand_csfle.py index 6e3f32c74..bfd07a83e 100644 --- a/test/test_on_demand_csfle.py +++ b/test/test_on_demand_csfle.py @@ -13,6 +13,8 @@ # limitations under the License. """Test client side encryption with on demand credentials.""" +from __future__ import annotations + import os import sys import unittest diff --git a/test/test_pooling.py b/test/test_pooling.py index 81e530a07..007aefa2e 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -13,6 +13,7 @@ # limitations under the License. """Test built in connection-pooling with threads.""" +from __future__ import annotations import gc import random diff --git a/test/test_pymongo.py b/test/test_pymongo.py index 7ec32e16a..d4203ed5c 100644 --- a/test/test_pymongo.py +++ b/test/test_pymongo.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the pymongo module itself.""" +from __future__ import annotations import sys diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index d82e5104c..38b4dd197 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import datetime import sys diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 682fe03e7..97855872c 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the read_concern module.""" +from __future__ import annotations import sys import unittest diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 814874a26..986785faf 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the replica_set_connection module.""" +from __future__ import annotations import contextlib import copy @@ -119,8 +120,8 @@ class TestReadPreferencesBase(IntegrationTest): return "secondary" else: self.fail( - "Cursor used address {}, expected either primary " - "{} or secondaries {}".format(address, client.primary, client.secondaries) + f"Cursor used address {address}, expected either primary " + f"{client.primary} or secondaries {client.secondaries}" ) return None @@ -272,7 +273,7 @@ class TestReadPreferences(TestReadPreferencesBase): self.assertFalse( not_used, "Expected to use primary and all secondaries for mode NEAREST," - " but didn't use {}\nlatencies: {}".format(not_used, latencies), + f" but didn't use {not_used}\nlatencies: {latencies}", ) diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index b27e9fa03..939f05faf 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the read and write concern tests.""" +from __future__ import annotations import json import os diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index bdeaeb06a..1dae0aea8 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -13,6 +13,7 @@ # limitations under the License. """Test clients and replica set configuration changes, using mocks.""" +from __future__ import annotations import sys @@ -168,7 +169,7 @@ class TestSecondaryAdded(MockClientTest): ) self.addCleanup(c.close) - wait_until(lambda: ("a", 1) == c.primary, "discover the primary") + wait_until(lambda: c.primary == ("a", 1), "discover the primary") wait_until(lambda: {("b", 2)} == c.secondaries, "discover the secondary") # C is added. diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index df173ac27..8779ea1ed 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -13,6 +13,7 @@ # limitations under the License. """Test retryable reads spec.""" +from __future__ import annotations import os import pprint diff --git a/test/test_retryable_reads_unified.py b/test/test_retryable_reads_unified.py index 6bf415776..69bee081a 100644 --- a/test/test_retryable_reads_unified.py +++ b/test/test_retryable_reads_unified.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the Retryable Reads unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 89507b33c..2da6f53f4 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -13,6 +13,7 @@ # limitations under the License. """Test retryable writes.""" +from __future__ import annotations import copy import os diff --git a/test/test_retryable_writes_unified.py b/test/test_retryable_writes_unified.py index 4e97c14d4..da16166ec 100644 --- a/test/test_retryable_writes_unified.py +++ b/test/test_retryable_writes_unified.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the Retryable Writes unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_run_command.py b/test/test_run_command.py index 848fd2cb9..486a4c7e3 100644 --- a/test/test_run_command.py +++ b/test/test_run_command.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest from test.unified_format import generate_test_classes diff --git a/test/test_saslprep.py b/test/test_saslprep.py index c07870dad..e825cafa3 100644 --- a/test/test_saslprep.py +++ b/test/test_saslprep.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import sys @@ -24,7 +25,7 @@ from pymongo.saslprep import saslprep class TestSASLprep(unittest.TestCase): def test_saslprep(self): try: - import stringprep # noqa + import stringprep except ImportError: self.assertRaises(TypeError, saslprep, "anything...") # Bytes strings are ignored. diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 2587ae796..f687eab31 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the sdam monitoring spec tests.""" +from __future__ import annotations import json import os @@ -44,8 +45,8 @@ _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sdam_mon def compare_server_descriptions(expected, actual): - if (not expected["address"] == "{}:{}".format(*actual.address)) or ( - not server_name_to_type(expected["type"]) == actual.server_type + if (expected["address"] != "{}:{}".format(*actual.address)) or ( + server_name_to_type(expected["type"]) != actual.server_type ): return False expected_hosts = set(expected["arbiters"] + expected["passives"] + expected["hosts"]) @@ -53,7 +54,7 @@ def compare_server_descriptions(expected, actual): def compare_topology_descriptions(expected, actual): - if not (TOPOLOGY_TYPE.__getattribute__(expected["topologyType"]) == actual.topology_type): + if TOPOLOGY_TYPE.__getattribute__(expected["topologyType"]) != actual.topology_type: return False expected = expected["servers"] actual = actual.server_descriptions() @@ -79,7 +80,7 @@ def compare_events(expected_dict, actual): if expected_type == "server_opening_event": if not isinstance(actual, monitoring.ServerOpeningEvent): return False, "Expected ServerOpeningEvent, got %s" % (actual.__class__) - if not expected["address"] == "{}:{}".format(*actual.server_address): + if expected["address"] != "{}:{}".format(*actual.server_address): return ( False, "ServerOpeningEvent published with wrong address (expected" @@ -90,7 +91,7 @@ def compare_events(expected_dict, actual): if not isinstance(actual, monitoring.ServerDescriptionChangedEvent): return (False, "Expected ServerDescriptionChangedEvent, got %s" % (actual.__class__)) - if not expected["address"] == "{}:{}".format(*actual.server_address): + if expected["address"] != "{}:{}".format(*actual.server_address): return ( False, "ServerDescriptionChangedEvent has wrong address" @@ -110,7 +111,7 @@ def compare_events(expected_dict, actual): elif expected_type == "server_closed_event": if not isinstance(actual, monitoring.ServerClosedEvent): return False, "Expected ServerClosedEvent, got %s" % (actual.__class__) - if not expected["address"] == "{}:{}".format(*actual.server_address): + if expected["address"] != "{}:{}".format(*actual.server_address): return ( False, "ServerClosedEvent published with wrong address" diff --git a/test/test_server.py b/test/test_server.py index 58e39edd7..1d71a614d 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the server module.""" +from __future__ import annotations import sys diff --git a/test/test_server_description.py b/test/test_server_description.py index bb49141d2..ee05e95cf 100644 --- a/test/test_server_description.py +++ b/test/test_server_description.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the server_description module.""" +from __future__ import annotations import sys diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 30b82769b..01f19ad87 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations import os import sys diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index d97c4b4e8..52873882f 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations import os import threading diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index 5c2a8a6fb..a129af458 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the topology module.""" +from __future__ import annotations import json import os diff --git a/test/test_session.py b/test/test_session.py index 18d0122da..c95691be1 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the client_session module.""" +from __future__ import annotations import copy import sys diff --git a/test/test_sessions_unified.py b/test/test_sessions_unified.py index 8a6b8bc9b..c51b4642e 100644 --- a/test/test_sessions_unified.py +++ b/test/test_sessions_unified.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the Sessions unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_son.py b/test/test_son.py index 5e62ffb17..579d765d8 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the son module.""" +from __future__ import annotations import copy import pickle @@ -150,8 +151,8 @@ class TestSON(unittest.TestCase): self.assertIn(1, test_son) self.assertTrue(2 in test_son, "in failed") self.assertFalse(22 in test_son, "in succeeded when it shouldn't") - self.assertTrue(test_son.has_key(2), "has_key failed") # noqa - self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") # noqa + self.assertTrue(test_son.has_key(2), "has_key failed") + self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") def test_clears(self): """Test clear()""" diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 8bf81f4de..d7e910662 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the SRV support tests.""" +from __future__ import annotations import sys from time import sleep diff --git a/test/test_ssl.py b/test/test_ssl.py index e6df2a1c2..bde385138 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -13,11 +13,11 @@ # limitations under the License. """Tests for SSL support.""" +from __future__ import annotations import os import socket import sys -from typing import Any sys.path[0:0] = [""] @@ -40,12 +40,12 @@ from pymongo.write_concern import WriteConcern _HAVE_PYOPENSSL = False try: # All of these must be available to use PyOpenSSL - import OpenSSL # noqa - import requests # noqa - import service_identity # noqa + import OpenSSL + import requests + import service_identity # Ensure service_identity>=18.1 is installed - from service_identity.pyopenssl import verify_ip_address # noqa + from service_identity.pyopenssl import verify_ip_address from pymongo.ocsp_support import _load_trusted_ca_certs @@ -185,7 +185,7 @@ class TestSSL(IntegrationTest): tlsCertificateKeyFilePassword="qwerty", tlsCAFile=CA_PEM, serverSelectionTimeoutMS=5000, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -317,7 +317,7 @@ class TestSSL(IntegrationTest): tlsAllowInvalidCertificates=False, tlsCAFile=CA_PEM, serverSelectionTimeoutMS=500, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -330,7 +330,7 @@ class TestSSL(IntegrationTest): tlsCAFile=CA_PEM, tlsAllowInvalidHostnames=True, serverSelectionTimeoutMS=500, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -345,7 +345,7 @@ class TestSSL(IntegrationTest): tlsAllowInvalidCertificates=False, tlsCAFile=CA_PEM, serverSelectionTimeoutMS=500, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -359,7 +359,7 @@ class TestSSL(IntegrationTest): tlsCAFile=CA_PEM, tlsAllowInvalidHostnames=True, serverSelectionTimeoutMS=500, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -383,7 +383,7 @@ class TestSSL(IntegrationTest): ssl=True, tlsCAFile=CA_PEM, serverSelectionTimeoutMS=1000, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -395,7 +395,7 @@ class TestSSL(IntegrationTest): tlsCAFile=CA_PEM, tlsCRLFile=CRL_PEM, serverSelectionTimeoutMS=1000, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -435,7 +435,7 @@ class TestSSL(IntegrationTest): ssl=True, tlsAllowInvalidHostnames=True, serverSelectionTimeoutMS=1000, - **self.credentials # type: ignore[arg-type] + **self.credentials, # type: ignore[arg-type] ) ) @@ -458,7 +458,6 @@ class TestSSL(IntegrationTest): ): raise SkipTest("Can't test when system CA certificates are loadable.") - ssl_support: Any have_certifi = ssl_support.HAVE_CERTIFI have_wincertstore = ssl_support.HAVE_WINCERTSTORE # Force the test regardless of environment. @@ -477,7 +476,6 @@ class TestSSL(IntegrationTest): # with SSLContext and SSLContext provides no information # about ca_certs. raise SkipTest("Can't test when SSLContext available.") - ssl_support: Any if not ssl_support.HAVE_CERTIFI: raise SkipTest("Need certifi to test certifi support.") diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 9da5a550a..44e673822 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the database module.""" +from __future__ import annotations import sys import time diff --git a/test/test_threads.py b/test/test_threads.py index b948bf924..b3dadbb1a 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -13,6 +13,7 @@ # limitations under the License. """Test that pymongo is thread safe.""" +from __future__ import annotations import threading from test import IntegrationTest, client_context, unittest diff --git a/test/test_timestamp.py b/test/test_timestamp.py index 3602fe280..7495d2ec9 100644 --- a/test/test_timestamp.py +++ b/test/test_timestamp.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the Timestamp class.""" +from __future__ import annotations import copy import datetime diff --git a/test/test_topology.py b/test/test_topology.py index a7bfeb766..88c99d2a2 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the topology module.""" +from __future__ import annotations import sys @@ -647,13 +648,13 @@ class TestMultiServerTopology(TopologyTest): ) self.assertEqual( repr(t.description), - ", " ", " "]>".format(t._topology_id), + " rtt: None>]>", ) def test_unexpected_load_balancer(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index a2901ce45..64b93f0b5 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -13,6 +13,7 @@ # limitations under the License. """Execute Transactions Spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_transactions_unified.py b/test/test_transactions_unified.py index 4f3aa233f..6de4902a8 100644 --- a/test/test_transactions_unified.py +++ b/test/test_transactions_unified.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the Transactions unified spec tests.""" +from __future__ import annotations import os import sys diff --git a/test/test_typing.py b/test/test_typing.py index b2db4b93b..3d6156ce2 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -15,6 +15,8 @@ """Test that each file in mypy_fails/ actually fails mypy, and test some sample client code that uses PyMongo typings. """ +from __future__ import annotations + import os import sys import tempfile @@ -36,7 +38,7 @@ try: year: int class ImplicitMovie(TypedDict): - _id: NotRequired[ObjectId] + _id: NotRequired[ObjectId] # pyright: ignore[reportGeneralTypeIssues] name: str year: int diff --git a/test/test_typing_strict.py b/test/test_typing_strict.py index 55cb1454b..4b03b2bfd 100644 --- a/test/test_typing_strict.py +++ b/test/test_typing_strict.py @@ -13,6 +13,8 @@ # limitations under the License. """Test typings in strict mode.""" +from __future__ import annotations + import unittest from typing import TYPE_CHECKING, Any, Dict diff --git a/test/test_unified_format.py b/test/test_unified_format.py index 8a6e3da54..bc6dbcc5c 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import sys diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index e2dd17ec2..e1e59eb65 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -13,6 +13,7 @@ # limitations under the License. """Test the pymongo uri_parser module.""" +from __future__ import annotations import copy import sys diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 5896b5e12..ad48fe787 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -15,6 +15,7 @@ """Test that the pymongo.uri_parser module is compliant with the connection string and uri options specifications. """ +from __future__ import annotations import json import os diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 3372c1a91..cb25c3f66 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import sys diff --git a/test/test_write_concern.py b/test/test_write_concern.py index 822f3a4d1..e22c7e7a8 100644 --- a/test/test_write_concern.py +++ b/test/test_write_concern.py @@ -13,6 +13,7 @@ # limitations under the License. """Run the unit tests for WriteConcern.""" +from __future__ import annotations import collections import unittest @@ -37,7 +38,7 @@ class TestWriteConcern(unittest.TestCase): concern = WriteConcern() self.assertNotEqual(concern, None) # Explicitly use the != operator. - self.assertTrue(concern != None) # noqa + self.assertTrue(concern != None) # noqa: E711 def test_equality_compatible_type(self): class _FakeWriteConcern: diff --git a/test/unicode/test_utf8.py b/test/unicode/test_utf8.py index 7ce2936b7..fd7fb2154 100644 --- a/test/unicode/test_utf8.py +++ b/test/unicode/test_utf8.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys sys.path[0:0] = [""] diff --git a/test/unified_format.py b/test/unified_format.py index a6676c601..68ce36e6f 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -16,6 +16,8 @@ https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst """ +from __future__ import annotations + import binascii import collections import copy diff --git a/test/utils.py b/test/utils.py index 51a7903c4..c8f9197c6 100644 --- a/test/utils.py +++ b/test/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utilities for testing pymongo""" +from __future__ import annotations import contextlib import copy @@ -542,7 +543,7 @@ class SpecTestCreator: def _connection_string(h): if h.startswith(("mongodb://", "mongodb+srv://")): return h - return f"mongodb://{str(h)}" + return f"mongodb://{h!s}" def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 6967544f0..7952a2862 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -13,6 +13,7 @@ # limitations under the License. """Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations import datetime import os diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 21cc3e6d8..eea96aa1d 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -13,6 +13,7 @@ # limitations under the License. """Utilities for testing driver specs.""" +from __future__ import annotations import functools import threading diff --git a/test/version.py b/test/version.py index 1dd1bec5f..043c760cf 100644 --- a/test/version.py +++ b/test/version.py @@ -13,6 +13,7 @@ # limitations under the License. """Some tools for running tests based on MongoDB server version.""" +from __future__ import annotations class Version(tuple): diff --git a/tools/clean.py b/tools/clean.py index 0ea31fc3d..15db9a411 100644 --- a/tools/clean.py +++ b/tools/clean.py @@ -16,20 +16,21 @@ Only really intended to be used by internal build scripts. """ +from __future__ import annotations -import os import sys +from pathlib import Path try: - os.remove("pymongo/_cmessage.so") - os.remove("bson/_cbson.so") -except BaseException: + Path("pymongo/_cmessage.so").unlink() + Path("bson/_cbson.so").unlink() +except BaseException: # noqa: S110 pass try: - os.remove("pymongo/_cmessage.pyd") - os.remove("bson/_cbson.pyd") -except BaseException: + Path("pymongo/_cmessage.pyd").unlink() + Path("bson/_cbson.pyd").unlink() +except BaseException: # noqa: S110 pass try: diff --git a/tools/ensure_future_annotations_import.py b/tools/ensure_future_annotations_import.py index 03e1bd036..3e7e60bfd 100644 --- a/tools/ensure_future_annotations_import.py +++ b/tools/ensure_future_annotations_import.py @@ -14,16 +14,16 @@ """Ensure that 'from __future__ import annotations' is used in all package files """ +from __future__ import annotations -import glob -import os import sys +from pathlib import Path pattern = "from __future__ import annotations" missing = [] for dirname in ["pymongo", "bson", "gridfs"]: - for path in glob.glob(f"{dirname}/*.py"): - if os.path.basename(path) in ["_version.py", "errors.py"]: + for path in Path(dirname).glob("*.py"): + if Path(path).name in ["_version.py", "errors.py"]: continue found = False with open(path) as fid: @@ -35,7 +35,7 @@ for dirname in ["pymongo", "bson", "gridfs"]: missing.append(path) if missing: - print(f"Missing '{pattern}' import in:") + print(f"Missing '{pattern}' import in:") # noqa: T201 for item in missing: - print(item) + print(item) # noqa: T201 sys.exit(1) diff --git a/tools/fail_if_no_c.py b/tools/fail_if_no_c.py index 60fed0ee8..2b59521c7 100644 --- a/tools/fail_if_no_c.py +++ b/tools/fail_if_no_c.py @@ -16,11 +16,12 @@ Only really intended to be used by internal build scripts. """ +from __future__ import annotations -import glob import os import subprocess import sys +from pathlib import Path sys.path[0:0] = [""] @@ -31,11 +32,11 @@ if not pymongo.has_c() or not bson.has_c(): sys.exit("could not load C extensions") if os.environ.get("ENSURE_UNIVERSAL2") == "1": - parent_dir = os.path.dirname(pymongo.__path__[0]) + parent_dir = Path(pymongo.__path__[0]).parent for pkg in ["pymongo", "bson", "grifs"]: - for so_file in glob.glob(f"{parent_dir}/{pkg}/*.so"): - print(f"Checking universal2 compatibility in {so_file}...") - output = subprocess.check_output(["file", so_file]) + for so_file in Path(f"{parent_dir}/{pkg}").glob("*.so"): + print(f"Checking universal2 compatibility in {so_file}...") # noqa: T201 + output = subprocess.check_output(["file", so_file]) # noqa: S603, S607 if "arm64" not in output.decode("utf-8"): sys.exit("Universal wheel was not compiled with arm64 support") if "x86_64" not in output.decode("utf-8"): diff --git a/tools/ocsptest.py b/tools/ocsptest.py index 702f15ee9..521d048f7 100644 --- a/tools/ocsptest.py +++ b/tools/ocsptest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import annotations import argparse import logging