PYTHON-4005 Replace flake8 and isort with ruff (#1399)

This commit is contained in:
Steven Silvester 2023-10-19 11:56:22 -05:00 committed by GitHub
parent 1f7b74f37d
commit 992d1507e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
189 changed files with 634 additions and 454 deletions

30
.flake8
View File

@ -1,30 +0,0 @@
[flake8]
max-line-length = 100
enable-extensions = G
extend-ignore =
G200, G202,
# black adds spaces around ':'
E203,
# E501 line too long (let black handle line length)
E501
# B305 `.next()` is not a thing on Python 3
B305
per-file-ignores =
# E402 module level import not at top of file
pymongo/__init__.py: E402
# G004 Logging statement uses f-string
pymongo/event_loggers.py: G004
# E402 module level import not at top of file
# B011 Do not call assert False since python -O removes these calls
# F405 'Foo' may be undefined, or defined from star imports
# E741 ambiguous variable name
# B007 Loop control variable 'foo' not used within the loop body
# F403 'from foo import *' used; unable to detect undefined names
# B001 Do not use bare `except:`
# E722 do not use bare 'except'
# E731 do not assign a lambda expression, use a def
# F811 redefinition of unused 'foo' from line XXX
# F841 local variable 'foo' is assigned to but never used
test/*: E402, B011, F405, E741, B007, F403, B001, E722, E731, F811, F841

View File

@ -24,12 +24,12 @@ repos:
files: \.py$
args: [--line-length=100]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.0
hooks:
- id: isort
files: \.py$
args: [--profile=black]
- id: ruff
args: ["--fix", "--show-fixes"]
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.13.0"
@ -38,18 +38,6 @@ repos:
additional_dependencies:
- black==22.3.0
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
files: \.py$
additional_dependencies: [
'flake8-bugbear==20.1.4',
'flake8-logging-format==0.6.0',
'flake8-implicit-str-concat==0.2.0',
]
stages: [manual]
# We use the Python version instead of the original version which seems to require Docker
# https://github.com/koalaman/shellcheck-precommit
- repo: https://github.com/shellcheck-py/shellcheck-py

View File

@ -244,7 +244,7 @@ def _raise_unknown_type(element_type: int, element_name: str) -> NoReturn:
def _get_int(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[int, int]:
"""Decode a BSON int32 to python int."""
return _UNPACK_INT_FROM(data, position)[0], position + 4
@ -257,7 +257,7 @@ def _get_c_string(data: Any, view: Any, position: int, opts: CodecOptions[Any])
def _get_float(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[float, int]:
"""Decode a BSON double to python float."""
return _UNPACK_FLOAT_FROM(data, position)[0], position + 8
@ -282,7 +282,7 @@ def _get_object_size(data: Any, position: int, obj_end: int) -> Tuple[int, int]:
try:
obj_size = _UNPACK_INT_FROM(data, position)[0]
except struct.error as exc:
raise InvalidBSON(str(exc))
raise InvalidBSON(str(exc)) from None
end = position + obj_size - 1
if data[end] != 0:
raise InvalidBSON("bad eoo")
@ -358,7 +358,7 @@ def _get_array(
def _get_binary(
data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy1: Any
data: Any, _view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy1: Any
) -> Tuple[Union[Binary, uuid.UUID], int]:
"""Decode a BSON binary to bson.binary.Binary or python UUID."""
length, subtype = _UNPACK_LENGTH_SUBTYPE_FROM(data, position)
@ -395,7 +395,7 @@ def _get_binary(
def _get_oid(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[ObjectId, int]:
"""Decode a BSON ObjectId to bson.objectid.ObjectId."""
end = position + 12
@ -403,7 +403,7 @@ def _get_oid(
def _get_boolean(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[bool, int]:
"""Decode a BSON true/false to python True/False."""
end = position + 1
@ -416,7 +416,7 @@ def _get_boolean(
def _get_date(
data: Any, view: Any, position: int, dummy0: int, opts: CodecOptions[Any], dummy1: Any
data: Any, _view: Any, position: int, dummy0: int, opts: CodecOptions[Any], dummy1: Any
) -> Tuple[Union[datetime.datetime, DatetimeMS], int]:
"""Decode a BSON datetime to python datetime.datetime."""
return _millis_to_datetime(_UNPACK_LONG_FROM(data, position)[0], opts), position + 8
@ -431,7 +431,7 @@ def _get_code(
def _get_code_w_scope(
data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], element_name: str
data: Any, view: Any, position: int, _obj_end: int, opts: CodecOptions[Any], element_name: str
) -> Tuple[Code, int]:
"""Decode a BSON code_w_scope to bson.code.Code."""
code_end = position + _UNPACK_INT_FROM(data, position)[0]
@ -462,7 +462,7 @@ def _get_ref(
def _get_timestamp(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[Timestamp, int]:
"""Decode a BSON timestamp to bson.timestamp.Timestamp."""
inc, timestamp = _UNPACK_TIMESTAMP_FROM(data, position)
@ -470,14 +470,14 @@ def _get_timestamp(
def _get_int64(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[Int64, int]:
"""Decode a BSON int64 to bson.int64.Int64."""
return Int64(_UNPACK_LONG_FROM(data, position)[0]), position + 8
def _get_decimal128(
data: Any, view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any
) -> Tuple[Decimal128, int]:
"""Decode a BSON decimal128 to bson.decimal128.Decimal128."""
end = position + 16
@ -496,11 +496,11 @@ _ELEMENT_GETTER: dict[int, Callable[..., Tuple[Any, int]]] = {
ord(BSONOBJ): _get_object,
ord(BSONARR): _get_array,
ord(BSONBIN): _get_binary,
ord(BSONUND): lambda u, v, w, x, y, z: (None, w), # Deprecated undefined
ord(BSONUND): lambda u, v, w, x, y, z: (None, w), # noqa: ARG005 # Deprecated undefined
ord(BSONOID): _get_oid,
ord(BSONBOO): _get_boolean,
ord(BSONDAT): _get_date,
ord(BSONNUL): lambda u, v, w, x, y, z: (None, w),
ord(BSONNUL): lambda u, v, w, x, y, z: (None, w), # noqa: ARG005
ord(BSONRGX): _get_regex,
ord(BSONREF): _get_ref, # Deprecated DBPointer
ord(BSONCOD): _get_code,
@ -510,8 +510,8 @@ _ELEMENT_GETTER: dict[int, Callable[..., Tuple[Any, int]]] = {
ord(BSONTIM): _get_timestamp,
ord(BSONLON): _get_int64,
ord(BSONDEC): _get_decimal128,
ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w),
ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w),
ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w), # noqa: ARG005
ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w), # noqa: ARG005
}
@ -519,7 +519,7 @@ if _USE_C:
def _element_to_dict(
data: Any,
view: Any,
view: Any, # noqa: ARG001
position: int,
obj_end: int,
opts: CodecOptions[Any],
@ -615,11 +615,11 @@ def _bson_to_dict(data: Any, opts: CodecOptions[_DocumentType]) -> _DocumentType
except Exception:
# Change exception type to InvalidBSON but preserve traceback.
_, exc_value, exc_tb = sys.exc_info()
raise InvalidBSON(str(exc_value)).with_traceback(exc_tb)
raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None
if _USE_C:
_bson_to_dict = _cbson._bson_to_dict # noqa: F811
_bson_to_dict = _cbson._bson_to_dict
_PACK_FLOAT = struct.Struct("<d").pack
@ -653,7 +653,9 @@ def _make_c_string_check(string: Union[str, bytes]) -> bytes:
_utf_8_decode(string, None, True)
return string + b"\x00"
except UnicodeError:
raise InvalidStringData("strings in documents must be valid UTF-8: %r" % string)
raise InvalidStringData(
"strings in documents must be valid UTF-8: %r" % string
) from None
else:
if "\x00" in string:
raise InvalidDocument("BSON keys / regex patterns must not contain a NUL character")
@ -667,7 +669,9 @@ def _make_c_string(string: Union[str, bytes]) -> bytes:
_utf_8_decode(string, None, True)
return string + b"\x00"
except UnicodeError:
raise InvalidStringData("strings in documents must be valid UTF-8: %r" % string)
raise InvalidStringData(
"strings in documents must be valid UTF-8: %r" % string
) from None
else:
return _utf_8_encode(string)[0] + b"\x00"
@ -817,7 +821,7 @@ def _encode_int(name: bytes, value: int, dummy0: Any, dummy1: Any) -> bytes:
try:
return b"\x12" + name + _PACK_LONG(value)
except struct.error:
raise OverflowError("BSON can only handle up to 8-byte ints")
raise OverflowError("BSON can only handle up to 8-byte ints") from None
def _encode_timestamp(name: bytes, value: Any, dummy0: Any, dummy1: Any) -> bytes:
@ -830,7 +834,7 @@ def _encode_long(name: bytes, value: Any, dummy0: Any, dummy1: Any) -> bytes:
try:
return b"\x12" + name + _PACK_LONG(value)
except struct.error:
raise OverflowError("BSON can only handle up to 8-byte ints")
raise OverflowError("BSON can only handle up to 8-byte ints") from None
def _encode_decimal128(name: bytes, value: Decimal128, dummy0: Any, dummy1: Any) -> bytes:
@ -995,14 +999,14 @@ def _dict_to_bson(
if not top_level or key != "_id":
elements.append(_element_to_bson(key, value, check_keys, opts))
except AttributeError:
raise TypeError(f"encoder expected a mapping type but got: {doc!r}")
raise TypeError(f"encoder expected a mapping type but got: {doc!r}") from None
encoded = b"".join(elements)
return _PACK_INT(len(encoded) + 5) + encoded + b"\x00"
if _USE_C:
_dict_to_bson = _cbson._dict_to_bson # noqa: F811
_dict_to_bson = _cbson._dict_to_bson
_CODEC_OPTIONS_TYPE_ERROR = TypeError("codec_options must be an instance of CodecOptions")
@ -1110,11 +1114,11 @@ def _decode_all(data: _ReadableBuffer, opts: CodecOptions[_DocumentType]) -> lis
except Exception:
# Change exception type to InvalidBSON but preserve traceback.
_, exc_value, exc_tb = sys.exc_info()
raise InvalidBSON(str(exc_value)).with_traceback(exc_tb)
raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None
if _USE_C:
_decode_all = _cbson._decode_all # noqa: F811
_decode_all = _cbson._decode_all
@overload
@ -1207,7 +1211,7 @@ def _array_of_documents_to_buffer(view: memoryview) -> bytes:
if _USE_C:
_array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer # noqa: F811
_array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer
def _convert_raw_document_lists_to_streams(document: Any) -> None:

View File

@ -242,7 +242,7 @@ class Binary(bytes):
@classmethod
def from_uuid(
cls: Type[Binary], uuid: UUID, uuid_representation: int = UuidRepresentation.STANDARD
) -> "Binary":
) -> Binary:
"""Create a BSON Binary object from a Python UUID.
Creates a :class:`~bson.binary.Binary` object from a

View File

@ -168,7 +168,7 @@ class TypeRegistry:
if issubclass(cast(TypeCodec, codec).python_type, pytype):
err_msg = (
"TypeEncoders cannot change how built-in types are "
"encoded (encoder {} transforms type {})".format(codec, pytype)
f"encoded (encoder {codec} transforms type {pytype})"
)
raise TypeError(err_msg)

View File

@ -75,10 +75,10 @@ class DatetimeMS:
def __repr__(self) -> str:
return type(self).__name__ + "(" + str(self._value) + ")"
def __lt__(self, other: Union["DatetimeMS", int]) -> bool:
def __lt__(self, other: Union[DatetimeMS, int]) -> bool:
return self._value < other
def __le__(self, other: Union["DatetimeMS", int]) -> bool:
def __le__(self, other: Union[DatetimeMS, int]) -> bool:
return self._value <= other
def __eq__(self, other: Any) -> bool:
@ -91,10 +91,10 @@ class DatetimeMS:
return self._value != other._value
return True
def __gt__(self, other: Union["DatetimeMS", int]) -> bool:
def __gt__(self, other: Union[DatetimeMS, int]) -> bool:
return self._value > other
def __ge__(self, other: Union["DatetimeMS", int]) -> bool:
def __ge__(self, other: Union[DatetimeMS, int]) -> bool:
return self._value >= other
_type_marker = 9

View File

@ -89,7 +89,7 @@ class DBRef:
try:
return self.__kwargs[key]
except KeyError:
raise AttributeError(key)
raise AttributeError(key) from None
def as_doc(self) -> SON[str, Any]:
"""Get the SON document representation of this DBRef.

View File

@ -298,7 +298,7 @@ class Decimal128:
return str(dec)
def __repr__(self) -> str:
return f"Decimal128('{str(self)}')"
return f"Decimal128('{self!s}')"
def __setstate__(self, value: Tuple[int, int]) -> None:
self.__high, self.__low = value

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Exceptions raised by the BSON package."""
from __future__ import annotations
class BSONError(Exception):

View File

@ -737,8 +737,7 @@ def _parse_canonical_regex(doc: Any) -> Regex[str]:
raise TypeError(f"Bad $regularExpression, extra field(s): {doc}")
if len(regex) != 2:
raise TypeError(
'Bad $regularExpression must include only "pattern"'
'and "options" components: {}'.format(doc)
f'Bad $regularExpression must include only "pattern and "options" components: {doc}'
)
opts = regex["options"]
if not isinstance(opts, str):
@ -812,7 +811,7 @@ def _parse_canonical_decimal128(doc: Any) -> Decimal128:
def _parse_canonical_minkey(doc: Any) -> MinKey:
"""Decode a JSON MinKey to bson.min_key.MinKey."""
if type(doc["$minKey"]) is not int or doc["$minKey"] != 1:
if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: # noqa: E721
raise TypeError(f"$minKey value must be 1: {doc}")
if len(doc) != 1:
raise TypeError(f"Bad $minKey, extra field(s): {doc}")
@ -821,7 +820,7 @@ def _parse_canonical_minkey(doc: Any) -> MinKey:
def _parse_canonical_maxkey(doc: Any) -> MaxKey:
"""Decode a JSON MaxKey to bson.max_key.MaxKey."""
if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1:
if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: # noqa: E721
raise TypeError("$maxKey value must be 1: %s", (doc,))
if len(doc) != 1:
raise TypeError(f"Bad $minKey, extra field(s): {doc}")

View File

@ -57,7 +57,7 @@ class ObjectId:
_type_marker = 7
def __init__(self, oid: Optional[Union[str, "ObjectId", bytes]] = None) -> None:
def __init__(self, oid: Optional[Union[str, ObjectId, bytes]] = None) -> None:
"""Initialize a new ObjectId.
An ObjectId is a 12-byte unique identifier consisting of:
@ -103,7 +103,7 @@ class ObjectId:
self.__validate(oid)
@classmethod
def from_datetime(cls: Type["ObjectId"], generation_time: datetime.datetime) -> "ObjectId":
def from_datetime(cls: Type[ObjectId], generation_time: datetime.datetime) -> ObjectId:
"""Create a dummy ObjectId instance with a specific generation time.
This method is useful for doing range queries on a field
@ -138,7 +138,7 @@ class ObjectId:
return cls(oid)
@classmethod
def is_valid(cls: Type["ObjectId"], oid: Any) -> bool:
def is_valid(cls: Type[ObjectId], oid: Any) -> bool:
"""Checks if a `oid` string is valid or not.
:Parameters:
@ -245,7 +245,7 @@ class ObjectId:
return binascii.hexlify(self.__id).decode()
def __repr__(self) -> str:
return f"ObjectId('{str(self)}')"
return f"ObjectId('{self!s}')"
def __eq__(self, other: Any) -> bool:
if isinstance(other, ObjectId):

View File

@ -55,9 +55,8 @@ from __future__ import annotations
from typing import Any, ItemsView, Iterator, Mapping, MutableMapping, Optional
from bson import _get_object_size, _raw_to_dict
from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER
from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER, CodecOptions
from bson.codec_options import DEFAULT_CODEC_OPTIONS as DEFAULT
from bson.codec_options import CodecOptions
from bson.son import SON
@ -135,7 +134,7 @@ class RawBSONDocument(Mapping[str, Any]):
elif not issubclass(codec_options.document_class, RawBSONDocument):
raise TypeError(
"RawBSONDocument cannot use CodecOptions with document "
"class {}".format(codec_options.document_class)
f"class {codec_options.document_class}"
)
self.__codec_options = codec_options
# Validate the bson object size.
@ -180,11 +179,7 @@ class RawBSONDocument(Mapping[str, Any]):
return NotImplemented
def __repr__(self) -> str:
return "{}({!r}, codec_options={!r})".format(
self.__class__.__name__,
self.raw,
self.__codec_options,
)
return f"{self.__class__.__name__}({self.raw!r}, codec_options={self.__codec_options!r})"
class _RawArrayBSONDocument(RawBSONDocument):

View File

@ -139,7 +139,7 @@ class SON(Dict[_Key, _Value]):
try:
k, v = next(iter(self.items()))
except StopIteration:
raise KeyError("container is empty")
raise KeyError("container is empty") from None
del self[k]
return (k, v)
@ -151,7 +151,7 @@ class SON(Dict[_Key, _Value]):
for k, v in other.items():
self[k] = v
elif hasattr(other, "keys"):
for k in other.keys():
for k in other:
self[k] = other[k]
else:
for k, v in other:
@ -204,6 +204,6 @@ class SON(Dict[_Key, _Value]):
memo[val_id] = out
for k, v in self.items():
if not isinstance(v, RE_TYPE):
v = copy.deepcopy(v, memo)
v = copy.deepcopy(v, memo) # noqa: PLW2901
out[k] = v
return out

View File

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
#
# PyMongo documentation build configuration file
#
# This file is execfile()d with the current directory set to its containing dir.
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [os.path.abspath("..")]
sys.path[0:0] = [Path("..").resolve()]
import pymongo # noqa
import pymongo # noqa: E402
# -- General configuration -----------------------------------------------------
@ -26,7 +26,7 @@ extensions = [
# Add optional extensions
try:
import sphinxcontrib.shellcheck # noqa
import sphinxcontrib.shellcheck # noqa: F401
extensions += ["sphinxcontrib.shellcheck"]
except ImportError:
@ -94,7 +94,7 @@ linkcheck_ignore = [
# -- Options for extensions ----------------------------------------------------
autoclass_content = "init"
doctest_path = [os.path.abspath("..")]
doctest_path = [Path("..").resolve()]
doctest_test_doctest_blocks = ""
@ -108,7 +108,7 @@ db = client.doctest_test
# -- Options for HTML output ---------------------------------------------------
try:
import furo # noqa
import furo # noqa: F401
html_theme = "furo"
except ImportError:

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test PyMongo with a variety of greenlet-based monkey-patching frameworks."""
from __future__ import annotations
import getopt
import sys
@ -65,13 +66,10 @@ def run(framework_name, *args):
def main():
"""Parse options and run tests."""
usage = """python %s FRAMEWORK_NAME
usage = f"""python {sys.argv[0]} FRAMEWORK_NAME
Test PyMongo with a variety of greenlet-based monkey-patching frameworks. See
python %s --help-frameworks.""" % (
sys.argv[0],
sys.argv[0],
)
python {sys.argv[0]} --help-frameworks."""
try:
opts, args = getopt.getopt(sys.argv[1:], "h", ["help", "help-frameworks"])

View File

@ -224,7 +224,7 @@ class GridFS:
doc = next(cursor)
return GridOut(self.__collection, file_document=doc, session=session)
except StopIteration:
raise NoFile("no version %d for filename %r" % (version, filename))
raise NoFile("no version %d for filename %r" % (version, filename)) from None
def get_last_version(
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
@ -932,7 +932,7 @@ class GridFSBucket:
grid_file = next(cursor)
return GridOut(self._collection, file_document=grid_file, session=session)
except StopIteration:
raise NoFile("no version %d for filename %r" % (revision, filename))
raise NoFile("no version %d for filename %r" % (revision, filename)) from None
@_csot.apply
def download_to_stream_by_name(

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Exceptions raised by the :mod:`gridfs` package"""
from __future__ import annotations
from pymongo.errors import PyMongoError

View File

@ -357,12 +357,14 @@ class GridIn:
except AttributeError:
# string
if not isinstance(data, (str, bytes)):
raise TypeError("can only write strings or file-like objects")
raise TypeError("can only write strings or file-like objects") from None
if isinstance(data, str):
try:
data = data.encode(self.encoding)
except AttributeError:
raise TypeError("must specify an encoding for file in order to write str")
raise TypeError(
"must specify an encoding for file in order to write str"
) from None
read = io.BytesIO(data).read
if self._buffer.tell() > 0:
@ -395,7 +397,7 @@ class GridIn:
def writeable(self) -> bool:
return True
def __enter__(self) -> "GridIn":
def __enter__(self) -> GridIn:
"""Support for the context manager protocol."""
return self
@ -671,7 +673,7 @@ class GridOut(io.IOBase):
def seekable(self) -> bool:
return True
def __iter__(self) -> "GridOut":
def __iter__(self) -> GridOut:
"""Return an iterator over all of this file's data.
The iterator will return lines (delimited by ``b'\\n'``) of
@ -708,7 +710,7 @@ class GridOut(io.IOBase):
def writable(self) -> bool:
return False
def __enter__(self) -> "GridOut":
def __enter__(self) -> GridOut:
"""Makes it possible to use :class:`GridOut` files
with the context manager protocol.
"""
@ -773,7 +775,7 @@ class _GridOutChunkIterator:
return self._chunk_size
return self._length - (self._chunk_size * (self._num_chunks - 1))
def __iter__(self) -> "_GridOutChunkIterator":
def __iter__(self) -> _GridOutChunkIterator:
return self
def _create_cursor(self) -> None:
@ -806,7 +808,7 @@ class _GridOutChunkIterator:
except StopIteration:
if self._next_chunk >= self._num_chunks:
raise
raise CorruptGridFile("no chunk #%d" % self._next_chunk)
raise CorruptGridFile("no chunk #%d" % self._next_chunk) from None
if chunk["n"] != self._next_chunk:
self.close()
@ -847,7 +849,7 @@ class GridOutIterator:
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0)
def __iter__(self) -> "GridOutIterator":
def __iter__(self) -> GridOutIterator:
return self
def next(self) -> bytes:
@ -914,6 +916,6 @@ class GridOutCursor(Cursor):
def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("Method does not exist for GridOutCursor")
def _clone_base(self, session: Optional[ClientSession]) -> "GridOutCursor":
def _clone_base(self, session: Optional[ClientSession]) -> GridOutCursor:
"""Creates an empty GridOutCursor for information to be copied into."""
return GridOutCursor(self.__root_collection, session=session)

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""Current version of PyMongo."""
from __future__ import annotations
from typing import Tuple, Union
version_tuple: Tuple[Union[int, str], ...] = (4, 6, 0, ".dev0")

View File

@ -325,7 +325,7 @@ def _password_digest(username: str, password: str) -> str:
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
md5hash = hashlib.md5()
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
@ -334,7 +334,7 @@ def _password_digest(username: str, password: str) -> str:
def _auth_key(nonce: str, username: str, password: str) -> str:
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5()
md5hash = hashlib.md5() # noqa: S324
data = f"{nonce}{username}{digest}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
@ -469,7 +469,7 @@ def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None
kerberos.authGSSClientClean(ctx)
except kerberos.KrbError as exc:
raise OperationFailure(str(exc))
raise OperationFailure(str(exc)) from None
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:

View File

@ -35,7 +35,7 @@ try:
set_use_cached_credentials(True)
except ImportError:
def set_cached_credentials(creds: Optional[AwsCredential]) -> None:
def set_cached_credentials(_creds: Optional[AwsCredential]) -> None:
pass
@ -110,7 +110,9 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
# Convert to OperationFailure and include pymongo-auth-aws version.
raise OperationFailure(f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})")
raise OperationFailure(
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
) from None
except Exception:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)

View File

@ -144,7 +144,7 @@ class _OIDCAuthenticator:
if principal_name:
payload["n"] = principal_name
cmd = SON(
return SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
@ -152,7 +152,6 @@ class _OIDCAuthenticator:
("autoAuthorize", 1),
]
)
return cmd
def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]:
# TODO: DRIVERS-2672, check for provider_name in self.properties here.
@ -207,7 +206,7 @@ class _OIDCAuthenticator:
self.idp_info = server_resp
# Handle the case of changed idp info.
if not self.idp_info == prev_idp_info:
if self.idp_info != prev_idp_info:
self.access_token = None
self.refresh_token = None

View File

@ -221,7 +221,7 @@ class _Bulk:
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd: dict[str, Any] = dict(
cmd: dict[str, Any] = dict( # noqa: C406
[("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]
)
if collation is not None:

View File

@ -225,7 +225,7 @@ class ChangeStream(Generic[_DocumentType]):
if self._start_at_operation_time is None:
raise OperationFailure(
"Expected field 'operationTime' missing from command "
"response : {!r}".format(result)
f"response : {result!r}"
)
def _run_aggregation_cmd(
@ -264,7 +264,7 @@ class ChangeStream(Generic[_DocumentType]):
self._closed = True
self._cursor.close()
def __iter__(self) -> "ChangeStream[_DocumentType]":
def __iter__(self) -> ChangeStream[_DocumentType]:
return self
@property
@ -406,7 +406,7 @@ class ChangeStream(Generic[_DocumentType]):
self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume token is missing."
)
) from None
# If this is the last change document from the current batch, cache the
# postBatchResumeToken.

View File

@ -200,7 +200,7 @@ class SessionOptions:
def __init__(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional["TransactionOptions"] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
) -> None:
if snapshot:
@ -227,7 +227,7 @@ class SessionOptions:
return self._causal_consistency
@property
def default_transaction_options(self) -> Optional["TransactionOptions"]:
def default_transaction_options(self) -> Optional[TransactionOptions]:
"""The default TransactionOptions to use for transactions started on
this session.
@ -287,25 +287,25 @@ class TransactionOptions:
if not isinstance(read_concern, ReadConcern):
raise TypeError(
"read_concern must be an instance of "
"pymongo.read_concern.ReadConcern, not: {!r}".format(read_concern)
f"pymongo.read_concern.ReadConcern, not: {read_concern!r}"
)
if write_concern is not None:
if not isinstance(write_concern, WriteConcern):
raise TypeError(
"write_concern must be an instance of "
"pymongo.write_concern.WriteConcern, not: {!r}".format(write_concern)
f"pymongo.write_concern.WriteConcern, not: {write_concern!r}"
)
if not write_concern.acknowledged:
raise ConfigurationError(
"transactions do not support unacknowledged write concern"
": {!r}".format(write_concern)
f": {write_concern!r}"
)
if read_preference is not None:
if not isinstance(read_preference, _ServerMode):
raise TypeError(
"{!r} is not valid for read_preference. See "
f"{read_preference!r} is not valid for read_preference. See "
"pymongo.read_preferences for valid "
"options.".format(read_preference)
"options."
)
if max_commit_time_ms is not None:
if not isinstance(max_commit_time_ms, int):
@ -354,7 +354,7 @@ def _validate_session_write_concern(
else:
raise ConfigurationError(
"Explicit sessions are incompatible with "
"unacknowledged write concern: {!r}".format(write_concern)
f"unacknowledged write concern: {write_concern!r}"
)
return session
@ -535,7 +535,7 @@ class ClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def __enter__(self) -> "ClientSession":
def __enter__(self) -> ClientSession:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
@ -585,7 +585,7 @@ class ClientSession:
def with_transaction(
self,
callback: Callable[["ClientSession"], _T],
callback: Callable[[ClientSession], _T],
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
@ -838,7 +838,7 @@ class ClientSession:
"""
def func(
session: Optional[ClientSession], conn: Connection, retryable: bool
_session: Optional[ClientSession], conn: Connection, _retryable: bool
) -> dict[str, Any]:
return self._finish_transaction(conn, command_name)
@ -1002,8 +1002,7 @@ class ClientSession:
if self.in_transaction:
if read_preference != ReadPreference.PRIMARY:
raise InvalidOperation(
"read preference in a transaction must be primary, not: "
"{!r}".format(read_preference)
f"read preference in a transaction must be primary, not: {read_preference!r}"
)
if self._transaction.state == _TxnState.STARTING:

View File

@ -368,8 +368,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if name.startswith("_"):
full_name = f"{self.__name}.{name}"
raise AttributeError(
"Collection has no attribute {!r}. To access the {}"
" collection, use database['{}'].".format(name, full_name, full_name)
f"Collection has no attribute {name!r}. To access the {full_name}"
f" collection, use database['{full_name}']."
)
return self.__getitem__(name)
@ -563,7 +563,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
try:
request._add_to_bulk(blk)
except AttributeError:
raise TypeError(f"{request!r} is not a valid request")
raise TypeError(f"{request!r} is not a valid request") from None
write_concern = self._write_concern_for(session)
bulk_api_result = blk.execute(write_concern, session)
@ -1812,7 +1812,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
_server: Server,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> int:
@ -1901,7 +1901,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
_server: Server,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> int:
@ -2277,7 +2277,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[MutableMapping[str, Any]]:
@ -2348,7 +2348,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
info = {}
for index in cursor:
index["key"] = list(index["key"].items())
index = dict(index)
index = dict(index) # noqa: PLW2901
info[index.pop("name")] = index
return info
@ -3038,7 +3038,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
_server: Server,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> list:

View File

@ -75,7 +75,7 @@ class CommandCursor(Generic[_DocumentType]):
if self.__killed:
self.__end_session(True)
if "ns" in cursor_info:
if "ns" in cursor_info: # noqa: SIM401
self.__ns = cursor_info["ns"]
else:
self.__ns = collection.full_name
@ -121,7 +121,7 @@ class CommandCursor(Generic[_DocumentType]):
"""Explicitly close / kill this cursor."""
self.__die(True)
def batch_size(self, batch_size: int) -> "CommandCursor[_DocumentType]":
def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
@ -340,7 +340,7 @@ class CommandCursor(Generic[_DocumentType]):
"""
return self._try_next(get_more_allowed=True)
def __enter__(self) -> "CommandCursor[_DocumentType]":
def __enter__(self) -> CommandCursor[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

View File

@ -195,7 +195,7 @@ def validate_integer(option: str, value: Any) -> int:
try:
return int(value)
except ValueError:
raise ValueError(f"The value of {option} must be an integer")
raise ValueError(f"The value of {option} must be an integer") from None
raise TypeError(f"Wrong type for {option}, value must be an integer")
@ -287,9 +287,9 @@ def validate_positive_float(option: str, value: Any) -> float:
try:
value = float(value)
except ValueError:
raise ValueError(errmsg)
raise ValueError(errmsg) from None
except TypeError:
raise TypeError(errmsg)
raise TypeError(errmsg) from None
# float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at
# one billion - this is a reasonable approximation for infinity
@ -388,10 +388,10 @@ def validate_uuid_representation(dummy: Any, value: Any) -> int:
return _UUID_REPRESENTATIONS[value]
except KeyError:
raise ValueError(
"{} is an invalid UUID representation. "
f"{value} is an invalid UUID representation. "
"Must be one of "
"{}".format(value, tuple(_UUID_REPRESENTATIONS))
)
f"{tuple(_UUID_REPRESENTATIONS)}"
) from None
def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]:
@ -411,7 +411,7 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]
tags[unquote_plus(key)] = unquote_plus(val)
tag_sets.append(tags)
except Exception:
raise ValueError(f"{tag_set!r} not a valid value for {name}")
raise ValueError(f"{tag_set!r} not a valid value for {name}") from None
return tag_sets
@ -432,7 +432,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
if not isinstance(value, str):
if not isinstance(value, dict):
raise ValueError("Auth mechanism properties must be given as a string or a dictionary")
for key, value in value.items():
for key, value in value.items(): # noqa: B020
if isinstance(value, str):
props[key] = value
elif isinstance(value, bool):
@ -462,20 +462,20 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
except ValueError:
# Try not to leak the token.
if "AWS_SESSION_TOKEN" in opt:
opt = (
opt = ( # noqa: PLW2901
"AWS_SESSION_TOKEN:<redacted token>, did you forget "
"to percent-escape the token with quote_plus?"
)
raise ValueError(
"auth mechanism properties must be "
"key:value pairs like SERVICE_NAME:"
"mongodb, not {}.".format(opt)
)
f"mongodb, not {opt}."
) from None
if key not in _MECHANISM_PROPS:
raise ValueError(
"{} is not a supported auth "
f"{key} is not a supported auth "
"mechanism property. Must be one of "
"{}.".format(key, tuple(_MECHANISM_PROPS))
f"{tuple(_MECHANISM_PROPS)}."
)
if key == "CANONICALIZE_HOST_NAME":
props[key] = validate_boolean_or_string(key, val)
@ -499,9 +499,9 @@ def validate_document_class(
is_mapping = issubclass(value.__origin__, abc.MutableMapping)
if not is_mapping and not issubclass(value, RawBSONDocument):
raise TypeError(
"{} must be dict, bson.son.SON, "
f"{option} must be dict, bson.son.SON, "
"bson.raw_bson.RawBSONDocument, or a "
"subclass of collections.MutableMapping".format(option)
"subclass of collections.MutableMapping"
)
return value
@ -531,9 +531,9 @@ def validate_list_or_mapping(option: Any, value: Any) -> None:
"""Validates that 'value' is a list or a document."""
if not isinstance(value, (abc.Mapping, list)):
raise TypeError(
"{} must either be a list or an instance of dict, "
f"{option} must either be a list or an instance of dict, "
"bson.son.SON, or any other type that inherits from "
"collections.Mapping".format(option)
"collections.Mapping"
)
@ -541,9 +541,9 @@ def validate_is_mapping(option: str, value: Any) -> None:
"""Validate the type of method arguments that expect a document."""
if not isinstance(value, abc.Mapping):
raise TypeError(
"{} must be an instance of dict, bson.son.SON, or "
f"{option} must be an instance of dict, bson.son.SON, or "
"any other type that inherits from "
"collections.Mapping".format(option)
"collections.Mapping"
)
@ -551,10 +551,10 @@ def validate_is_document_type(option: str, value: Any) -> None:
"""Validate the type of method arguments that expect a MongoDB document."""
if not isinstance(value, (abc.MutableMapping, RawBSONDocument)):
raise TypeError(
"{} must be an instance of dict, bson.son.SON, "
f"{option} must be an instance of dict, bson.son.SON, "
"bson.raw_bson.RawBSONDocument, or "
"a type that inherits from "
"collections.MutableMapping".format(option)
"collections.MutableMapping"
)
@ -626,9 +626,9 @@ def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str:
"""Validate the Unicode decode error handler option of CodecOptions."""
if value not in _UNICODE_DECODE_ERROR_HANDLERS:
raise ValueError(
"{} is an invalid Unicode decode error handler. "
f"{value} is an invalid Unicode decode error handler. "
"Must be one of "
"{}".format(value, tuple(_UNICODE_DECODE_ERROR_HANDLERS))
f"{tuple(_UNICODE_DECODE_ERROR_HANDLERS)}"
)
return value
@ -841,28 +841,28 @@ def get_validated_options(
validated_options = _CaseInsensitiveDictionary()
def get_normed_key(x: str) -> str:
return x # noqa: E731
return x
def get_setter_key(x: str) -> str:
return options.cased_key(x) # type: ignore[attr-defined] # noqa: E731
return options.cased_key(x) # type: ignore[attr-defined]
else:
validated_options = {}
def get_normed_key(x: str) -> str:
return x.lower() # noqa: E731
return x.lower()
def get_setter_key(x: str) -> str:
return x # noqa: E731
return x
for opt, value in options.items():
normed_key = get_normed_key(opt)
try:
validator = URI_OPTIONS_VALIDATOR_MAP.get(normed_key, raise_config_error)
value = validator(opt, value)
value = validator(opt, value) # noqa: PLW2901
except (ValueError, TypeError, ConfigurationError) as exc:
if warn:
warnings.warn(str(exc))
warnings.warn(str(exc), stacklevel=2)
else:
raise
else:
@ -902,9 +902,9 @@ class BaseObject:
if not isinstance(read_preference, _ServerMode):
raise TypeError(
"{!r} is not valid for read_preference. See "
f"{read_preference!r} is not valid for read_preference. See "
"pymongo.read_preferences for valid "
"options.".format(read_preference)
"options."
)
self.__read_preference = read_preference
@ -1004,7 +1004,7 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]):
return NotImplemented
if len(self) != len(other):
return False
for key in other:
for key in other: # noqa: SIM110
if self[key] != other[key]:
return False

View File

@ -58,24 +58,27 @@ def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[s
for compressor in compressors[:]:
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn(f"Unsupported compressor: {compressor}")
warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2)
elif compressor == "snappy" and not _HAVE_SNAPPY:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with snappy is not available. "
"You must install the python-snappy module for snappy support."
"You must install the python-snappy module for snappy support.",
stacklevel=2,
)
elif compressor == "zlib" and not _HAVE_ZLIB:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zlib is not available. "
"The zlib module is not available."
"The zlib module is not available.",
stacklevel=2,
)
elif compressor == "zstd" and not _HAVE_ZSTD:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"You must install the zstandard module for zstandard support."
"You must install the zstandard module for zstandard support.",
stacklevel=2,
)
return compressors
@ -84,7 +87,7 @@ def validate_zlib_compression_level(option: str, value: Any) -> int:
try:
level = int(value)
except Exception:
raise TypeError(f"{option} must be an integer, not {value!r}.")
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
if level < -1 or level > 9:
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
return level

View File

@ -335,7 +335,7 @@ class Cursor(Generic[_DocumentType]):
def __del__(self) -> None:
self.__die()
def rewind(self) -> "Cursor[_DocumentType]":
def rewind(self) -> Cursor[_DocumentType]:
"""Rewind this cursor to its unevaluated state.
Reset this cursor if it has been partially or completely evaluated.
@ -353,7 +353,7 @@ class Cursor(Generic[_DocumentType]):
return self
def clone(self) -> "Cursor[_DocumentType]":
def clone(self) -> Cursor[_DocumentType]:
"""Get a clone of this cursor.
Returns a new Cursor instance with options matching those that have
@ -505,7 +505,7 @@ class Cursor(Generic[_DocumentType]):
if self.__retrieved or self.__id is not None:
raise InvalidOperation("cannot set options after executing query")
def add_option(self, mask: int) -> "Cursor[_DocumentType]":
def add_option(self, mask: int) -> Cursor[_DocumentType]:
"""Set arbitrary query flags using a bitmask.
To set the tailable flag:
@ -525,7 +525,7 @@ class Cursor(Generic[_DocumentType]):
self.__query_flags |= mask
return self
def remove_option(self, mask: int) -> "Cursor[_DocumentType]":
def remove_option(self, mask: int) -> Cursor[_DocumentType]:
"""Unset arbitrary query flags using a bitmask.
To unset the tailable flag:
@ -541,7 +541,7 @@ class Cursor(Generic[_DocumentType]):
self.__query_flags &= ~mask
return self
def allow_disk_use(self, allow_disk_use: bool) -> "Cursor[_DocumentType]":
def allow_disk_use(self, allow_disk_use: bool) -> Cursor[_DocumentType]:
"""Specifies whether MongoDB can use temporary disk files while
processing a blocking sort operation.
@ -563,7 +563,7 @@ class Cursor(Generic[_DocumentType]):
self.__allow_disk_use = allow_disk_use
return self
def limit(self, limit: int) -> "Cursor[_DocumentType]":
def limit(self, limit: int) -> Cursor[_DocumentType]:
"""Limits the number of results to be returned by this cursor.
Raises :exc:`TypeError` if `limit` is not an integer. Raises
@ -586,7 +586,7 @@ class Cursor(Generic[_DocumentType]):
self.__limit = limit
return self
def batch_size(self, batch_size: int) -> "Cursor[_DocumentType]":
def batch_size(self, batch_size: int) -> Cursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
@ -614,7 +614,7 @@ class Cursor(Generic[_DocumentType]):
self.__batch_size = batch_size
return self
def skip(self, skip: int) -> "Cursor[_DocumentType]":
def skip(self, skip: int) -> Cursor[_DocumentType]:
"""Skips the first `skip` results of this cursor.
Raises :exc:`TypeError` if `skip` is not an integer. Raises
@ -635,7 +635,7 @@ class Cursor(Generic[_DocumentType]):
self.__skip = skip
return self
def max_time_ms(self, max_time_ms: Optional[int]) -> "Cursor[_DocumentType]":
def max_time_ms(self, max_time_ms: Optional[int]) -> Cursor[_DocumentType]:
"""Specifies a time limit for a query operation. If the specified
time is exceeded, the operation will be aborted and
:exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms`
@ -655,7 +655,7 @@ class Cursor(Generic[_DocumentType]):
self.__max_time_ms = max_time_ms
return self
def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> "Cursor[_DocumentType]":
def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> Cursor[_DocumentType]:
"""Specifies a time limit for a getMore operation on a
:attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other
types of cursor max_await_time_ms is ignored.
@ -687,7 +687,7 @@ class Cursor(Generic[_DocumentType]):
...
@overload
def __getitem__(self, index: slice) -> "Cursor[_DocumentType]":
def __getitem__(self, index: slice) -> Cursor[_DocumentType]:
...
def __getitem__(self, index: Union[int, slice]) -> Union[_DocumentType, Cursor[_DocumentType]]:
@ -770,7 +770,7 @@ class Cursor(Generic[_DocumentType]):
raise IndexError("no such item for Cursor instance")
raise TypeError("index %r cannot be applied to Cursor instances" % index)
def max_scan(self, max_scan: Optional[int]) -> "Cursor[_DocumentType]":
def max_scan(self, max_scan: Optional[int]) -> Cursor[_DocumentType]:
"""**DEPRECATED** - Limit the number of documents to scan when
performing the query.
@ -790,7 +790,7 @@ class Cursor(Generic[_DocumentType]):
self.__max_scan = max_scan
return self
def max(self, spec: _Sort) -> "Cursor[_DocumentType]":
def max(self, spec: _Sort) -> Cursor[_DocumentType]:
"""Adds ``max`` operator that specifies upper bound for specific index.
When using ``max``, :meth:`~hint` should also be configured to ensure
@ -813,7 +813,7 @@ class Cursor(Generic[_DocumentType]):
self.__max = SON(spec)
return self
def min(self, spec: _Sort) -> "Cursor[_DocumentType]":
def min(self, spec: _Sort) -> Cursor[_DocumentType]:
"""Adds ``min`` operator that specifies lower bound for specific index.
When using ``min``, :meth:`~hint` should also be configured to ensure
@ -838,7 +838,7 @@ class Cursor(Generic[_DocumentType]):
def sort(
self, key_or_list: _Hint, direction: Optional[Union[int, str]] = None
) -> "Cursor[_DocumentType]":
) -> Cursor[_DocumentType]:
"""Sorts this cursor's results.
Pass a field name and a direction, either
@ -944,7 +944,7 @@ class Cursor(Generic[_DocumentType]):
else:
self.__hint = helpers._index_document(index)
def hint(self, index: Optional[_Hint]) -> "Cursor[_DocumentType]":
def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]:
"""Adds a 'hint', telling Mongo the proper index to use for the query.
Judicious use of hints can greatly improve query
@ -969,7 +969,7 @@ class Cursor(Generic[_DocumentType]):
self.__set_hint(index)
return self
def comment(self, comment: Any) -> "Cursor[_DocumentType]":
def comment(self, comment: Any) -> Cursor[_DocumentType]:
"""Adds a 'comment' to the cursor.
http://mongodb.com/docs/manual/reference/operator/comment/
@ -984,7 +984,7 @@ class Cursor(Generic[_DocumentType]):
self.__comment = comment
return self
def where(self, code: Union[str, Code]) -> "Cursor[_DocumentType]":
def where(self, code: Union[str, Code]) -> Cursor[_DocumentType]:
"""Adds a `$where`_ clause to this query.
The `code` argument must be an instance of :class:`str` or
@ -1027,7 +1027,7 @@ class Cursor(Generic[_DocumentType]):
self.__spec = spec
return self
def collation(self, collation: Optional[_CollationIn]) -> "Cursor[_DocumentType]":
def collation(self, collation: Optional[_CollationIn]) -> Cursor[_DocumentType]:
"""Adds a :class:`~pymongo.collation.Collation` to this query.
Raises :exc:`TypeError` if `collation` is not an instance of
@ -1253,7 +1253,7 @@ class Cursor(Generic[_DocumentType]):
return self.__session
return None
def __iter__(self) -> "Cursor[_DocumentType]":
def __iter__(self) -> Cursor[_DocumentType]:
return self
def next(self) -> _DocumentType:
@ -1267,13 +1267,13 @@ class Cursor(Generic[_DocumentType]):
__next__ = next
def __enter__(self) -> "Cursor[_DocumentType]":
def __enter__(self) -> Cursor[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def __copy__(self) -> "Cursor[_DocumentType]":
def __copy__(self) -> Cursor[_DocumentType]:
"""Support function for `copy.copy()`.
.. versionadded:: 2.4
@ -1320,15 +1320,15 @@ class Cursor(Generic[_DocumentType]):
for key, value in iterator:
if isinstance(value, (dict, list)) and not isinstance(value, SON):
value = self._deepcopy(value, memo)
value = self._deepcopy(value, memo) # noqa: PLW2901
elif not isinstance(value, RE_TYPE):
value = copy.deepcopy(value, memo)
value = copy.deepcopy(value, memo) # noqa: PLW2901
if is_list:
y.append(value) # type: ignore[union-attr]
else:
if not isinstance(key, RE_TYPE):
key = copy.deepcopy(key, memo)
key = copy.deepcopy(key, memo) # noqa: PLW2901
y[key] = value
return y

View File

@ -63,7 +63,7 @@ if sys.platform == "win32":
try:
with open(os.devnull, "r+b") as devnull:
popen = subprocess.Popen(
args,
args, # noqa: S603
creationflags=_DETACHED_PROCESS,
stdin=devnull,
stderr=devnull,
@ -94,7 +94,11 @@ else:
try:
with open(os.devnull, "r+b") as devnull:
return subprocess.Popen(
args, close_fds=True, stdin=devnull, stderr=devnull, stdout=devnull
args, # noqa: S603
close_fds=True,
stdin=devnull,
stderr=devnull,
stdout=devnull,
)
except FileNotFoundError as exc:
warnings.warn(
@ -108,7 +112,7 @@ else:
"""Spawn a daemon process using a double subprocess.Popen."""
spawner_args = [sys.executable, _THIS_FILE]
spawner_args.extend(args)
temp_proc = subprocess.Popen(spawner_args, close_fds=True)
temp_proc = subprocess.Popen(spawner_args, close_fds=True) # noqa: S603
# Reap the intermediate child process to avoid creating zombie
# processes.
_popen_wait(temp_proc, _WAIT_TIMEOUT)

View File

@ -74,7 +74,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def __init__(
self,
client: "MongoClient[_DocumentType]",
client: MongoClient[_DocumentType],
name: str,
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
@ -144,7 +144,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
self._timeout = client.options.timeout
@property
def client(self) -> "MongoClient[_DocumentType]":
def client(self) -> MongoClient[_DocumentType]:
"""The client instance for this :class:`Database`."""
return self.__client
@ -224,12 +224,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
"""
if name.startswith("_"):
raise AttributeError(
"Database has no attribute {!r}. To access the {}"
" collection, use database[{!r}].".format(name, name, name)
f"Database has no attribute {name!r}. To access the {name}"
f" collection, use database[{name!r}]."
)
return self.__getitem__(name)
def __getitem__(self, name: str) -> "Collection[_DocumentType]":
def __getitem__(self, name: str) -> Collection[_DocumentType]:
"""Get a collection of this database by name.
Raises InvalidName if an invalid collection name is used.
@ -791,7 +791,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: "Optional[bson.codec_options.CodecOptions[_CodecDocumentType]]" = None,
codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -1012,7 +1012,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> dict[str, Any]:
@ -1090,7 +1090,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[MutableMapping[str, Any]]:
@ -1377,7 +1377,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if dbref.database is not None and dbref.database != self.__name:
raise ValueError(
"trying to dereference a DBRef that points to "
"another database ({!r} not {!r})".format(dbref.database, self.__name)
f"another database ({dbref.database!r} not {self.__name!r})"
)
return self[dbref.collection].find_one(
{"_id": dbref.id}, session=session, comment=comment, **kwargs

View File

@ -31,13 +31,12 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])):
def __new__(
cls, name: str, version: Optional[str] = None, platform: Optional[str] = None
) -> "DriverInfo":
) -> DriverInfo:
self = super().__new__(cls, name, version, platform)
for key, value in self._asdict().items():
if value is not None and not isinstance(value, str):
raise TypeError(
"Wrong type for DriverInfo {} option, value "
"must be an instance of str".format(key)
f"Wrong type for DriverInfo {key} option, value must be an instance of str"
)
return self

View File

@ -34,7 +34,7 @@ from typing import (
try:
from pymongocrypt.auto_encrypter import AutoEncrypter
from pymongocrypt.errors import MongoCryptError # noqa: F401
from pymongocrypt.errors import MongoCryptError
from pymongocrypt.explicit_encrypter import ExplicitEncrypter
from pymongocrypt.mongocrypt import MongoCryptOptions
from pymongocrypt.state_machine import MongoCryptCallback
@ -102,7 +102,7 @@ def _wrap_encryption_errors() -> Iterator[None]:
# we should propagate them unchanged.
raise
except Exception as exc:
raise EncryptionError(exc)
raise EncryptionError(exc) from None
class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
@ -177,7 +177,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out")
raise socket.timeout("timed out") from None
finally:
conn.close()
except (PyMongoError, MongoCryptError):
@ -414,8 +414,7 @@ class _Encrypter:
with _wrap_encryption_errors():
encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd)
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
return encrypt_cmd
return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
def decrypt(self, response: bytes) -> Optional[bytes]:
"""Decrypt a MongoDB command response.

View File

@ -119,7 +119,7 @@ def _index_list(
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
item = (item, ASCENDING)
item = (item, ASCENDING) # noqa: PLW2901
values.append(item)
return values
@ -146,7 +146,7 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]:
else:
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING)
item = (item, ASCENDING) # noqa: PLW2901
key, value = item
_validate_index_key_pair(key, value)
index[key] = value

View File

@ -67,7 +67,7 @@ def _with_primary(max_staleness: int, selection: Selection) -> Selection:
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
assert s.last_write_date and primary.last_write_date
assert s.last_write_date and primary.last_write_date # noqa: PT018
staleness = (
(s.last_update_time - s.last_write_date)
- (primary.last_update_time - primary.last_write_date)
@ -95,7 +95,7 @@ def _no_primary(max_staleness: int, selection: Selection) -> Selection:
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
assert smax.last_write_date and s.last_write_date
assert smax.last_write_date and s.last_write_date # noqa: PT018
staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency
if staleness <= max_staleness:

View File

@ -106,14 +106,14 @@ _OP_MAP = {
}
_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"}
_UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions(
_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions(
unicode_decode_error_handler="replace"
)
def _randint() -> int:
"""Generate a pseudo random 32 bit integer."""
return random.randint(MIN_INT32, MAX_INT32)
return random.randint(MIN_INT32, MAX_INT32) # noqa: S311
def _maybe_add_read_preference(
@ -731,7 +731,7 @@ def _op_msg_uncompressed(
if _use_c:
_op_msg_uncompressed = _cmessage._op_msg # noqa: F811
_op_msg_uncompressed = _cmessage._op_msg
def _op_msg(
@ -833,7 +833,7 @@ def _query_uncompressed(
if _use_c:
_query_uncompressed = _cmessage._query_message # noqa: F811
_query_uncompressed = _cmessage._query_message
def _query(
@ -889,7 +889,7 @@ def _get_more_uncompressed(
if _use_c:
_get_more_uncompressed = _cmessage._get_more_message # noqa: F811
_get_more_uncompressed = _cmessage._get_more_message
def _get_more(
@ -942,7 +942,7 @@ class _BulkWriteContext:
self.field = _FIELD_MAP[self.name]
self.start_time = datetime.datetime.now() if self.publish else None
self.session = session
self.compress = True if conn.compression_context else False
self.compress = bool(conn.compression_context)
self.op_type = op_type
self.codec = codec
@ -1222,7 +1222,7 @@ def _batched_op_msg_impl(
try:
buf.write(_OP_MSG_MAP[operation])
except KeyError:
raise InvalidOperation("Unknown command")
raise InvalidOperation("Unknown command") from None
to_send = []
idx = 0
@ -1278,7 +1278,7 @@ def _encode_batched_op_msg(
if _use_c:
_encode_batched_op_msg = _cmessage._encode_batched_op_msg # noqa: F811
_encode_batched_op_msg = _cmessage._encode_batched_op_msg
def _batched_op_msg_compressed(
@ -1328,7 +1328,7 @@ def _batched_op_msg(
if _use_c:
_batched_op_msg = _cmessage._batched_op_msg # noqa: F811
_batched_op_msg = _cmessage._batched_op_msg
def _do_batched_op_msg(
@ -1371,7 +1371,7 @@ def _encode_batched_write_command(
if _use_c:
_encode_batched_write_command = _cmessage._encode_batched_write_command # noqa: F811
_encode_batched_write_command = _cmessage._encode_batched_write_command
def _batched_write_command_impl(
@ -1410,7 +1410,7 @@ def _batched_write_command_impl(
try:
buf.write(_OP_MAP[operation])
except KeyError:
raise InvalidOperation("Unknown command")
raise InvalidOperation("Unknown command") from None
# Where to write list document length
list_start = buf.tell() - 4
@ -1586,7 +1586,7 @@ class _OpMsg:
def raw_response(
self,
cursor_id: Optional[int] = None,
user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006
user_fields: Optional[Mapping[str, Any]] = {},
) -> list[Mapping[str, Any]]:
"""
cursor_id is ignored

View File

@ -1297,7 +1297,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# We're running a getMore or this session is pinned to a mongos.
server = topology.select_server_by_address(address)
if not server:
raise AutoReconnect("server %s:%s no longer available" % address)
raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031
else:
server = topology.select_server(server_selector)
return server
@ -1380,7 +1380,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
)
def _cmd(
session: Optional[ClientSession],
_session: Optional[ClientSession],
server: Server,
conn: Connection,
read_preference: _ServerMode,
@ -1579,8 +1579,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
if name.startswith("_"):
raise AttributeError(
"MongoClient has no attribute {!r}. To access the {}"
" database, use client[{!r}].".format(name, name, name)
f"MongoClient has no attribute {name!r}. To access the {name}"
f" database, use client[{name!r}]."
)
return self.__getitem__(name)
@ -2132,7 +2132,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
write_concern=DEFAULT_WRITE_CONCERN,
)
def __enter__(self) -> "MongoClient[_DocumentType]":
def __enter__(self) -> MongoClient[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
@ -2153,8 +2153,7 @@ def _retryable_error_doc(exc: PyMongoError) -> Optional[Mapping[str, Any]]:
# Check the last writeConcernError to determine if this
# BulkWriteError is retryable.
wces = exc.details["writeConcernErrors"]
wce = wces[-1] if wces else None
return wce
return wces[-1] if wces else None
if isinstance(exc, (NotPrimaryError, OperationFailure)):
return cast(Mapping[str, Any], exc.details)
return None

View File

@ -226,7 +226,7 @@ class CommandListener(_EventListener):
and `CommandFailedEvent`.
"""
def started(self, event: "CommandStartedEvent") -> None:
def started(self, event: CommandStartedEvent) -> None:
"""Abstract method to handle a `CommandStartedEvent`.
:Parameters:
@ -234,7 +234,7 @@ class CommandListener(_EventListener):
"""
raise NotImplementedError
def succeeded(self, event: "CommandSucceededEvent") -> None:
def succeeded(self, event: CommandSucceededEvent) -> None:
"""Abstract method to handle a `CommandSucceededEvent`.
:Parameters:
@ -242,7 +242,7 @@ class CommandListener(_EventListener):
"""
raise NotImplementedError
def failed(self, event: "CommandFailedEvent") -> None:
def failed(self, event: CommandFailedEvent) -> None:
"""Abstract method to handle a `CommandFailedEvent`.
:Parameters:
@ -267,7 +267,7 @@ class ConnectionPoolListener(_EventListener):
.. versionadded:: 3.9
"""
def pool_created(self, event: "PoolCreatedEvent") -> None:
def pool_created(self, event: PoolCreatedEvent) -> None:
"""Abstract method to handle a :class:`PoolCreatedEvent`.
Emitted when a connection Pool is created.
@ -277,7 +277,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def pool_ready(self, event: "PoolReadyEvent") -> None:
def pool_ready(self, event: PoolReadyEvent) -> None:
"""Abstract method to handle a :class:`PoolReadyEvent`.
Emitted when a connection Pool is marked ready.
@ -289,7 +289,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def pool_cleared(self, event: "PoolClearedEvent") -> None:
def pool_cleared(self, event: PoolClearedEvent) -> None:
"""Abstract method to handle a `PoolClearedEvent`.
Emitted when a connection Pool is cleared.
@ -299,7 +299,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def pool_closed(self, event: "PoolClosedEvent") -> None:
def pool_closed(self, event: PoolClosedEvent) -> None:
"""Abstract method to handle a `PoolClosedEvent`.
Emitted when a connection Pool is closed.
@ -309,7 +309,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_created(self, event: "ConnectionCreatedEvent") -> None:
def connection_created(self, event: ConnectionCreatedEvent) -> None:
"""Abstract method to handle a :class:`ConnectionCreatedEvent`.
Emitted when a connection Pool creates a Connection object.
@ -319,7 +319,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_ready(self, event: "ConnectionReadyEvent") -> None:
def connection_ready(self, event: ConnectionReadyEvent) -> None:
"""Abstract method to handle a :class:`ConnectionReadyEvent`.
Emitted when a connection has finished its setup, and is now ready to
@ -330,7 +330,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_closed(self, event: "ConnectionClosedEvent") -> None:
def connection_closed(self, event: ConnectionClosedEvent) -> None:
"""Abstract method to handle a :class:`ConnectionClosedEvent`.
Emitted when a connection Pool closes a connection.
@ -340,7 +340,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_check_out_started(self, event: "ConnectionCheckOutStartedEvent") -> None:
def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None:
"""Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`.
Emitted when the driver starts attempting to check out a connection.
@ -350,7 +350,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_check_out_failed(self, event: "ConnectionCheckOutFailedEvent") -> None:
def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None:
"""Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`.
Emitted when the driver's attempt to check out a connection fails.
@ -360,7 +360,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None:
def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None:
"""Abstract method to handle a :class:`ConnectionCheckedOutEvent`.
Emitted when the driver successfully checks out a connection.
@ -370,7 +370,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None:
def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None:
"""Abstract method to handle a :class:`ConnectionCheckedInEvent`.
Emitted when the driver checks in a connection back to the connection
@ -391,7 +391,7 @@ class ServerHeartbeatListener(_EventListener):
.. versionadded:: 3.3
"""
def started(self, event: "ServerHeartbeatStartedEvent") -> None:
def started(self, event: ServerHeartbeatStartedEvent) -> None:
"""Abstract method to handle a `ServerHeartbeatStartedEvent`.
:Parameters:
@ -399,7 +399,7 @@ class ServerHeartbeatListener(_EventListener):
"""
raise NotImplementedError
def succeeded(self, event: "ServerHeartbeatSucceededEvent") -> None:
def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None:
"""Abstract method to handle a `ServerHeartbeatSucceededEvent`.
:Parameters:
@ -407,7 +407,7 @@ class ServerHeartbeatListener(_EventListener):
"""
raise NotImplementedError
def failed(self, event: "ServerHeartbeatFailedEvent") -> None:
def failed(self, event: ServerHeartbeatFailedEvent) -> None:
"""Abstract method to handle a `ServerHeartbeatFailedEvent`.
:Parameters:
@ -424,7 +424,7 @@ class TopologyListener(_EventListener):
.. versionadded:: 3.3
"""
def opened(self, event: "TopologyOpenedEvent") -> None:
def opened(self, event: TopologyOpenedEvent) -> None:
"""Abstract method to handle a `TopologyOpenedEvent`.
:Parameters:
@ -432,7 +432,7 @@ class TopologyListener(_EventListener):
"""
raise NotImplementedError
def description_changed(self, event: "TopologyDescriptionChangedEvent") -> None:
def description_changed(self, event: TopologyDescriptionChangedEvent) -> None:
"""Abstract method to handle a `TopologyDescriptionChangedEvent`.
:Parameters:
@ -440,7 +440,7 @@ class TopologyListener(_EventListener):
"""
raise NotImplementedError
def closed(self, event: "TopologyClosedEvent") -> None:
def closed(self, event: TopologyClosedEvent) -> None:
"""Abstract method to handle a `TopologyClosedEvent`.
:Parameters:
@ -457,7 +457,7 @@ class ServerListener(_EventListener):
.. versionadded:: 3.3
"""
def opened(self, event: "ServerOpeningEvent") -> None:
def opened(self, event: ServerOpeningEvent) -> None:
"""Abstract method to handle a `ServerOpeningEvent`.
:Parameters:
@ -465,7 +465,7 @@ class ServerListener(_EventListener):
"""
raise NotImplementedError
def description_changed(self, event: "ServerDescriptionChangedEvent") -> None:
def description_changed(self, event: ServerDescriptionChangedEvent) -> None:
"""Abstract method to handle a `ServerDescriptionChangedEvent`.
:Parameters:
@ -473,7 +473,7 @@ class ServerListener(_EventListener):
"""
raise NotImplementedError
def closed(self, event: "ServerClosedEvent") -> None:
def closed(self, event: ServerClosedEvent) -> None:
"""Abstract method to handle a `ServerClosedEvent`.
:Parameters:
@ -496,10 +496,10 @@ def _validate_event_listeners(
for listener in listeners:
if not isinstance(listener, _EventListener):
raise TypeError(
"Listeners for {} must be either a "
f"Listeners for {option} must be either a "
"CommandListener, ServerHeartbeatListener, "
"ServerListener, TopologyListener, or "
"ConnectionPoolListener.".format(option)
"ConnectionPoolListener."
)
return listeners
@ -514,10 +514,10 @@ def register(listener: _EventListener) -> None:
"""
if not isinstance(listener, _EventListener):
raise TypeError(
"Listeners for {} must be either a "
f"Listeners for {listener} must be either a "
"CommandListener, ServerHeartbeatListener, "
"ServerListener, TopologyListener, or "
"ConnectionPoolListener.".format(listener)
"ConnectionPoolListener."
)
if isinstance(listener, CommandListener):
_LISTENERS.command_listeners.append(listener)
@ -1147,11 +1147,7 @@ class _ServerEvent:
return self.__topology_id
def __repr__(self) -> str:
return "<{} {} topology_id: {}>".format(
self.__class__.__name__,
self.server_address,
self.topology_id,
)
return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>"
class ServerDescriptionChangedEvent(_ServerEvent):
@ -1216,7 +1212,7 @@ class ServerClosedEvent(_ServerEvent):
class TopologyEvent:
"""Base class for topology description events."""
__slots__ = "__topology_id"
__slots__ = ("__topology_id",)
def __init__(self, topology_id: ObjectId) -> None:
self.__topology_id = topology_id

View File

@ -265,8 +265,8 @@ def receive_message(
)
if length > max_message_size:
raise ProtocolError(
"Message length ({!r}) is larger than server max "
"message size ({!r})".format(length, max_message_size)
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
@ -279,7 +279,9 @@ def receive_message(
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}")
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
@ -337,8 +339,8 @@ def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[fl
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out")
except OSError as exc: # noqa: B014
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise

View File

@ -180,7 +180,7 @@ def _public_key_hash(cert: Certificate) -> bytes:
pbytes = public_key.public_bytes(_Encoding.X962, _PublicFormat.UncompressedPoint)
else:
pbytes = public_key.public_bytes(_Encoding.DER, _PublicFormat.SubjectPublicKeyInfo)
digest = _Hash(_SHA1(), backend=_default_backend())
digest = _Hash(_SHA1(), backend=_default_backend()) # noqa: S303
digest.update(pbytes)
return digest.finalize()
@ -262,7 +262,7 @@ def _verify_response_signature(issuer: Certificate, response: OCSPResponse) -> i
def _build_ocsp_request(cert: Certificate, issuer: Certificate) -> OCSPRequest:
# https://cryptography.io/en/latest/x509/ocsp/#creating-requests
builder = _OCSPRequestBuilder()
builder = builder.add_certificate(cert, issuer, _SHA1())
builder = builder.add_certificate(cert, issuer, _SHA1()) # noqa: S303
return builder.build()

View File

@ -547,7 +547,7 @@ class IndexModel:
class SearchIndexModel:
"""Represents a search index to create."""
__slots__ = "__document"
__slots__ = ("__document",)
def __init__(self, definition: Mapping[str, Any], name: Optional[str] = None) -> None:
"""Create a Search Index instance.

View File

@ -116,7 +116,7 @@ except ImportError:
# Windows, various platforms we don't claim to support
# (Jython, IronPython, ...), systems that don't provide
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(fd: int) -> None:
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
"""Dummy function for platforms that don't provide fcntl."""
@ -1076,7 +1076,7 @@ class Connection:
# shutdown.
try:
self.conn.close()
except Exception:
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
@ -1250,7 +1250,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc: # noqa: B014
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
@ -1811,7 +1811,7 @@ class Pool:
return True
if self._check_interval_seconds is not None and (
0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
):
if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR)
@ -1847,7 +1847,7 @@ class Pool:
)
raise WaitQueueTimeoutError(
"Timed out while checking out a connection from connection pool. "
"maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout)
f"maxPoolSize: {self.opts.max_pool_size}, timeout: {timeout}"
)
def __del__(self) -> None:

View File

@ -86,7 +86,7 @@ def _is_ip_address(address: Any) -> bool:
try:
_ip_address(address)
return True
except (ValueError, UnicodeError): # noqa: B014
except (ValueError, UnicodeError):
return False
@ -122,8 +122,8 @@ class _sslConn(_SSL.Connection):
# Check for closed socket.
if self.fileno() == -1:
if timeout and _time.monotonic() - start > timeout:
raise _socket.timeout("timed out")
raise SSLError("Underlying socket has been closed")
raise _socket.timeout("timed out") from None
raise SSLError("Underlying socket has been closed") from None
if isinstance(exc, _SSL.WantReadError):
want_read = True
want_write = False
@ -135,7 +135,7 @@ class _sslConn(_SSL.Connection):
want_write = True
self.socket_checker.select(self, want_read, want_write, timeout)
if timeout and _time.monotonic() - start > timeout:
raise _socket.timeout("timed out")
raise _socket.timeout("timed out") from None
continue
def do_handshake(self, *args: Any, **kwargs: Any) -> None:
@ -169,7 +169,7 @@ class _sslConn(_SSL.Connection):
# XXX: It's not clear if this can actually happen. PyOpenSSL
# doesn't appear to have any interrupt handling, nor any interrupt
# errors for OpenSSL connections.
except OSError as exc: # noqa: B014
except OSError as exc:
if _errno_from_exception(exc) == _EINTR:
continue
raise
@ -226,10 +226,10 @@ class SSLContext:
"""Setter for verify_mode."""
def _cb(
connobj: _SSL.Connection,
x509obj: _crypto.X509,
errnum: int,
errdepth: int,
_connobj: _SSL.Connection,
_x509obj: _crypto.X509,
_errnum: int,
_errdepth: int,
retcode: int,
) -> bool:
# It seems we don't need to do anything here. Twisted doesn't,
@ -295,7 +295,7 @@ class SSLContext:
# Password callback MUST be set first or it will be ignored.
if password:
def _pwcb(max_length: int, prompt_twice: bool, user_data: bytes) -> bytes:
def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: bytes) -> bytes:
# XXX:We could check the password length against what OpenSSL
# tells us is the max, but we can't raise an exception, so...
# warn?
@ -410,5 +410,5 @@ class SSLContext:
else:
_verify_hostname(ssl_conn, server_hostname)
except (_SICertificateError, _SIVerificationError) as exc:
raise _CertificateError(str(exc))
raise _CertificateError(str(exc)) from None
return ssl_conn

View File

@ -64,9 +64,9 @@ def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
"Tag set {!r} invalid, must be an instance of dict, "
f"Tag set {tags!r} invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping".format(tags)
"collection.Mapping"
)
return list(tag_sets)

View File

@ -32,10 +32,10 @@ class _WriteResult:
"""Raise an exception on property access if unacknowledged."""
if not self.__acknowledged:
raise InvalidOperation(
"A value for {} is not available when "
f"A value for {property_name} is not available when "
"the write is unacknowledged. Check the "
"acknowledged attribute to avoid this "
"error.".format(property_name)
"error."
)
@property

View File

@ -22,7 +22,9 @@ try:
except ImportError:
HAVE_STRINGPREP = False
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> Any:
def saslprep(
data: Any, prohibit_unassigned_code_points: Optional[bool] = True # noqa: ARG001
) -> Any:
"""SASLprep dummy"""
if isinstance(data, str):
raise TypeError(

View File

@ -122,12 +122,12 @@ class ServerApi:
if strict is not None and not isinstance(strict, bool):
raise TypeError(
"Wrong type for ServerApi strict, value must be an instance "
"of bool, not {}".format(type(strict))
f"of bool, not {type(strict)}"
)
if deprecation_errors is not None and not isinstance(deprecation_errors, bool):
raise TypeError(
"Wrong type for ServerApi deprecation_errors, value must be "
"an instance of bool, not {}".format(type(deprecation_errors))
f"an instance of bool, not {type(deprecation_errors)}"
)
self._version = version
self._strict = strict

View File

@ -257,7 +257,7 @@ class ServerDescription:
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._topology_version
def to_unknown(self, error: Optional[Exception] = None) -> "ServerDescription":
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
unknown = ServerDescription(self.address, error=error)
unknown._topology_version = self.topology_version
return unknown

View File

@ -75,7 +75,7 @@ class _SrvResolver:
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
@ -87,7 +87,7 @@ class _SrvResolver:
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc))
raise ConfigurationError(str(exc)) from None
if len(results) > 1:
raise ConfigurationError("Only one TXT record is supported")
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8")
@ -102,7 +102,7 @@ class _SrvResolver:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc))
raise ConfigurationError(str(exc)) from None
return results
def _get_srv_response_and_hosts(
@ -120,7 +120,7 @@ class _SrvResolver:
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:

View File

@ -35,7 +35,7 @@ if HAVE_SSL:
# CPython ssl module constants to configure certificate verification
# at a high level. This is legacy behavior, but requires us to
# import the ssl module even if we're only using it for this purpose.
import ssl as _stdlibssl # noqa
import ssl as _stdlibssl # noqa: F401
from ssl import CERT_NONE, CERT_REQUIRED
HAS_SNI = _ssl.HAS_SNI
@ -74,12 +74,14 @@ if HAVE_SSL:
try:
ctx.load_cert_chain(certfile, None, passphrase)
except _ssl.SSLError as exc:
raise ConfigurationError(f"Private key doesn't match certificate: {exc}")
raise ConfigurationError(f"Private key doesn't match certificate: {exc}") from None
if crlfile is not None:
if _ssl.IS_PYOPENSSL:
raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL")
# Match the server's behavior.
setattr(ctx, "verify_flags", getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0)) # noqa
ctx.verify_flags = getattr( # type:ignore[attr-defined]
_ssl, "VERIFY_CRL_CHECK_LEAF", 0
)
ctx.load_verify_locations(crlfile)
if ca_certs is not None:
ctx.load_verify_locations(ca_certs)

View File

@ -186,7 +186,8 @@ class Topology:
"MongoClient opened before fork. May not be entirely fork-safe, "
"proceed with caution. See PyMongo's documentation for details: "
"https://pymongo.readthedocs.io/en/stable/faq.html#"
"is-pymongo-fork-safe"
"is-pymongo-fork-safe",
stacklevel=2,
)
with self._lock:
# Close servers and clear the pools.

View File

@ -168,12 +168,12 @@ class TopologyDescription:
def has_server(self, address: _Address) -> bool:
return address in self._server_descriptions
def reset_server(self, address: _Address) -> "TopologyDescription":
def reset_server(self, address: _Address) -> TopologyDescription:
"""A copy of this description, with one server marked Unknown."""
unknown_sd = self._server_descriptions[address].to_unknown()
return updated_topology_description(self, unknown_sd)
def reset(self) -> "TopologyDescription":
def reset(self) -> TopologyDescription:
"""A copy of this description, with all servers marked Unknown."""
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
@ -381,7 +381,7 @@ _SERVER_TYPE_TO_TOPOLOGY_TYPE = {
def updated_topology_description(
topology_description: TopologyDescription, server_description: ServerDescription
) -> "TopologyDescription":
) -> TopologyDescription:
"""Return an updated copy of a TopologyDescription.
:Parameters:
@ -672,5 +672,5 @@ def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int:
for s in sds.values():
if s.server_type == SERVER_TYPE.RSPrimary:
return TOPOLOGY_TYPE.ReplicaSetWithPrimary
else:
else: # noqa: PLW0120
return TOPOLOGY_TYPE.ReplicaSetNoPrimary

View File

@ -178,7 +178,7 @@ def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionar
options.setdefault(key, []).append(value)
else:
if key in options:
warnings.warn(f"Duplicate URI option '{key}'.")
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
if key.lower() == "authmechanismproperties":
val = value
else:
@ -350,7 +350,7 @@ def split_options(
else:
raise ValueError
except ValueError:
raise InvalidURI("MongoDB URI options are key=value pairs.")
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
options = _handle_security_options(options)
@ -598,14 +598,14 @@ def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict
if not isinstance(kms_tls_options, dict):
raise TypeError("kms_tls_options must be a dict")
contexts = {}
for provider, opts in kms_tls_options.items():
if not isinstance(opts, dict):
for provider, options in kms_tls_options.items():
if not isinstance(options, dict):
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
opts.setdefault("tls", True)
opts = _CaseInsensitiveDictionary(opts)
options.setdefault("tls", True)
opts = _CaseInsensitiveDictionary(options)
opts = _handle_security_options(opts)
opts = _normalize_options(opts)
opts = validate_options(opts)
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
if ssl_context is None:
raise ConfigurationError("TLS is required for KMS providers")
@ -628,7 +628,7 @@ if __name__ == "__main__":
import pprint
try:
pprint.pprint(parse_uri(sys.argv[1]))
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
except InvalidURI as exc:
print(exc)
print(exc) # noqa: T201
sys.exit(0)

View File

@ -95,3 +95,71 @@ include = ["bson","gridfs", "pymongo"]
bson=["py.typed", "*.pyi"]
pymongo=["py.typed", "*.pyi"]
gridfs=["py.typed", "*.pyi"]
[tool.ruff]
target-version = "py37"
line-length = 100
select = [
"E", "F", "W", # flake8
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"G", # flake8-logging-format
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
"RUF", # Ruff-specific
"S", # flake8-bandit
"SIM", # flake8-simplify
"T20", # flake8-print
"UP", # pyupgrade
"YTT", # flake8-2020
"EXE", # flake8-executable
]
extend-ignore = [
"PLR", # Design related pylint codes
"E501", # Line too long
"PT004", # Use underscore for non-returning fixture (use usefixture instead)
"UP007", # Use `X | Y` for type annotation
"EM101", # Exception must not use a string literal, assign to variable first
"EM102", # Exception must not use an f-string literal, assign to variable first
"G004", # Logging statement uses f-string"
"UP006", # Use `type` instead of `Type` for type annotation"
"RET505", # Unnecessary `elif` after `return` statement"
"RET506", # Unnecessary `elif` after `raise` statement
"SIM108", # Use ternary operator"
"PTH123", # `open()` should be replaced by `Path.open()`"
"SIM102", # Use a single `if` statement instead of nested `if` statements
"SIM105", # Use `contextlib.suppress(OSError)` instead of `try`-`except`-`pass`
"ARG002", # Unused method argument:
"S101", # Use of `assert` detected
"SIM114", # Combine `if` branches using logical `or` operator
"PGH003", # Use specific rule codes when ignoring type issues
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar`
"EM103", # Exception must not use a `.format()` string directly, assign to variable first
"C408", # Unnecessary `dict` call (rewrite as a literal)
"SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements
]
unfixable = [
"RUF100", # Unused noqa
"T20", # Removes print statements
"F841", # Removes unused variables
]
exclude = []
flake8-unused-arguments.ignore-variadic-names = true
isort.required-imports = ["from __future__ import annotations"]
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?)|dummy.*)$"
[tool.ruff.per-file-ignores]
"pymongo/__init__.py" = ["E402"]
"test/*.py" = ["PT", "E402", "PLW", "SIM", "E741", "PTH", "S", "B904", "E722", "T201",
"RET", "ARG", "F405", "B028", "PGH001", "B018", "F403", "RUF015", "E731", "B007",
"UP031", "F401", "B023", "F811"]
"green_framework_test.py" = ["T201"]

10
setup.py Executable file → Normal file
View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os
import sys
import warnings
@ -75,7 +77,8 @@ https://pymongo.readthedocs.io/en/stable/installation.html#osx
% (
"Extension modules",
"There was an issue with your platform configuration - see above.",
)
),
stacklevel=2,
)
def build_extension(self, ext):
@ -90,9 +93,10 @@ https://pymongo.readthedocs.io/en/stable/installation.html#osx
warnings.warn(
self.warning_message
% (
"The %s extension module" % (name,),
"The %s extension module" % (name,), # noqa: UP031
"The output above this warning shows how the compilation failed.",
)
),
stacklevel=2,
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test suite for pymongo, bson, and gridfs."""
from __future__ import annotations
import base64
import gc
@ -29,7 +30,7 @@ import unittest
import warnings
try:
import ipaddress # noqa
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
@ -795,11 +796,12 @@ class ClientContext:
return True
return False
def require_cluster_type(self, topologies=[]): # noqa
def require_cluster_type(self, topologies=None):
"""Run a test only if the client is connected to a cluster that
conforms to one of the specified topologies. Acceptable topologies
are 'single', 'replicaset', and 'sharded'.
"""
topologies = topologies or []
def _is_valid_topology():
return self.is_topology_type(topologies)
@ -1169,9 +1171,9 @@ def print_running_topology(topology):
if running:
print(
"WARNING: found Topology with running threads:\n"
" Threads: {}\n"
" Topology: {}\n"
" Creation traceback:\n{}".format(running, topology, topology._settings._stack)
f" Threads: {running}\n"
f" Topology: {topology}\n"
f" Creation traceback:\n{topology._settings._stack}"
)
@ -1238,7 +1240,7 @@ def clear_warning_registry():
"""Clear the __warningregistry__ for all modules."""
for _, module in list(sys.modules.items()):
if hasattr(module, "__warningregistry__"):
setattr(module, "__warningregistry__", {}) # noqa
module.__warningregistry__ = {} # type:ignore[attr-defined]
class SystemCertsPatcher:

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test connections to various Atlas cluster types."""
from __future__ import annotations
import os
import sys

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test MONGODB-AWS Authentication."""
from __future__ import annotations
import os
import sys

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test MONGODB-OIDC Authentication."""
from __future__ import annotations
import os
import sys

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from test import setup, teardown
import pytest

View File

@ -16,6 +16,7 @@
https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.rst
"""
from __future__ import annotations
from test.utils_spec_runner import SpecRunner

View File

@ -4,6 +4,8 @@ Lambda function for Python Driver testing
Creates the client that is cached for all requests, subscribes to
relevant events, and forces the connection pool to get populated.
"""
from __future__ import annotations
import json
import os

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import namedtuple

1
test/mockupdb/test_auth_recovering_member.py Executable file → Normal file
View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test $clusterTime handling."""
from __future__ import annotations
import unittest

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test list_indexes with more than one batch."""
from __future__ import annotations
import unittest

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""Test PyMongo cursor with a sharded cluster."""
from __future__ import annotations
import unittest
from queue import Queue

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
@ -63,7 +64,7 @@ class TestHandshake(unittest.TestCase):
client = MongoClient(
"mongodb://" + primary.address_string,
appname="my app", # For _check_handshake_data()
**dict([k_map.get((k, v), (k, v)) for k, v in kwargs.items()]) # type: ignore[arg-type]
**dict([k_map.get((k, v), (k, v)) for k, v in kwargs.items()]), # type: ignore[arg-type]
)
self.addCleanup(client.close)
@ -236,14 +237,12 @@ class TestHandshake(unittest.TestCase):
request.reply(
OpMsgReply(
**primary_response,
**{
"payload": b"r=wPleNM8S5p8gMaffMDF7Py4ru9bnmmoqb0"
b"1WNPsil6o=pAvr6B1garhlwc6MKNQ93ZfFky"
b"tXdF9r,"
b"s=4dcxugMJq2P4hQaDbGXZR8uR3ei"
b"PHrSmh4uhkg==,i=15000",
"saslSupportedMechs": ["SCRAM-SHA-1"],
}
payload=b"r=wPleNM8S5p8gMaffMDF7Py4ru9bnmmoqb0"
b"1WNPsil6o=pAvr6B1garhlwc6MKNQ93ZfFky"
b"tXdF9r,"
b"s=4dcxugMJq2P4hQaDbGXZR8uR3ei"
b"PHrSmh4uhkg==,i=15000",
saslSupportedMechs=["SCRAM-SHA-1"],
)
)
return None

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import time
import unittest

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test list_indexes with more than one batch."""
from __future__ import annotations
import unittest

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test PyMongo with a mixed-version cluster."""
from __future__ import annotations
import time
import unittest

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import itertools
import unittest

1
test/mockupdb/test_network_disconnect_primary.py Executable file → Normal file
View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest

1
test/mockupdb/test_op_msg.py Executable file → Normal file
View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from collections import namedtuple

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import itertools

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test PyMongo query and read preference with a sharded cluster."""
from __future__ import annotations
import unittest

1
test/mockupdb/test_reset_and_request_check.py Executable file → Normal file
View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import itertools
import time

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test connections to RSGhost nodes."""
from __future__ import annotations
import datetime
import unittest

View File

@ -16,6 +16,7 @@
Just make sure SlaveOkay is *not* set on primary reads.
"""
from __future__ import annotations
import unittest
@ -50,7 +51,7 @@ class TestSlaveOkayRS(unittest.TestCase):
def create_slave_ok_rs_test(operation):
def test(self):
self.setup_server()
assert not operation.op_type == "always-use-secondary"
assert operation.op_type != "always-use-secondary"
client = MongoClient(self.primary.uri, replicaSet="rs")
self.addCleanup(client.close)

View File

@ -18,6 +18,8 @@
- A direct connection to a slave.
- A direct connection to a mongos.
"""
from __future__ import annotations
import itertools
import unittest
from queue import Queue
@ -43,10 +45,7 @@ class TestSlaveOkaySharded(unittest.TestCase):
"ismaster", minWireVersion=2, maxWireVersion=6, ismaster=True, msg="isdbgrid"
)
self.mongoses_uri = "mongodb://{},{}".format(
self.mongos1.address_string,
self.mongos2.address_string,
)
self.mongoses_uri = f"mongodb://{self.mongos1.address_string},{self.mongos2.address_string}"
def create_slave_ok_sharded_test(mode, operation):

View File

@ -18,6 +18,7 @@
- A direct connection to a slave.
- A direct connection to a mongos.
"""
from __future__ import annotations
import itertools
import unittest

View File

@ -14,6 +14,7 @@
"""Minimal test of PyMongo in a WSGI application, see bug PYTHON-353
"""
from __future__ import annotations
import datetime
import os
@ -42,10 +43,10 @@ from pymongo.mongo_client import MongoClient
assert bson.has_c()
assert pymongo.has_c()
OPTS: "CodecOptions[dict]" = CodecOptions(
OPTS: CodecOptions[dict] = CodecOptions(
uuid_representation=STANDARD, datetime_conversion=DatetimeConversion.DATETIME_AUTO
)
client: "MongoClient[dict]" = MongoClient()
client: MongoClient[dict] = MongoClient()
# Use a unique collection name for each process:
coll_name = f"test-{uuid.uuid4()}"
collection = client.test.get_collection(coll_name, codec_options=OPTS)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test client for mod_wsgi application, see bug PYTHON-353."""
from __future__ import annotations
import _thread as thread
import random

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pymongo import MongoClient
client: MongoClient = MongoClient()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pymongo import MongoClient
client: MongoClient = MongoClient()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from bson.raw_bson import RawBSONDocument
from pymongo import MongoClient

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import TypedDict
from pymongo import MongoClient

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test OCSP."""
from __future__ import annotations
import logging
import os

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for the MongoDB Driver Performance Benchmarking Spec."""
from __future__ import annotations
import multiprocessing as mp
import os

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tools for mocking parts of PyMongo to test other parts."""
from __future__ import annotations
import contextlib
import weakref

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import datetime
import random

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Used by test_client.TestClient.test_sigstop_sigcont."""
from __future__ import annotations
import logging
import os

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Authentication Tests."""
from __future__ import annotations
import os
import sys
@ -179,7 +180,7 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].list_collection_names()
uri = uri + f"&replicaSet={str(set_name)}"
uri = uri + f"&replicaSet={set_name!s}"
client = MongoClient(uri)
client[GSSAPI_DB].list_collection_names()
@ -196,7 +197,7 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].list_collection_names()
mech_uri = mech_uri + f"&replicaSet={str(set_name)}"
mech_uri = mech_uri + f"&replicaSet={set_name!s}"
client = MongoClient(mech_uri)
client[GSSAPI_DB].list_collection_names()

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Run the auth spec tests."""
from __future__ import annotations
import glob
import json

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for the Binary wrapper."""
from __future__ import annotations
import array
import base64

View File

@ -14,6 +14,7 @@
# limitations under the License.
"""Test the bson module."""
from __future__ import annotations
import array
import collections

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Run the BSON corpus specification tests."""
from __future__ import annotations
import binascii
import codecs

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test the bulk API."""
from __future__ import annotations
import sys
import uuid
@ -830,7 +831,7 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests)
self.assertFalse(result.acknowledged)
wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents")
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
def test_no_results_ordered_failure(self):
@ -845,7 +846,7 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests)
self.assertFalse(result.acknowledged)
wait_until(lambda: 3 == self.coll.count_documents({}), "insert 3 documents")
wait_until(lambda: self.coll.count_documents({}) == 3, "insert 3 documents")
self.assertEqual({"_id": 1}, self.coll.find_one({"_id": 1}))
def test_no_results_unordered_success(self):
@ -857,7 +858,7 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests, ordered=False)
self.assertFalse(result.acknowledged)
wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents")
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
def test_no_results_unordered_failure(self):
@ -872,7 +873,7 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests, ordered=False)
self.assertFalse(result.acknowledged)
wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents")
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')

Some files were not shown because too many files have changed in this diff Show More