PYTHON-2504 Run pyupgrade 3.4.0 and ruff 0.0.265 (#1196)

pyupgrade --py37-plus bson/*.py pymongo/*.py gridfs/*.py test/*.py tools/*.py test/*/*.py
ruff --fix-only --select ALL --fixable ALL --target-version py37 --line-length=100 --unfixable COM812,D400,D415,ERA001,RUF100,SIM108,D211,D212,SIM105,SIM,PT,ANN204,EM bson/*.py pymongo/*.py gridfs/*.py test/*.py test/*/*.py
This commit is contained in:
Shane Harvey 2023-05-11 15:27:17 -07:00 committed by GitHub
parent afd7e1c2cd
commit 0092b0af79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
146 changed files with 1234 additions and 1241 deletions

View File

@ -237,8 +237,8 @@ def get_data_and_view(data: Any) -> Tuple[Any, memoryview]:
def _raise_unknown_type(element_type: int, element_name: str) -> NoReturn:
"""Unknown type helper."""
raise InvalidBSON(
"Detected unknown BSON type %r for fieldname '%s'. Are "
"you using the latest driver version?" % (chr(element_type).encode(), element_name)
"Detected unknown BSON type {!r} for fieldname '{}'. Are "
"you using the latest driver version?".format(chr(element_type).encode(), element_name)
)
@ -626,8 +626,7 @@ def gen_list_name() -> Generator[bytes, None, None]:
The first 1000 keys are returned from a pre-built cache. All
subsequent keys are generated on the fly.
"""
for name in _LIST_NAMES:
yield name
yield from _LIST_NAMES
counter = itertools.count(1000)
while True:
@ -942,18 +941,18 @@ def _name_value_to_bson(
name, fallback_encoder(value), check_keys, opts, in_fallback_call=True
)
raise InvalidDocument("cannot encode object: %r, of type: %r" % (value, type(value)))
raise InvalidDocument(f"cannot encode object: {value!r}, of type: {type(value)!r}")
def _element_to_bson(key: Any, value: Any, check_keys: bool, opts: CodecOptions) -> bytes:
"""Encode a single key, value pair."""
if not isinstance(key, str):
raise InvalidDocument("documents must have only string keys, key was %r" % (key,))
raise InvalidDocument(f"documents must have only string keys, key was {key!r}")
if check_keys:
if key.startswith("$"):
raise InvalidDocument("key %r must not start with '$'" % (key,))
raise InvalidDocument(f"key {key!r} must not start with '$'")
if "." in key:
raise InvalidDocument("key %r must not contain '.'" % (key,))
raise InvalidDocument(f"key {key!r} must not contain '.'")
name = _make_name(key)
return _name_value_to_bson(name, value, check_keys, opts)
@ -971,7 +970,7 @@ def _dict_to_bson(doc: Any, check_keys: bool, opts: CodecOptions, top_level: boo
if not top_level or key != "_id":
elements.append(_element_to_bson(key, value, check_keys, opts))
except AttributeError:
raise TypeError("encoder expected a mapping type but got: %r" % (doc,))
raise TypeError(f"encoder expected a mapping type but got: {doc!r}")
encoded = b"".join(elements)
return _PACK_INT(len(encoded) + 5) + encoded + b"\x00"

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Setstate and getstate functions for objects with __slots__, allowing
compatibility with default pickling protocol
compatibility with default pickling protocol
"""
from typing import Any, Mapping
@ -33,7 +33,7 @@ def _mangle_name(name: str, prefix: str) -> str:
def _getstate_slots(self: Any) -> Mapping[Any, Any]:
prefix = self.__class__.__name__
ret = dict()
ret = {}
for name in self.__slots__:
mangled_name = _mangle_name(name, prefix)
if hasattr(self, mangled_name):

View File

@ -306,7 +306,7 @@ class Binary(bytes):
.. versionadded:: 3.11
"""
if self.subtype not in ALL_UUID_SUBTYPES:
raise ValueError("cannot decode subtype %s as a uuid" % (self.subtype,))
raise ValueError(f"cannot decode subtype {self.subtype} as a uuid")
if uuid_representation not in ALL_UUID_REPRESENTATIONS:
raise ValueError(
@ -330,8 +330,7 @@ class Binary(bytes):
return UUID(bytes=self)
raise ValueError(
"cannot decode subtype %s to %s"
% (self.subtype, UUID_REPRESENTATION_NAMES[uuid_representation])
f"cannot decode subtype {self.subtype} to {UUID_REPRESENTATION_NAMES[uuid_representation]}"
)
@property
@ -341,7 +340,7 @@ class Binary(bytes):
def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override]
# Work around http://bugs.python.org/issue7382
data = super(Binary, self).__getnewargs__()[0]
data = super().__getnewargs__()[0]
if not isinstance(data, bytes):
data = data.encode("latin-1")
return data, self.__subtype
@ -355,10 +354,10 @@ class Binary(bytes):
return False
def __hash__(self) -> int:
return super(Binary, self).__hash__() ^ hash(self.__subtype)
return super().__hash__() ^ hash(self.__subtype)
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self):
return "Binary(%s, %s)" % (bytes.__repr__(self), self.__subtype)
return f"Binary({bytes.__repr__(self)}, {self.__subtype})"

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for representing JavaScript code in BSON.
"""
"""Tools for representing JavaScript code in BSON."""
from collections.abc import Mapping as _Mapping
from typing import Any, Mapping, Optional, Type, Union
@ -54,7 +53,7 @@ class Code(str):
cls: Type["Code"],
code: Union[str, "Code"],
scope: Optional[Mapping[str, Any]] = None,
**kwargs: Any
**kwargs: Any,
) -> "Code":
if not isinstance(code, str):
raise TypeError("code must be an instance of str")
@ -88,7 +87,7 @@ class Code(str):
return self.__scope
def __repr__(self):
return "Code(%s, %r)" % (str.__repr__(self), self.__scope)
return f"Code({str.__repr__(self)}, {self.__scope!r})"
def __eq__(self, other: Any) -> bool:
if isinstance(other, Code):

View File

@ -63,12 +63,10 @@ class TypeEncoder(abc.ABC):
@abc.abstractproperty
def python_type(self) -> Any:
"""The Python type to be converted into something serializable."""
pass
@abc.abstractmethod
def transform_python(self, value: Any) -> Any:
"""Convert the given Python object into something serializable."""
pass
class TypeDecoder(abc.ABC):
@ -84,12 +82,10 @@ class TypeDecoder(abc.ABC):
@abc.abstractproperty
def bson_type(self) -> Any:
"""The BSON type to be converted into our own type."""
pass
@abc.abstractmethod
def transform_bson(self, value: Any) -> Any:
"""Convert the given BSON value into our own type."""
pass
class TypeCodec(TypeEncoder, TypeDecoder):
@ -105,14 +101,12 @@ class TypeCodec(TypeEncoder, TypeDecoder):
See :ref:`custom-type-type-codec` documentation for an example.
"""
pass
_Codec = Union[TypeEncoder, TypeDecoder, TypeCodec]
_Fallback = Callable[[Any], Any]
class TypeRegistry(object):
class TypeRegistry:
"""Encapsulates type codecs used in encoding and / or decoding BSON, as
well as the fallback encoder. Type registries cannot be modified after
instantiation.
@ -164,8 +158,7 @@ class TypeRegistry(object):
self._decoder_map[codec.bson_type] = codec.transform_bson
if not is_valid_codec:
raise TypeError(
"Expected an instance of %s, %s, or %s, got %r instead"
% (TypeEncoder.__name__, TypeDecoder.__name__, TypeCodec.__name__, codec)
f"Expected an instance of {TypeEncoder.__name__}, {TypeDecoder.__name__}, or {TypeCodec.__name__}, got {codec!r} instead"
)
def _validate_type_encoder(self, codec: _Codec) -> None:
@ -175,12 +168,12 @@ class TypeRegistry(object):
if issubclass(cast(TypeCodec, codec).python_type, pytype):
err_msg = (
"TypeEncoders cannot change how built-in types are "
"encoded (encoder %s transforms type %s)" % (codec, pytype)
"encoded (encoder {} transforms type {})".format(codec, pytype)
)
raise TypeError(err_msg)
def __repr__(self):
return "%s(type_codecs=%r, fallback_encoder=%r)" % (
return "{}(type_codecs={!r}, fallback_encoder={!r})".format(
self.__class__.__name__,
self.__type_codecs,
self._fallback_encoder,
@ -446,10 +439,9 @@ else:
)
return (
"document_class=%s, tz_aware=%r, uuid_representation=%s, "
"unicode_decode_error_handler=%r, tzinfo=%r, "
"type_registry=%r, datetime_conversion=%s"
% (
"document_class={}, tz_aware={!r}, uuid_representation={}, "
"unicode_decode_error_handler={!r}, tzinfo={!r}, "
"type_registry={!r}, datetime_conversion={!s}".format(
document_class_repr,
self.tz_aware,
uuid_rep_repr,
@ -474,7 +466,7 @@ else:
}
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, self._arguments_repr())
return f"{self.__class__.__name__}({self._arguments_repr()})"
def with_options(self, **kwargs: Any) -> "CodecOptions":
"""Make a copy of this CodecOptions, overriding some options::

View File

@ -21,7 +21,7 @@ from bson._helpers import _getstate_slots, _setstate_slots
from bson.son import SON
class DBRef(object):
class DBRef:
"""A reference to a document stored in MongoDB."""
__slots__ = "__collection", "__id", "__database", "__kwargs"
@ -36,7 +36,7 @@ class DBRef(object):
id: Any,
database: Optional[str] = None,
_extra: Optional[Mapping[str, Any]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
"""Initialize a new :class:`DBRef`.
@ -102,10 +102,10 @@ class DBRef(object):
return doc
def __repr__(self):
extra = "".join([", %s=%r" % (k, v) for k, v in self.__kwargs.items()])
extra = "".join([f", {k}={v!r}" for k, v in self.__kwargs.items()])
if self.database is None:
return "DBRef(%r, %r%s)" % (self.collection, self.id, extra)
return "DBRef(%r, %r, %r%s)" % (self.collection, self.id, self.database, extra)
return f"DBRef({self.collection!r}, {self.id!r}{extra})"
return f"DBRef({self.collection!r}, {self.id!r}, {self.database!r}{extra})"
def __eq__(self, other: Any) -> bool:
if isinstance(other, DBRef):

View File

@ -115,7 +115,7 @@ def _decimal_to_128(value: _VALUE_OPTIONS) -> Tuple[int, int]:
return high, low
class Decimal128(object):
class Decimal128:
"""BSON Decimal128 type::
>>> Decimal128(Decimal("0.0005"))
@ -226,7 +226,7 @@ class Decimal128(object):
)
self.__high, self.__low = value # type: ignore
else:
raise TypeError("Cannot convert %r to Decimal128" % (value,))
raise TypeError(f"Cannot convert {value!r} to Decimal128")
def to_decimal(self) -> decimal.Decimal:
"""Returns an instance of :class:`decimal.Decimal` for this
@ -297,7 +297,7 @@ class Decimal128(object):
return str(dec)
def __repr__(self):
return "Decimal128('%s')" % (str(self),)
return f"Decimal128('{str(self)}')"
def __setstate__(self, value: Tuple[int, int]) -> None:
self.__high, self.__low = value

View File

@ -288,7 +288,7 @@ class JSONOptions(CodecOptions):
strict_uuid: Optional[bool] = None,
json_mode: int = JSONMode.RELAXED,
*args: Any,
**kwargs: Any
**kwargs: Any,
) -> "JSONOptions":
kwargs["tz_aware"] = kwargs.get("tz_aware", False)
if kwargs["tz_aware"]:
@ -303,7 +303,7 @@ class JSONOptions(CodecOptions):
"JSONOptions.datetime_representation must be one of LEGACY, "
"NUMBERLONG, or ISO8601 from DatetimeRepresentation."
)
self = cast(JSONOptions, super(JSONOptions, cls).__new__(cls, *args, **kwargs))
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs))
if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL):
raise ValueError(
"JSONOptions.json_mode must be one of LEGACY, RELAXED, "
@ -350,21 +350,20 @@ class JSONOptions(CodecOptions):
def _arguments_repr(self) -> str:
return (
"strict_number_long=%r, "
"datetime_representation=%r, "
"strict_uuid=%r, json_mode=%r, %s"
% (
"strict_number_long={!r}, "
"datetime_representation={!r}, "
"strict_uuid={!r}, json_mode={!r}, {}".format(
self.strict_number_long,
self.datetime_representation,
self.strict_uuid,
self.json_mode,
super(JSONOptions, self)._arguments_repr(),
super()._arguments_repr(),
)
)
def _options_dict(self) -> Dict[Any, Any]:
# TODO: PYTHON-2442 use _asdict() instead
options_dict = super(JSONOptions, self)._options_dict()
options_dict = super()._options_dict()
options_dict.update(
{
"strict_number_long": self.strict_number_long,
@ -492,7 +491,7 @@ def _json_convert(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) ->
if hasattr(obj, "items"):
return SON(((k, _json_convert(v, json_options)) for k, v in obj.items()))
elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)):
return list((_json_convert(v, json_options) for v in obj))
return [_json_convert(v, json_options) for v in obj]
try:
return default(obj, json_options)
except TypeError:
@ -568,9 +567,9 @@ def _parse_legacy_regex(doc: Any) -> Any:
def _parse_legacy_uuid(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]:
"""Decode a JSON legacy $uuid to Python UUID."""
if len(doc) != 1:
raise TypeError("Bad $uuid, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $uuid, extra field(s): {doc}")
if not isinstance(doc["$uuid"], str):
raise TypeError("$uuid must be a string: %s" % (doc,))
raise TypeError(f"$uuid must be a string: {doc}")
if json_options.uuid_representation == UuidRepresentation.UNSPECIFIED:
return Binary.from_uuid(uuid.UUID(doc["$uuid"]))
else:
@ -613,11 +612,11 @@ def _parse_canonical_binary(doc: Any, json_options: JSONOptions) -> Union[Binary
b64 = binary["base64"]
subtype = binary["subType"]
if not isinstance(b64, str):
raise TypeError("$binary base64 must be a string: %s" % (doc,))
raise TypeError(f"$binary base64 must be a string: {doc}")
if not isinstance(subtype, str) or len(subtype) > 2:
raise TypeError("$binary subType must be a string at most 2 characters: %s" % (doc,))
raise TypeError(f"$binary subType must be a string at most 2 characters: {doc}")
if len(binary) != 2:
raise TypeError('$binary must include only "base64" and "subType" components: %s' % (doc,))
raise TypeError(f'$binary must include only "base64" and "subType" components: {doc}')
data = base64.b64decode(b64.encode())
return _binary_or_uuid(data, int(subtype, 16), json_options)
@ -629,7 +628,7 @@ def _parse_canonical_datetime(
"""Decode a JSON datetime to python datetime.datetime."""
dtm = doc["$date"]
if len(doc) != 1:
raise TypeError("Bad $date, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $date, extra field(s): {doc}")
# mongoexport 2.6 and newer
if isinstance(dtm, str):
# Parse offset
@ -692,7 +691,7 @@ def _parse_canonical_datetime(
def _parse_canonical_oid(doc: Any) -> ObjectId:
"""Decode a JSON ObjectId to bson.objectid.ObjectId."""
if len(doc) != 1:
raise TypeError("Bad $oid, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $oid, extra field(s): {doc}")
return ObjectId(doc["$oid"])
@ -700,7 +699,7 @@ def _parse_canonical_symbol(doc: Any) -> str:
"""Decode a JSON symbol to Python string."""
symbol = doc["$symbol"]
if len(doc) != 1:
raise TypeError("Bad $symbol, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $symbol, extra field(s): {doc}")
return str(symbol)
@ -708,7 +707,7 @@ def _parse_canonical_code(doc: Any) -> Code:
"""Decode a JSON code to bson.code.Code."""
for key in doc:
if key not in ("$code", "$scope"):
raise TypeError("Bad $code, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $code, extra field(s): {doc}")
return Code(doc["$code"], scope=doc.get("$scope"))
@ -716,11 +715,11 @@ def _parse_canonical_regex(doc: Any) -> Regex:
"""Decode a JSON regex to bson.regex.Regex."""
regex = doc["$regularExpression"]
if len(doc) != 1:
raise TypeError("Bad $regularExpression, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $regularExpression, extra field(s): {doc}")
if len(regex) != 2:
raise TypeError(
'Bad $regularExpression must include only "pattern"'
'and "options" components: %s' % (doc,)
'and "options" components: {}'.format(doc)
)
opts = regex["options"]
if not isinstance(opts, str):
@ -739,28 +738,28 @@ def _parse_canonical_dbpointer(doc: Any) -> Any:
"""Decode a JSON (deprecated) DBPointer to bson.dbref.DBRef."""
dbref = doc["$dbPointer"]
if len(doc) != 1:
raise TypeError("Bad $dbPointer, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $dbPointer, extra field(s): {doc}")
if isinstance(dbref, DBRef):
dbref_doc = dbref.as_doc()
# DBPointer must not contain $db in its value.
if dbref.database is not None:
raise TypeError("Bad $dbPointer, extra field $db: %s" % (dbref_doc,))
raise TypeError(f"Bad $dbPointer, extra field $db: {dbref_doc}")
if not isinstance(dbref.id, ObjectId):
raise TypeError("Bad $dbPointer, $id must be an ObjectId: %s" % (dbref_doc,))
raise TypeError(f"Bad $dbPointer, $id must be an ObjectId: {dbref_doc}")
if len(dbref_doc) != 2:
raise TypeError("Bad $dbPointer, extra field(s) in DBRef: %s" % (dbref_doc,))
raise TypeError(f"Bad $dbPointer, extra field(s) in DBRef: {dbref_doc}")
return dbref
else:
raise TypeError("Bad $dbPointer, expected a DBRef: %s" % (doc,))
raise TypeError(f"Bad $dbPointer, expected a DBRef: {doc}")
def _parse_canonical_int32(doc: Any) -> int:
"""Decode a JSON int32 to python int."""
i_str = doc["$numberInt"]
if len(doc) != 1:
raise TypeError("Bad $numberInt, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $numberInt, extra field(s): {doc}")
if not isinstance(i_str, str):
raise TypeError("$numberInt must be string: %s" % (doc,))
raise TypeError(f"$numberInt must be string: {doc}")
return int(i_str)
@ -768,7 +767,7 @@ def _parse_canonical_int64(doc: Any) -> Int64:
"""Decode a JSON int64 to bson.int64.Int64."""
l_str = doc["$numberLong"]
if len(doc) != 1:
raise TypeError("Bad $numberLong, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $numberLong, extra field(s): {doc}")
return Int64(l_str)
@ -776,9 +775,9 @@ def _parse_canonical_double(doc: Any) -> float:
"""Decode a JSON double to python float."""
d_str = doc["$numberDouble"]
if len(doc) != 1:
raise TypeError("Bad $numberDouble, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $numberDouble, extra field(s): {doc}")
if not isinstance(d_str, str):
raise TypeError("$numberDouble must be string: %s" % (doc,))
raise TypeError(f"$numberDouble must be string: {doc}")
return float(d_str)
@ -786,18 +785,18 @@ def _parse_canonical_decimal128(doc: Any) -> Decimal128:
"""Decode a JSON decimal128 to bson.decimal128.Decimal128."""
d_str = doc["$numberDecimal"]
if len(doc) != 1:
raise TypeError("Bad $numberDecimal, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $numberDecimal, extra field(s): {doc}")
if not isinstance(d_str, str):
raise TypeError("$numberDecimal must be string: %s" % (doc,))
raise TypeError(f"$numberDecimal must be string: {doc}")
return Decimal128(d_str)
def _parse_canonical_minkey(doc: Any) -> MinKey:
"""Decode a JSON MinKey to bson.min_key.MinKey."""
if type(doc["$minKey"]) is not int or doc["$minKey"] != 1:
raise TypeError("$minKey value must be 1: %s" % (doc,))
raise TypeError(f"$minKey value must be 1: {doc}")
if len(doc) != 1:
raise TypeError("Bad $minKey, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $minKey, extra field(s): {doc}")
return MinKey()
@ -806,7 +805,7 @@ def _parse_canonical_maxkey(doc: Any) -> MaxKey:
if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1:
raise TypeError("$maxKey value must be 1: %s", (doc,))
if len(doc) != 1:
raise TypeError("Bad $minKey, extra field(s): %s" % (doc,))
raise TypeError(f"Bad $minKey, extra field(s): {doc}")
return MaxKey()
@ -839,7 +838,7 @@ def default(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any:
millis = int(obj.microsecond / 1000)
fracsecs = ".%03d" % (millis,) if millis else ""
return {
"$date": "%s%s%s" % (obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string)
"$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string)
}
millis = _datetime_to_millis(obj)

View File

@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Representation for the MongoDB internal MaxKey type.
"""
"""Representation for the MongoDB internal MaxKey type."""
from typing import Any
class MaxKey(object):
class MaxKey:
"""MongoDB internal MaxKey type."""
__slots__ = ()

View File

@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Representation for the MongoDB internal MinKey type.
"""
"""Representation for the MongoDB internal MinKey type."""
from typing import Any
class MinKey(object):
class MinKey:
"""MongoDB internal MinKey type."""
__slots__ = ()

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with MongoDB ObjectIds.
"""
"""Tools for working with MongoDB ObjectIds."""
import binascii
import calendar
@ -43,7 +42,7 @@ def _random_bytes() -> bytes:
return os.urandom(5)
class ObjectId(object):
class ObjectId:
"""A MongoDB ObjectId."""
_pid = os.getpid()
@ -166,7 +165,6 @@ class ObjectId(object):
def __generate(self) -> None:
"""Generate a new value for this ObjectId."""
# 4 bytes current time
oid = struct.pack(">I", int(time.time()))
@ -202,9 +200,7 @@ class ObjectId(object):
else:
_raise_invalid_id(oid)
else:
raise TypeError(
"id must be an instance of (bytes, str, ObjectId), not %s" % (type(oid),)
)
raise TypeError(f"id must be an instance of (bytes, str, ObjectId), not {type(oid)}")
@property
def binary(self) -> bytes:
@ -224,13 +220,13 @@ class ObjectId(object):
return datetime.datetime.fromtimestamp(timestamp, utc)
def __getstate__(self) -> bytes:
"""return value of object for pickling.
"""Return value of object for pickling.
needed explicitly because __slots__() defined.
"""
return self.__id
def __setstate__(self, value: Any) -> None:
"""explicit state set from pickling"""
"""Explicit state set from pickling"""
# Provide backwards compatibility with OIDs
# pickled with pymongo-1.9 or older.
if isinstance(value, dict):
@ -249,7 +245,7 @@ class ObjectId(object):
return binascii.hexlify(self.__id).decode()
def __repr__(self):
return "ObjectId('%s')" % (str(self),)
return f"ObjectId('{str(self)}')"
def __eq__(self, other: Any) -> bool:
if isinstance(other, ObjectId):

View File

@ -131,7 +131,7 @@ class RawBSONDocument(Mapping[str, Any]):
elif not issubclass(codec_options.document_class, RawBSONDocument):
raise TypeError(
"RawBSONDocument cannot use CodecOptions with document "
"class %s" % (codec_options.document_class,)
"class {}".format(codec_options.document_class)
)
self.__codec_options = codec_options
# Validate the bson object size.
@ -174,7 +174,7 @@ class RawBSONDocument(Mapping[str, Any]):
return NotImplemented
def __repr__(self):
return "%s(%r, codec_options=%r)" % (
return "{}({!r}, codec_options={!r})".format(
self.__class__.__name__,
self.raw,
self.__codec_options,

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for representing MongoDB regular expressions.
"""
"""Tools for representing MongoDB regular expressions."""
import re
from typing import Any, Generic, Pattern, Type, TypeVar, Union
@ -117,7 +116,7 @@ class Regex(Generic[_T]):
return not self == other
def __repr__(self):
return "Regex(%r, %r)" % (self.pattern, self.flags)
return f"Regex({self.pattern!r}, {self.flags!r})"
def try_compile(self) -> "Pattern[_T]":
"""Compile this :class:`Regex` as a Python regular expression.

View File

@ -16,7 +16,8 @@
Regular dictionaries can be used instead of SON objects, but not when the order
of keys is important. A SON object can be used just like a normal Python
dictionary."""
dictionary.
"""
import copy
import re
@ -58,7 +59,7 @@ class SON(Dict[_Key, _Value]):
def __init__(
self,
data: Optional[Union[Mapping[_Key, _Value], Iterable[Tuple[_Key, _Value]]]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self.__keys = []
dict.__init__(self)
@ -66,14 +67,14 @@ class SON(Dict[_Key, _Value]):
self.update(kwargs)
def __new__(cls: Type["SON[_Key, _Value]"], *args: Any, **kwargs: Any) -> "SON[_Key, _Value]":
instance = super(SON, cls).__new__(cls, *args, **kwargs) # type: ignore[type-var]
instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var]
instance.__keys = []
return instance
def __repr__(self):
result = []
for key in self.__keys:
result.append("(%r, %r)" % (key, self[key]))
result.append(f"({key!r}, {self[key]!r})")
return "SON([%s])" % ", ".join(result)
def __setitem__(self, key: _Key, value: _Value) -> None:
@ -94,8 +95,7 @@ class SON(Dict[_Key, _Value]):
# efficient.
# second level definitions support higher levels
def __iter__(self) -> Iterator[_Key]:
for k in self.__keys:
yield k
yield from self.__keys
def has_key(self, key: _Key) -> bool:
return key in self.__keys
@ -113,7 +113,7 @@ class SON(Dict[_Key, _Value]):
def clear(self) -> None:
self.__keys = []
super(SON, self).clear()
super().clear()
def setdefault(self, key: _Key, default: _Value) -> _Value:
try:
@ -189,7 +189,7 @@ class SON(Dict[_Key, _Value]):
if isinstance(value, list):
return [transform_value(v) for v in value]
elif isinstance(value, _Mapping):
return dict([(k, transform_value(v)) for k, v in value.items()])
return {k: transform_value(v) for k, v in value.items()}
else:
return value

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for representing MongoDB internal Timestamps.
"""
"""Tools for representing MongoDB internal Timestamps."""
import calendar
import datetime
@ -25,7 +24,7 @@ from bson.tz_util import utc
UPPERBOUND = 4294967296
class Timestamp(object):
class Timestamp:
"""MongoDB internal timestamps used in the opLog."""
__slots__ = ("__time", "__inc")
@ -113,7 +112,7 @@ class Timestamp(object):
return NotImplemented
def __repr__(self):
return "Timestamp(%s, %s)" % (self.__time, self.__inc)
return f"Timestamp({self.__time}, {self.__inc})"
def as_datetime(self) -> datetime.datetime:
"""Return a :class:`~datetime.datetime` instance corresponding

View File

@ -53,7 +53,7 @@ __all__ = [
]
class GridFS(object):
class GridFS:
"""An instance of GridFS on top of a single Database."""
def __init__(self, database: Database, collection: str = "fs"):
@ -141,7 +141,6 @@ class GridFS(object):
.. versionchanged:: 3.0
w=0 writes to GridFS are now prohibited.
"""
with GridIn(self.__collection, **kwargs) as grid_file:
grid_file.write(data)
return grid_file._id
@ -449,7 +448,7 @@ class GridFS(object):
return f is not None
class GridFSBucket(object):
class GridFSBucket:
"""An instance of GridFS on top of a single Database."""
def __init__(

View File

@ -76,7 +76,7 @@ def _grid_in_property(
if read_only:
docstring += "\n\nThis attribute is read-only."
elif closed_only:
docstring = "%s\n\n%s" % (
docstring = "{}\n\n{}".format(
docstring,
"This attribute is read-only and "
"can only be read after :meth:`close` "
@ -114,7 +114,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None:
raise InvalidOperation("GridFS does not support multi-document transactions")
class GridIn(object):
class GridIn:
"""Class to write data to GridFS."""
def __init__(
@ -497,7 +497,7 @@ class GridOut(io.IOBase):
self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session)
if not self._file:
raise NoFile(
"no file in gridfs collection %r with _id %r" % (self.__files, self.__file_id)
f"no file in gridfs collection {self.__files!r} with _id {self.__file_id!r}"
)
def __getattr__(self, name: str) -> Any:
@ -640,10 +640,10 @@ class GridOut(io.IOBase):
elif whence == _SEEK_END:
new_pos = int(self.length) + pos
else:
raise IOError(22, "Invalid value for `whence`")
raise OSError(22, "Invalid value for `whence`")
if new_pos < 0:
raise IOError(22, "Invalid value for `pos` - must be positive")
raise OSError(22, "Invalid value for `pos` - must be positive")
# Optimization, continue using the same buffer and chunk iterator.
if new_pos == self.__position:
@ -732,7 +732,7 @@ class GridOut(io.IOBase):
pass
class _GridOutChunkIterator(object):
class _GridOutChunkIterator:
"""Iterates over a file's chunks using a single cursor.
Raises CorruptGridFile when encountering any truncated, missing, or extra
@ -832,7 +832,7 @@ class _GridOutChunkIterator(object):
self._cursor = None
class GridOutIterator(object):
class GridOutIterator:
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0)
@ -878,7 +878,7 @@ class GridOutCursor(Cursor):
# Hold on to the base "fs" collection to create GridOut objects later.
self.__root_collection = collection
super(GridOutCursor, self).__init__(
super().__init__(
collection.files,
filter,
skip=skip,
@ -892,7 +892,7 @@ class GridOutCursor(Cursor):
def next(self) -> GridOut:
"""Get next GridOut object from cursor."""
_disallow_transactions(self.session)
next_file = super(GridOutCursor, self).next()
next_file = super().next()
return GridOut(self.__root_collection, file_document=next_file, session=self.session)
__next__ = next

View File

@ -57,7 +57,7 @@ def clamp_remaining(max_timeout: float) -> float:
return min(timeout, max_timeout)
class _TimeoutContext(object):
class _TimeoutContext:
"""Internal timeout context manager.
Use :func:`pymongo.timeout` instead::

View File

@ -21,7 +21,7 @@ from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference, _AggWritePref
class _AggregationCommand(object):
class _AggregationCommand:
"""The internal abstract base class for aggregation cursors.
Should not be called directly by application developers. Use
@ -202,7 +202,7 @@ class _CollectionAggregationCommand(_AggregationCommand):
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
def __init__(self, *args, **kwargs):
super(_CollectionRawAggregationCommand, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
# For raw-batches, we set the initial batchSize for the cursor to 0.
if not self._performs_write:
@ -216,7 +216,7 @@ class _DatabaseAggregationCommand(_AggregationCommand):
@property
def _cursor_namespace(self):
return "%s.$cmd.aggregate" % (self._target.name,)
return f"{self._target.name}.$cmd.aggregate"
@property
def _database(self):

View File

@ -61,7 +61,7 @@ MECHANISMS = frozenset(
"""The authentication mechanisms supported by PyMongo."""
class _Cache(object):
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
@ -104,7 +104,7 @@ _AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
def _build_credentials_tuple(mech, source, user, passwd, extra, database):
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError("%s requires a username." % (mech,))
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
@ -297,7 +297,7 @@ def _password_digest(username, password):
raise TypeError("username must be an instance of str")
md5hash = hashlib.md5()
data = "%s:mongo:%s" % (username, password)
data = f"{username}:mongo:{password}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
@ -306,7 +306,7 @@ def _auth_key(nonce, username, password):
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5()
data = "%s%s%s" % (nonce, username, digest)
data = f"{nonce}{username}{digest}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
@ -448,7 +448,7 @@ def _authenticate_plain(credentials, sock_info):
source = credentials.source
username = credentials.username
password = credentials.password
payload = ("\x00%s\x00%s" % (username, password)).encode("utf-8")
payload = (f"\x00{username}\x00{password}").encode()
cmd = SON(
[
("saslStart", 1),
@ -518,7 +518,7 @@ _AUTH_MAP: Mapping[str, Callable] = {
}
class _AuthContext(object):
class _AuthContext:
def __init__(self, credentials, address):
self.credentials = credentials
self.speculative_authenticate = None
@ -543,7 +543,7 @@ class _AuthContext(object):
class _ScramContext(_AuthContext):
def __init__(self, credentials, address, mechanism):
super(_ScramContext, self).__init__(credentials, address)
super().__init__(credentials, address)
self.scram_data = None
self.mechanism = mechanism
@ -569,7 +569,7 @@ class _OIDCContext(_AuthContext):
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.auth_start_cmd(False)
if cmd is None:
return
return None
cmd["db"] = self.credentials.source
return cmd

View File

@ -21,7 +21,7 @@ try:
_HAVE_MONGODB_AWS = True
except ImportError:
class AwsSaslContext(object): # type: ignore
class AwsSaslContext: # type: ignore
def __init__(self, credentials):
pass
@ -102,9 +102,7 @@ def _authenticate_aws(credentials, sock_info):
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
# Convert to OperationFailure and include pymongo-auth-aws version.
raise OperationFailure(
"%s (pymongo-auth-aws version %s)" % (exc, pymongo_auth_aws.__version__)
)
raise OperationFailure(f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})")
except Exception:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)

View File

@ -131,11 +131,11 @@ class _OIDCAuthenticator:
refresh_token = self.idp_resp and self.idp_resp.get("refresh_token")
refresh_token = refresh_token or ""
context = dict(
timeout_seconds=timeout,
version=CALLBACK_VERSION,
refresh_token=refresh_token,
)
context = {
"timeout_seconds": timeout,
"version": CALLBACK_VERSION,
"refresh_token": refresh_token,
}
if self.idp_resp is None or refresh_cb is None:
self.idp_resp = request_cb(self.idp_info, context)
@ -181,7 +181,7 @@ class _OIDCAuthenticator:
aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"]
with open(aws_identity_file) as fid:
token = fid.read().strip()
payload = dict(jwt=token)
payload = {"jwt": token}
cmd = SON(
[
("saslStart", 1),
@ -203,7 +203,7 @@ class _OIDCAuthenticator:
if self.idp_info is None:
# Send the SASL start with the optional principal name.
payload = dict()
payload = {}
if principal_name:
payload["n"] = principal_name
@ -221,7 +221,7 @@ class _OIDCAuthenticator:
token = self.get_current_token(use_callbacks)
if not token:
return None
bin_payload = Binary(bson.encode(dict(jwt=token)))
bin_payload = Binary(bson.encode({"jwt": token}))
return SON(
[
("saslStart", 1),
@ -268,7 +268,7 @@ class _OIDCAuthenticator:
if resp["done"]:
sock_info.oidc_token_gen_id = self.token_gen_id
return
return None
server_resp: Dict = bson.decode(resp["payload"])
if "issuer" in server_resp:
@ -278,7 +278,7 @@ class _OIDCAuthenticator:
conversation_id = resp["conversationId"]
token = self.get_current_token()
sock_info.oidc_token_gen_id = self.token_gen_id
bin_payload = Binary(bson.encode(dict(jwt=token)))
bin_payload = Binary(bson.encode({"jwt": token}))
cmd = SON(
[
("saslContinue", 1),

View File

@ -60,7 +60,7 @@ _WRITE_CONCERN_ERROR = 64
_COMMANDS = ("insert", "update", "delete")
class _Run(object):
class _Run:
"""Represents a batch of write operations."""
def __init__(self, op_type):
@ -136,7 +136,7 @@ def _raise_bulk_write_error(full_result: Any) -> NoReturn:
raise BulkWriteError(full_result)
class _Bulk(object):
class _Bulk:
"""The private guts of the bulk write API."""
def __init__(self, collection, ordered, bypass_document_validation, comment=None, let=None):
@ -509,5 +509,6 @@ class _Bulk(object):
if not write_concern.acknowledged:
with client._socket_for_writes(session) as sock_info:
self.execute_no_results(sock_info, generator, write_concern)
return None
else:
return self.execute_command(generator, write_concern, session)

View File

@ -156,7 +156,8 @@ class ChangeStream(Generic[_DocumentType]):
@property
def _client(self):
"""The client against which the aggregation commands for
this ChangeStream will be run."""
this ChangeStream will be run.
"""
raise NotImplementedError
def _change_stream_options(self):
@ -221,7 +222,7 @@ class ChangeStream(Generic[_DocumentType]):
if self._start_at_operation_time is None:
raise OperationFailure(
"Expected field 'operationTime' missing from command "
"response : %r" % (result,)
"response : {!r}".format(result)
)
def _run_aggregation_cmd(self, session, explicit_session):
@ -473,6 +474,6 @@ class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]):
"""
def _change_stream_options(self):
options = super(ClusterChangeStream, self)._change_stream_options()
options = super()._change_stream_options()
options["allChangesForCluster"] = True
return options

View File

@ -167,7 +167,7 @@ def _parse_pool_options(username, password, database, options):
)
class ClientOptions(object):
class ClientOptions:
"""Read only configuration options for a MongoClient.
Should not be instantiated directly by application developers. Access

View File

@ -169,7 +169,7 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.write_concern import WriteConcern
class SessionOptions(object):
class SessionOptions:
"""Options for a new :class:`ClientSession`.
:Parameters:
@ -203,8 +203,9 @@ class SessionOptions(object):
if not isinstance(default_transaction_options, TransactionOptions):
raise TypeError(
"default_transaction_options must be an instance of "
"pymongo.client_session.TransactionOptions, not: %r"
% (default_transaction_options,)
"pymongo.client_session.TransactionOptions, not: {!r}".format(
default_transaction_options
)
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
@ -232,7 +233,7 @@ class SessionOptions(object):
return self._snapshot
class TransactionOptions(object):
class TransactionOptions:
"""Options for :meth:`ClientSession.start_transaction`.
:Parameters:
@ -275,25 +276,25 @@ class TransactionOptions(object):
if not isinstance(read_concern, ReadConcern):
raise TypeError(
"read_concern must be an instance of "
"pymongo.read_concern.ReadConcern, not: %r" % (read_concern,)
"pymongo.read_concern.ReadConcern, not: {!r}".format(read_concern)
)
if write_concern is not None:
if not isinstance(write_concern, WriteConcern):
raise TypeError(
"write_concern must be an instance of "
"pymongo.write_concern.WriteConcern, not: %r" % (write_concern,)
"pymongo.write_concern.WriteConcern, not: {!r}".format(write_concern)
)
if not write_concern.acknowledged:
raise ConfigurationError(
"transactions do not support unacknowledged write concern"
": %r" % (write_concern,)
": {!r}".format(write_concern)
)
if read_preference is not None:
if not isinstance(read_preference, _ServerMode):
raise TypeError(
"%r is not valid for read_preference. See "
"{!r} is not valid for read_preference. See "
"pymongo.read_preferences for valid "
"options." % (read_preference,)
"options.".format(read_preference)
)
if max_commit_time_ms is not None:
if not isinstance(max_commit_time_ms, int):
@ -340,12 +341,12 @@ def _validate_session_write_concern(session, write_concern):
else:
raise ConfigurationError(
"Explicit sessions are incompatible with "
"unacknowledged write concern: %r" % (write_concern,)
"unacknowledged write concern: {!r}".format(write_concern)
)
return session
class _TransactionContext(object):
class _TransactionContext:
"""Internal transaction context manager for start_transaction."""
def __init__(self, session):
@ -362,7 +363,7 @@ class _TransactionContext(object):
self.__session.abort_transaction()
class _TxnState(object):
class _TxnState:
NONE = 1
STARTING = 2
IN_PROGRESS = 3
@ -371,7 +372,7 @@ class _TxnState(object):
ABORTED = 6
class _Transaction(object):
class _Transaction:
"""Internal class to hold transaction information in a ClientSession."""
def __init__(self, opts, client):
@ -973,7 +974,7 @@ class ClientSession:
if read_preference != ReadPreference.PRIMARY:
raise InvalidOperation(
"read preference in a transaction must be primary, not: "
"%r" % (read_preference,)
"{!r}".format(read_preference)
)
if self._transaction.state == _TxnState.STARTING:
@ -1023,7 +1024,7 @@ class _EmptyServerSession:
self.started_retryable_write = True
class _ServerSession(object):
class _ServerSession:
def __init__(self, generation):
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.
self.session_id = {"id": Binary(uuid.uuid4().bytes, 4)}
@ -1062,7 +1063,7 @@ class _ServerSessionPool(collections.deque):
"""
def __init__(self, *args, **kwargs):
super(_ServerSessionPool, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.generation = 0
def reset(self):

View File

@ -21,7 +21,7 @@ from typing import Any, Dict, Mapping, Optional, Union
from pymongo import common
class CollationStrength(object):
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
@ -43,7 +43,7 @@ class CollationStrength(object):
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate(object):
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
@ -62,7 +62,7 @@ class CollationAlternate(object):
"""
class CollationMaxVariable(object):
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
@ -75,7 +75,7 @@ class CollationMaxVariable(object):
"""Spaces alone are ignored."""
class CollationCaseFirst(object):
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
@ -91,7 +91,7 @@ class CollationCaseFirst(object):
"""Default for locale or collation strength."""
class Collation(object):
class Collation:
"""Collation
:Parameters:
@ -163,7 +163,7 @@ class Collation(object):
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: Dict[str, Any] = {"locale": locale}
@ -201,7 +201,7 @@ class Collation(object):
def __repr__(self):
document = self.document
return "Collation(%s)" % (", ".join("%s=%r" % (key, document[key]) for key in document),)
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):

View File

@ -88,7 +88,7 @@ _WriteOp = Union[
]
class ReturnDocument(object):
class ReturnDocument:
"""An enum used with
:meth:`~pymongo.collection.Collection.find_one_and_replace` and
:meth:`~pymongo.collection.Collection.find_one_and_update`.
@ -201,7 +201,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
.. seealso:: The MongoDB documentation on `collections <https://dochub.mongodb.org/core/collections>`_.
"""
super(Collection, self).__init__(
super().__init__(
codec_options or database.codec_options,
read_preference or database.read_preference,
write_concern or database.write_concern,
@ -212,7 +212,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if not name or ".." in name:
raise InvalidName("collection names cannot be empty")
if "$" in name and not (name.startswith("oplog.$main") or name.startswith("$cmd")):
if "$" in name and not (name.startswith(("oplog.$main", "$cmd"))):
raise InvalidName("collection names must not contain '$': %r" % name)
if name[0] == "." or name[-1] == ".":
raise InvalidName("collection names must not start or end with '.': %r" % name)
@ -222,7 +222,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self.__database: Database[_DocumentType] = database
self.__name = name
self.__full_name = "%s.%s" % (self.__database.name, self.__name)
self.__full_name = f"{self.__database.name}.{self.__name}"
self.__write_response_codec_options = self.codec_options._replace(
unicode_decode_error_handler="replace", document_class=dict
)
@ -344,17 +344,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
- `name`: the name of the collection to get
"""
if name.startswith("_"):
full_name = "%s.%s" % (self.__name, name)
full_name = f"{self.__name}.{name}"
raise AttributeError(
"Collection has no attribute %r. To access the %s"
" collection, use database['%s']." % (name, full_name, full_name)
"Collection has no attribute {!r}. To access the {}"
" collection, use database['{}'].".format(name, full_name, full_name)
)
return self.__getitem__(name)
def __getitem__(self, name: str) -> "Collection[_DocumentType]":
return Collection(
self.__database,
"%s.%s" % (self.__name, name),
f"{self.__name}.{name}",
False,
self.codec_options,
self.read_preference,
@ -363,7 +363,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
)
def __repr__(self):
return "Collection(%r, %r)" % (self.__database, self.__name)
return f"Collection({self.__database!r}, {self.__name!r})"
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collection):
@ -541,7 +541,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
try:
request._add_to_bulk(blk)
except AttributeError:
raise TypeError("%r is not a valid request" % (request,))
raise TypeError(f"{request!r} is not a valid request")
write_concern = self._write_concern_for(session)
bulk_api_result = blk.execute(write_concern, session)
@ -579,6 +579,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if not isinstance(doc, RawBSONDocument):
return doc.get("_id")
return None
def insert_one(
self,
@ -719,7 +720,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern = self._write_concern_for(session)
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment)
blk.ops = [doc for doc in gen()]
blk.ops = list(gen())
blk.execute(write_concern, session=session)
return InsertManyResult(inserted_ids, write_concern.acknowledged)
@ -1924,7 +1925,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
for index in indexes:
if not isinstance(index, IndexModel):
raise TypeError(
"%r is not an instance of pymongo.operations.IndexModel" % (index,)
f"{index!r} is not an instance of pymongo.operations.IndexModel"
)
document = index.document
names.append(document["name"])
@ -2442,7 +2443,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
with self.__database.client._tmp_session(session, close=False) as s:
return self._aggregate(
_CollectionAggregationCommand,
@ -2687,7 +2687,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if "$" in new_name and not new_name.startswith("oplog.$main"):
raise InvalidName("collection names must not contain '$'")
new_name = "%s.%s" % (self.__database.name, new_name)
new_name = f"{self.__database.name}.{new_name}"
cmd = SON([("renameCollection", self.__full_name), ("to", new_name)])
cmd.update(kwargs)
if comment is not None:
@ -2794,7 +2794,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
**kwargs,
):
"""Internal findAndModify helper."""
common.validate_is_mapping("filter", filter)
if not isinstance(return_document, bool):
raise ValueError(

View File

@ -132,13 +132,15 @@ class CommandCursor(Generic[_DocumentType]):
def _has_next(self):
"""Returns `True` if the cursor has documents remaining from the
previous batch."""
previous batch.
"""
return len(self.__data) > 0
@property
def _post_batch_resume_token(self):
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore."""
changeStream aggregate or getMore.
"""
return self.__postbatchresumetoken
def _maybe_pin_connection(self, sock_info):
@ -328,7 +330,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
assert not cursor_info.get("firstBatch")
super(RawBatchCommandCursor, self).__init__(
super().__init__(
collection,
cursor_info,
address,

View File

@ -157,7 +157,7 @@ def clean_node(node: str) -> Tuple[str, int]:
def raise_config_error(key: str, dummy: Any) -> NoReturn:
"""Raise ConfigurationError with the given key name."""
raise ConfigurationError("Unknown option %s" % (key,))
raise ConfigurationError(f"Unknown option {key}")
# Mapping of URI uuid representation options to valid subtypes.
@ -174,14 +174,14 @@ def validate_boolean(option: str, value: Any) -> bool:
"""Validates that 'value' is True or False."""
if isinstance(value, bool):
return value
raise TypeError("%s must be True or False" % (option,))
raise TypeError(f"{option} must be True or False")
def validate_boolean_or_string(option: str, value: Any) -> bool:
"""Validates that value is True, False, 'true', or 'false'."""
if isinstance(value, str):
if value not in ("true", "false"):
raise ValueError("The value of %s must be 'true' or 'false'" % (option,))
raise ValueError(f"The value of {option} must be 'true' or 'false'")
return value == "true"
return validate_boolean(option, value)
@ -194,15 +194,15 @@ def validate_integer(option: str, value: Any) -> int:
try:
return int(value)
except ValueError:
raise ValueError("The value of %s must be an integer" % (option,))
raise TypeError("Wrong type for %s, value must be an integer" % (option,))
raise ValueError(f"The value of {option} must be an integer")
raise TypeError(f"Wrong type for {option}, value must be an integer")
def validate_positive_integer(option: str, value: Any) -> int:
"""Validate that 'value' is a positive integer, which does not include 0."""
val = validate_integer(option, value)
if val <= 0:
raise ValueError("The value of %s must be a positive integer" % (option,))
raise ValueError(f"The value of {option} must be a positive integer")
return val
@ -210,7 +210,7 @@ def validate_non_negative_integer(option: str, value: Any) -> int:
"""Validate that 'value' is a positive integer or 0."""
val = validate_integer(option, value)
if val < 0:
raise ValueError("The value of %s must be a non negative integer" % (option,))
raise ValueError(f"The value of {option} must be a non negative integer")
return val
@ -221,7 +221,7 @@ def validate_readable(option: str, value: Any) -> Optional[str]:
# First make sure its a string py3.3 open(True, 'r') succeeds
# Used in ssl cert checking due to poor ssl module error reporting
value = validate_string(option, value)
open(value, "r").close()
open(value).close()
return value
@ -243,7 +243,7 @@ def validate_string(option: str, value: Any) -> str:
"""Validates that 'value' is an instance of `str`."""
if isinstance(value, str):
return value
raise TypeError("Wrong type for %s, value must be an instance of str" % (option,))
raise TypeError(f"Wrong type for {option}, value must be an instance of str")
def validate_string_or_none(option: str, value: Any) -> Optional[str]:
@ -262,7 +262,7 @@ def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
return int(value)
except ValueError:
return value
raise TypeError("Wrong type for %s, value must be an integer or a string" % (option,))
raise TypeError(f"Wrong type for {option}, value must be an integer or a string")
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
@ -275,16 +275,14 @@ def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[in
except ValueError:
return value
return validate_non_negative_integer(option, val)
raise TypeError(
"Wrong type for %s, value must be an non negative integer or a string" % (option,)
)
raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string")
def validate_positive_float(option: str, value: Any) -> float:
"""Validates that 'value' is a float, or can be converted to one, and is
positive.
"""
errmsg = "%s must be an integer or float" % (option,)
errmsg = f"{option} must be an integer or float"
try:
value = float(value)
except ValueError:
@ -295,7 +293,7 @@ def validate_positive_float(option: str, value: Any) -> float:
# float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at
# one billion - this is a reasonable approximation for infinity
if not 0 < value < 1e9:
raise ValueError("%s must be greater than 0 and less than one billion" % (option,))
raise ValueError(f"{option} must be greater than 0 and less than one billion")
return value
@ -324,7 +322,7 @@ def validate_timeout_or_zero(option: str, value: Any) -> float:
config error.
"""
if value is None:
raise ConfigurationError("%s cannot be None" % (option,))
raise ConfigurationError(f"{option} cannot be None")
if value == 0 or value == "0":
return 0
return validate_positive_float(option, value) / 1000.0
@ -360,7 +358,7 @@ def validate_max_staleness(option: str, value: Any) -> int:
def validate_read_preference(dummy: Any, value: Any) -> _ServerMode:
"""Validate a read preference."""
if not isinstance(value, _ServerMode):
raise TypeError("%r is not a read preference." % (value,))
raise TypeError(f"{value!r} is not a read preference.")
return value
@ -372,14 +370,14 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
mode.
"""
if value not in _MONGOS_MODES:
raise ValueError("%s is not a valid read preference" % (value,))
raise ValueError(f"{value} is not a valid read preference")
return value
def validate_auth_mechanism(option: str, value: Any) -> str:
"""Validate the authMechanism URI option."""
if value not in MECHANISMS:
raise ValueError("%s must be in %s" % (option, tuple(MECHANISMS)))
raise ValueError(f"{option} must be in {tuple(MECHANISMS)}")
return value
@ -389,9 +387,9 @@ def validate_uuid_representation(dummy: Any, value: Any) -> int:
return _UUID_REPRESENTATIONS[value]
except KeyError:
raise ValueError(
"%s is an invalid UUID representation. "
"{} is an invalid UUID representation. "
"Must be one of "
"%s" % (value, tuple(_UUID_REPRESENTATIONS))
"{}".format(value, tuple(_UUID_REPRESENTATIONS))
)
@ -412,7 +410,7 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]
tags[unquote_plus(key)] = unquote_plus(val)
tag_sets.append(tags)
except Exception:
raise ValueError("%r not a valid value for %s" % (tag_set, name))
raise ValueError(f"{tag_set!r} not a valid value for {name}")
return tag_sets
@ -472,13 +470,13 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni
raise ValueError(
"auth mechanism properties must be "
"key:value pairs like SERVICE_NAME:"
"mongodb, not %s." % (opt,)
"mongodb, not {}.".format(opt)
)
if key not in _MECHANISM_PROPS:
raise ValueError(
"%s is not a supported auth "
"{} is not a supported auth "
"mechanism property. Must be one of "
"%s." % (key, tuple(_MECHANISM_PROPS))
"{}.".format(key, tuple(_MECHANISM_PROPS))
)
if key == "CANONICALIZE_HOST_NAME":
props[key] = validate_boolean_or_string(key, val)
@ -502,9 +500,9 @@ def validate_document_class(
is_mapping = issubclass(value.__origin__, abc.MutableMapping)
if not is_mapping and not issubclass(value, RawBSONDocument):
raise TypeError(
"%s must be dict, bson.son.SON, "
"{} must be dict, bson.son.SON, "
"bson.raw_bson.RawBSONDocument, or a "
"subclass of collections.MutableMapping" % (option,)
"subclass of collections.MutableMapping".format(option)
)
return value
@ -512,14 +510,14 @@ def validate_document_class(
def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
"""Validate the type_registry option."""
if value is not None and not isinstance(value, TypeRegistry):
raise TypeError("%s must be an instance of %s" % (option, TypeRegistry))
raise TypeError(f"{option} must be an instance of {TypeRegistry}")
return value
def validate_list(option: str, value: Any) -> List:
"""Validates that 'value' is a list."""
if not isinstance(value, list):
raise TypeError("%s must be a list" % (option,))
raise TypeError(f"{option} must be a list")
return value
@ -534,9 +532,9 @@ def validate_list_or_mapping(option: Any, value: Any) -> None:
"""Validates that 'value' is a list or a document."""
if not isinstance(value, (abc.Mapping, list)):
raise TypeError(
"%s must either be a list or an instance of dict, "
"{} must either be a list or an instance of dict, "
"bson.son.SON, or any other type that inherits from "
"collections.Mapping" % (option,)
"collections.Mapping".format(option)
)
@ -544,9 +542,9 @@ def validate_is_mapping(option: str, value: Any) -> None:
"""Validate the type of method arguments that expect a document."""
if not isinstance(value, abc.Mapping):
raise TypeError(
"%s must be an instance of dict, bson.son.SON, or "
"{} must be an instance of dict, bson.son.SON, or "
"any other type that inherits from "
"collections.Mapping" % (option,)
"collections.Mapping".format(option)
)
@ -554,10 +552,10 @@ def validate_is_document_type(option: str, value: Any) -> None:
"""Validate the type of method arguments that expect a MongoDB document."""
if not isinstance(value, (abc.MutableMapping, RawBSONDocument)):
raise TypeError(
"%s must be an instance of dict, bson.son.SON, "
"{} must be an instance of dict, bson.son.SON, "
"bson.raw_bson.RawBSONDocument, or "
"a type that inherits from "
"collections.MutableMapping" % (option,)
"collections.MutableMapping".format(option)
)
@ -568,7 +566,7 @@ def validate_appname_or_none(option: str, value: Any) -> Optional[str]:
validate_string(option, value)
# We need length in bytes, so encode utf8 first.
if len(value.encode("utf-8")) > 128:
raise ValueError("%s must be <= 128 bytes" % (option,))
raise ValueError(f"{option} must be <= 128 bytes")
return value
@ -577,7 +575,7 @@ def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]:
if value is None:
return value
if not isinstance(value, DriverInfo):
raise TypeError("%s must be an instance of DriverInfo" % (option,))
raise TypeError(f"{option} must be an instance of DriverInfo")
return value
@ -586,7 +584,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
if value is None:
return value
if not isinstance(value, ServerApi):
raise TypeError("%s must be an instance of ServerApi" % (option,))
raise TypeError(f"{option} must be an instance of ServerApi")
return value
@ -595,7 +593,7 @@ def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
if value is None:
return value
if not callable(value):
raise ValueError("%s must be a callable" % (option,))
raise ValueError(f"{option} must be a callable")
return value
@ -629,9 +627,9 @@ def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str:
"""Validate the Unicode decode error handler option of CodecOptions."""
if value not in _UNICODE_DECODE_ERROR_HANDLERS:
raise ValueError(
"%s is an invalid Unicode decode error handler. "
"{} is an invalid Unicode decode error handler. "
"Must be one of "
"%s" % (value, tuple(_UNICODE_DECODE_ERROR_HANDLERS))
"{}".format(value, tuple(_UNICODE_DECODE_ERROR_HANDLERS))
)
return value
@ -650,7 +648,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
from pymongo.encryption_options import AutoEncryptionOpts
if not isinstance(value, AutoEncryptionOpts):
raise TypeError("%s must be an instance of AutoEncryptionOpts" % (option,))
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
return value
@ -667,7 +665,7 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo
elif isinstance(value, int):
return DatetimeConversion(value)
raise TypeError("%s must be a str or int representing DatetimeConversion" % (option,))
raise TypeError(f"{option} must be a str or int representing DatetimeConversion")
# Dictionary where keys are the names of public URI options, and values
@ -805,7 +803,7 @@ def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]:
"""Validate optional authentication parameters."""
lower, value = validate(option, value)
if lower not in _AUTH_OPTIONS:
raise ConfigurationError("Unknown authentication option: %s" % (option,))
raise ConfigurationError(f"Unknown authentication option: {option}")
return option, value
@ -866,7 +864,7 @@ def _ecoc_coll_name(encrypted_fields, name):
WRITE_CONCERN_OPTIONS = frozenset(["w", "wtimeout", "wtimeoutms", "fsync", "j", "journal"])
class BaseObject(object):
class BaseObject:
"""A base class that provides attributes and methods common
to multiple pymongo classes.
@ -886,9 +884,9 @@ class BaseObject(object):
if not isinstance(read_preference, _ServerMode):
raise TypeError(
"%r is not valid for read_preference. See "
"{!r} is not valid for read_preference. See "
"pymongo.read_preferences for valid "
"options." % (read_preference,)
"options.".format(read_preference)
)
self.__read_preference = read_preference

View File

@ -40,8 +40,8 @@ except ImportError:
from pymongo.hello import HelloCompat
from pymongo.monitoring import _SENSITIVE_COMMANDS
_SUPPORTED_COMPRESSORS = set(["snappy", "zlib", "zstd"])
_NO_COMPRESSION = set([HelloCompat.CMD, HelloCompat.LEGACY_CMD])
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
@ -56,7 +56,7 @@ def validate_compressors(dummy, value):
for compressor in compressors[:]:
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn("Unsupported compressor: %s" % (compressor,))
warnings.warn(f"Unsupported compressor: {compressor}")
elif compressor == "snappy" and not _HAVE_SNAPPY:
compressors.remove(compressor)
warnings.warn(
@ -82,13 +82,13 @@ def validate_zlib_compression_level(option, value):
try:
level = int(value)
except Exception:
raise TypeError("%s must be an integer, not %r." % (option, value))
raise TypeError(f"{option} must be an integer, not {value!r}.")
if level < -1 or level > 9:
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
return level
class CompressionSettings(object):
class CompressionSettings:
def __init__(self, compressors, zlib_compression_level):
self.compressors = compressors
self.zlib_compression_level = zlib_compression_level
@ -102,9 +102,11 @@ class CompressionSettings(object):
return ZlibContext(self.zlib_compression_level)
elif chosen == "zstd":
return ZstdContext()
return None
return None
class SnappyContext(object):
class SnappyContext:
compressor_id = 1
@staticmethod
@ -112,7 +114,7 @@ class SnappyContext(object):
return snappy.compress(data)
class ZlibContext(object):
class ZlibContext:
compressor_id = 2
def __init__(self, level):
@ -122,7 +124,7 @@ class ZlibContext(object):
return zlib.compress(data, self.level)
class ZstdContext(object):
class ZstdContext:
compressor_id = 3
@staticmethod

View File

@ -97,7 +97,7 @@ _QUERY_OPTIONS = {
}
class CursorType(object):
class CursorType:
NON_TAILABLE = 0
"""The standard cursor type."""
@ -126,7 +126,7 @@ class CursorType(object):
"""
class _SocketManager(object):
class _SocketManager:
"""Used with exhaust cursors to ensure the socket is returned."""
def __init__(self, sock, more_to_come):
@ -387,11 +387,11 @@ class Cursor(Generic[_DocumentType]):
"exhaust",
"has_filter",
)
data = dict(
(k, v)
data = {
k: v
for k, v in self.__dict__.items()
if k.startswith("_Cursor__") and k[9:] in values_to_clone
)
}
if deepcopy:
data = self._deepcopy(data)
base.__dict__.update(data)
@ -412,7 +412,7 @@ class Cursor(Generic[_DocumentType]):
self.__killed = True
if self.__id and not already_killed:
cursor_id = self.__id
address = _CursorAddress(self.__address, "%s.%s" % (self.__dbname, self.__collname))
address = _CursorAddress(self.__address, f"{self.__dbname}.{self.__collname}")
else:
# Skip killCursors.
cursor_id = 0
@ -1322,7 +1322,7 @@ class RawBatchCursor(Cursor, Generic[_DocumentType]):
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
super(RawBatchCursor, self).__init__(collection, *args, **kwargs)
super().__init__(collection, *args, **kwargs)
def _unpack_response(
self, response, cursor_id, codec_options, user_fields=None, legacy_response=False

View File

@ -125,7 +125,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
db.__my_collection__
"""
super(Database, self).__init__(
super().__init__(
codec_options or client.codec_options,
read_preference or client.read_preference,
write_concern or client.write_concern,
@ -211,7 +211,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
return hash((self.__client, self.__name))
def __repr__(self):
return "Database(%r, %r)" % (self.__client, self.__name)
return f"Database({self.__client!r}, {self.__name!r})"
def __getattr__(self, name: str) -> Collection[_DocumentType]:
"""Get a collection of this database by name.
@ -223,8 +223,8 @@ class Database(common.BaseObject, Generic[_DocumentType]):
"""
if name.startswith("_"):
raise AttributeError(
"Database has no attribute %r. To access the %s"
" collection, use database[%r]." % (name, name, name)
"Database has no attribute {!r}. To access the {}"
" collection, use database[{!r}].".format(name, name, name)
)
return self.__getitem__(name)
@ -415,9 +415,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
{
// key pattern must be {_id: 1}
key: <key pattern>, // required
unique: <bool>, // required, must be true
unique: <bool>, // required, must be `true`
name: <string>, // optional, otherwise automatically generated
v: <int>, // optional, must be 2 if provided
v: <int>, // optional, must be `2` if provided
}
- ``changeStreamPreAndPostImages`` (dict): a document with a boolean field ``enabled`` for
enabling pre- and post-images.
@ -863,7 +863,6 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _list_collections(self, sock_info, session, read_preference, **kwargs):
"""Internal listCollections helper."""
coll = self.get_collection("$cmd", read_preference=read_preference)
cmd = SON([("listCollections", 1), ("cursor", {})])
cmd.update(kwargs)
@ -1128,14 +1127,14 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if "result" in result:
info = result["result"]
if info.find("exception") != -1 or info.find("corrupt") != -1:
raise CollectionInvalid("%s invalid: %s" % (name, info))
raise CollectionInvalid(f"{name} invalid: {info}")
# Sharded results
elif "raw" in result:
for _, res in result["raw"].items():
if "result" in res:
info = res["result"]
if info.find("exception") != -1 or info.find("corrupt") != -1:
raise CollectionInvalid("%s invalid: %s" % (name, info))
raise CollectionInvalid(f"{name} invalid: {info}")
elif not res.get("valid", False):
valid = False
break
@ -1144,7 +1143,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
valid = False
if not valid:
raise CollectionInvalid("%s invalid: %r" % (name, result))
raise CollectionInvalid(f"{name} invalid: {result!r}")
return result
@ -1200,7 +1199,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if dbref.database is not None and dbref.database != self.__name:
raise ValueError(
"trying to dereference a DBRef that points to "
"another database (%r not %r)" % (dbref.database, self.__name)
"another database ({!r} not {!r})".format(dbref.database, self.__name)
)
return self[dbref.collection].find_one(
{"_id": dbref.id}, session=session, comment=comment, **kwargs

View File

@ -31,12 +31,12 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])):
def __new__(
cls, name: str, version: Optional[str] = None, platform: Optional[str] = None
) -> "DriverInfo":
self = super(DriverInfo, cls).__new__(cls, name, version, platform)
self = super().__new__(cls, name, version, platform)
for key, value in self._asdict().items():
if value is not None and not isinstance(value, str):
raise TypeError(
"Wrong type for DriverInfo %s option, value "
"must be an instance of str" % (key,)
"Wrong type for DriverInfo {} option, value "
"must be an instance of str".format(key)
)
return self

View File

@ -177,6 +177,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor:
for doc in cursor:
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
return None
def spawn(self):
"""Spawn mongocryptd.
@ -272,7 +273,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
self.mongocryptd_client = None
class RewrapManyDataKeyResult(object):
class RewrapManyDataKeyResult:
"""Result object returned by a :meth:`~ClientEncryption.rewrap_many_data_key` operation.
.. versionadded:: 4.2
@ -292,11 +293,12 @@ class RewrapManyDataKeyResult(object):
return self._bulk_write_result
class _Encrypter(object):
class _Encrypter:
"""Encrypts and decrypts MongoDB commands.
This class is used to support automatic encryption and decryption of
MongoDB commands."""
MongoDB commands.
"""
def __init__(self, client, opts):
"""Create a _Encrypter for a client.

View File

@ -31,7 +31,7 @@ if TYPE_CHECKING:
from pymongo.mongo_client import MongoClient
class AutoEncryptionOpts(object):
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
def __init__(

View File

@ -33,7 +33,7 @@ class PyMongoError(Exception):
"""Base class for all PyMongo exceptions."""
def __init__(self, message: str = "", error_labels: Optional[Iterable[str]] = None) -> None:
super(PyMongoError, self).__init__(message)
super().__init__(message)
self._message = message
self._error_labels = set(error_labels or [])
@ -105,7 +105,7 @@ class AutoReconnect(ConnectionFailure):
if errors is not None:
if isinstance(errors, dict):
error_labels = errors.get("errorLabels")
super(AutoReconnect, self).__init__(message, error_labels)
super().__init__(message, error_labels)
self.errors = self.details = errors or []
@ -125,7 +125,7 @@ class NetworkTimeout(AutoReconnect):
def _format_detailed_error(message, details):
if details is not None:
message = "%s, full error: %s" % (message, details)
message = f"{message}, full error: {details}"
return message
@ -148,9 +148,7 @@ class NotPrimaryError(AutoReconnect):
def __init__(
self, message: str = "", errors: Optional[Union[Mapping[str, Any], List]] = None
) -> None:
super(NotPrimaryError, self).__init__(
_format_detailed_error(message, errors), errors=errors
)
super().__init__(_format_detailed_error(message, errors), errors=errors)
class ServerSelectionTimeoutError(AutoReconnect):
@ -191,9 +189,7 @@ class OperationFailure(PyMongoError):
error_labels = None
if details is not None:
error_labels = details.get("errorLabels")
super(OperationFailure, self).__init__(
_format_detailed_error(error, details), error_labels=error_labels
)
super().__init__(_format_detailed_error(error, details), error_labels=error_labels)
self.__code = code
self.__details = details
self.__max_wire_version = max_wire_version
@ -293,7 +289,7 @@ class BulkWriteError(OperationFailure):
details: Mapping[str, Any]
def __init__(self, results: Mapping[str, Any]) -> None:
super(BulkWriteError, self).__init__("batch op errors occurred", 65, results)
super().__init__("batch op errors occurred", 65, results)
def __reduce__(self) -> Tuple[Any, Any]:
return self.__class__, (self.details,)
@ -331,8 +327,6 @@ class InvalidURI(ConfigurationError):
class DocumentTooLarge(InvalidDocument):
"""Raised when an encoded document is too large for the connected server."""
pass
class EncryptionError(PyMongoError):
"""Raised when encryption or decryption fails.
@ -344,7 +338,7 @@ class EncryptionError(PyMongoError):
"""
def __init__(self, cause: Exception) -> None:
super(EncryptionError, self).__init__(str(cause))
super().__init__(str(cause))
self.__cause = cause
@property
@ -369,7 +363,7 @@ class EncryptedCollectionError(EncryptionError):
"""
def __init__(self, cause: Exception, encrypted_fields: Mapping[str, Any]) -> None:
super(EncryptedCollectionError, self).__init__(cause)
super().__init__(cause)
self.__encrypted_fields = encrypted_fields
@property
@ -386,5 +380,3 @@ class EncryptedCollectionError(EncryptionError):
class _OperationCancelled(AutoReconnect):
"""Internal error raised when a socket operation is cancelled."""
pass

View File

@ -74,7 +74,7 @@ _REAUTHENTICATION_REQUIRED_CODE = 391
def _gen_index_name(keys):
"""Generate an index name from the set of fields it is over."""
return "_".join(["%s_%s" % item for item in keys])
return "_".join(["{}_{}".format(*item) for item in keys])
def _index_list(key_or_list, direction=None):
@ -248,12 +248,10 @@ def _fields_list_to_dict(fields, option_name):
if isinstance(fields, (abc.Sequence, abc.Set)):
if not all(isinstance(field, str) for field in fields):
raise TypeError(
"%s must be a list of key names, each an instance of str" % (option_name,)
)
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
return dict.fromkeys(fields, 1)
raise TypeError("%s must be a mapping or list of key names" % (option_name,))
raise TypeError(f"{option_name} must be a mapping or list of key names")
def _handle_exception():
@ -266,7 +264,7 @@ def _handle_exception():
einfo = sys.exc_info()
try:
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
except IOError:
except OSError:
pass
finally:
del einfo

View File

@ -115,7 +115,6 @@ def _convert_exception(exception):
def _convert_write_result(operation, command, result):
"""Convert a legacy write result to write command format."""
# Based on _merge_legacy from bulk.py
affected = result.get("n", 0)
res = {"ok": 1, "n": affected}
@ -240,7 +239,7 @@ def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, commen
return cmd
class _Query(object):
class _Query:
"""A query operation."""
__slots__ = (
@ -310,7 +309,7 @@ class _Query(object):
self._as_command = None
def namespace(self):
return "%s.%s" % (self.db, self.coll)
return f"{self.db}.{self.coll}"
def use_command(self, sock_info):
use_find_cmd = False
@ -421,7 +420,7 @@ class _Query(object):
)
class _GetMore(object):
class _GetMore:
"""A getmore operation."""
__slots__ = (
@ -475,7 +474,7 @@ class _GetMore(object):
self._as_command = None
def namespace(self):
return "%s.%s" % (self.db, self.coll)
return f"{self.db}.{self.coll}"
def use_command(self, sock_info):
use_cmd = False
@ -518,7 +517,6 @@ class _GetMore(object):
def get_message(self, dummy0, sock_info, use_cmd=False):
"""Get a getmore message."""
ns = self.namespace()
ctx = sock_info.compression_context
@ -539,7 +537,7 @@ class _GetMore(object):
class _RawBatchQuery(_Query):
def use_command(self, sock_info):
# Compatibility checks.
super(_RawBatchQuery, self).use_command(sock_info)
super().use_command(sock_info)
if sock_info.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
@ -551,7 +549,7 @@ class _RawBatchQuery(_Query):
class _RawBatchGetMore(_GetMore):
def use_command(self, sock_info):
# Compatibility checks.
super(_RawBatchGetMore, self).use_command(sock_info)
super().use_command(sock_info)
if sock_info.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
@ -578,7 +576,7 @@ class _CursorAddress(tuple):
def __hash__(self):
# Two _CursorAddress instances with different namespaces
# must not hash the same.
return (self + (self.__namespace,)).__hash__()
return ((*self, self.__namespace)).__hash__()
def __eq__(self, other):
if isinstance(other, _CursorAddress):
@ -648,7 +646,7 @@ def _op_msg_no_header(flags, command, identifier, docs, opts):
encoded_size = _pack_int(size)
total_size += size
max_doc_size = max(len(doc) for doc in encoded_docs)
data = [flags_type, encoded, type_one, encoded_size, cstring] + encoded_docs
data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs]
else:
data = [flags_type, encoded]
return b"".join(data), total_size, max_doc_size
@ -795,7 +793,7 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
return _get_more_uncompressed(collection_name, num_to_return, cursor_id)
class _BulkWriteContext(object):
class _BulkWriteContext:
"""A wrapper around SocketInfo for use with write splitting functions."""
__slots__ = (
@ -1033,7 +1031,7 @@ def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> N
else:
# There's nothing intelligent we can say
# about size for update and delete
raise DocumentTooLarge("%r command document too large" % (operation,))
raise DocumentTooLarge(f"{operation!r} command document too large")
# OP_MSG -------------------------------------------------------------
@ -1253,7 +1251,7 @@ def _batched_write_command_impl(namespace, operation, command, docs, opts, ctx,
return to_send, length
class _OpReply(object):
class _OpReply:
"""A MongoDB OP_REPLY response message."""
__slots__ = ("flags", "cursor_id", "number_returned", "documents")
@ -1363,7 +1361,7 @@ class _OpReply(object):
return cls(flags, cursor_id, number_returned, documents)
class _OpMsg(object):
class _OpMsg:
"""A MongoDB OP_MSG response message."""
__slots__ = ("flags", "cursor_id", "number_returned", "payload_document")
@ -1427,12 +1425,12 @@ class _OpMsg(object):
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
if flags != 0:
if flags & cls.CHECKSUM_PRESENT:
raise ProtocolError("Unsupported OP_MSG flag checksumPresent: 0x%x" % (flags,))
raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}")
if flags ^ cls.MORE_TO_COME:
raise ProtocolError("Unsupported OP_MSG flags: 0x%x" % (flags,))
raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}")
if first_payload_type != 0:
raise ProtocolError("Unsupported OP_MSG payload type: 0x%x" % (first_payload_type,))
raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}")
if len(msg) != first_payload_size + 5:
raise ProtocolError("Unsupported OP_MSG reply: >1 section")

View File

@ -805,7 +805,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self.__kill_cursors_queue: List = []
self._event_listeners = options.pool_options._event_listeners
super(MongoClient, self).__init__(
super().__init__(
options.codec_options,
options.read_preference,
options.write_concern,
@ -1509,11 +1509,11 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if value is dict:
return "document_class=dict"
else:
return "document_class=%s.%s" % (value.__module__, value.__name__)
return f"document_class={value.__module__}.{value.__name__}"
if option in common.TIMEOUT_OPTIONS and value is not None:
return "%s=%s" % (option, int(value * 1000))
return f"{option}={int(value * 1000)}"
return "%s=%r" % (option, value)
return f"{option}={value!r}"
# Host first...
options = [
@ -1536,7 +1536,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return ", ".join(options)
def __repr__(self):
return "MongoClient(%s)" % (self._repr_helper(),)
return f"MongoClient({self._repr_helper()})"
def __getattr__(self, name: str) -> database.Database[_DocumentType]:
"""Get a database by name.
@ -1549,8 +1549,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
if name.startswith("_"):
raise AttributeError(
"MongoClient has no attribute %r. To access the %s"
" database, use client[%r]." % (name, name, name)
"MongoClient has no attribute {!r}. To access the {}"
" database, use client[{!r}].".format(name, name, name)
)
return self.__getitem__(name)
@ -1685,7 +1685,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# This method is run periodically by a background thread.
def _process_periodic_tasks(self):
"""Process any pending kill cursors requests and
maintain connection pool parameters."""
maintain connection pool parameters.
"""
try:
self._process_kill_cursors()
self._topology.update_pool()
@ -1742,7 +1743,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _return_server_session(self, server_session, lock):
"""Internal: return a _ServerSession to the pool."""
if isinstance(server_session, _EmptyServerSession):
return
return None
return self._topology.return_server_session(server_session, lock)
def _ensure_session(self, session=None):
@ -2121,7 +2122,7 @@ def _add_retryable_write_error(exc, max_wire_version):
exc._add_error_label("RetryableWriteError")
class _MongoClientErrorHandler(object):
class _MongoClientErrorHandler:
"""Handle errors raised when executing an operation."""
__slots__ = (

View File

@ -37,7 +37,7 @@ def _sanitize(error):
error.__cause__ = None
class MonitorBase(object):
class MonitorBase:
def __init__(self, topology, name, interval, min_interval):
"""Base class to do periodic work on a background thread.
@ -108,7 +108,7 @@ class Monitor(MonitorBase):
The Topology is weakly referenced. The Pool must be exclusive to this
Monitor.
"""
super(Monitor, self).__init__(
super().__init__(
topology,
"pymongo_server_monitor_thread",
topology_settings.heartbeat_frequency,
@ -290,7 +290,7 @@ class SrvMonitor(MonitorBase):
The Topology is weakly referenced.
"""
super(SrvMonitor, self).__init__(
super().__init__(
topology,
"pymongo_srv_polling_thread",
common.MIN_SRV_RESCAN_INTERVAL,
@ -343,7 +343,7 @@ class _RttMonitor(MonitorBase):
The Topology is weakly referenced.
"""
super(_RttMonitor, self).__init__(
super().__init__(
topology,
"pymongo_server_rtt_thread",
topology_settings.heartbeat_frequency,

View File

@ -211,7 +211,7 @@ _Listeners = namedtuple(
_LISTENERS = _Listeners([], [], [], [], [])
class _EventListener(object):
class _EventListener:
"""Abstract base class for all event listeners."""
@ -486,14 +486,14 @@ def _to_micros(dur):
def _validate_event_listeners(option, listeners):
"""Validate event listeners"""
if not isinstance(listeners, abc.Sequence):
raise TypeError("%s must be a list or tuple" % (option,))
raise TypeError(f"{option} must be a list or tuple")
for listener in listeners:
if not isinstance(listener, _EventListener):
raise TypeError(
"Listeners for %s must be either a "
"Listeners for {} must be either a "
"CommandListener, ServerHeartbeatListener, "
"ServerListener, TopologyListener, or "
"ConnectionPoolListener." % (option,)
"ConnectionPoolListener.".format(option)
)
return listeners
@ -508,10 +508,10 @@ def register(listener: _EventListener) -> None:
"""
if not isinstance(listener, _EventListener):
raise TypeError(
"Listeners for %s must be either a "
"Listeners for {} must be either a "
"CommandListener, ServerHeartbeatListener, "
"ServerListener, TopologyListener, or "
"ConnectionPoolListener." % (listener,)
"ConnectionPoolListener.".format(listener)
)
if isinstance(listener, CommandListener):
_LISTENERS.command_listeners.append(listener)
@ -528,19 +528,17 @@ def register(listener: _EventListener) -> None:
# Note - to avoid bugs from forgetting which if these is all lowercase and
# which are camelCase, and at the same time avoid having to add a test for
# every command, use all lowercase here and test against command_name.lower().
_SENSITIVE_COMMANDS: set = set(
[
"authenticate",
"saslstart",
"saslcontinue",
"getnonce",
"createuser",
"updateuser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
]
)
_SENSITIVE_COMMANDS: set = {
"authenticate",
"saslstart",
"saslcontinue",
"getnonce",
"createuser",
"updateuser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
}
# The "hello" command is also deemed sensitive when attempting speculative
@ -554,7 +552,7 @@ def _is_speculative_authenticate(command_name, doc):
return False
class _CommandEvent(object):
class _CommandEvent:
"""Base class for command events."""
__slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id", "__service_id")
@ -627,10 +625,10 @@ class CommandStartedEvent(_CommandEvent):
service_id: Optional[ObjectId] = None,
) -> None:
if not command:
raise ValueError("%r is not a valid command" % (command,))
raise ValueError(f"{command!r} is not a valid command")
# Command name must be first key.
command_name = next(iter(command))
super(CommandStartedEvent, self).__init__(
super().__init__(
command_name, request_id, connection_id, operation_id, service_id=service_id
)
cmd_name = command_name.lower()
@ -651,7 +649,7 @@ class CommandStartedEvent(_CommandEvent):
return self.__db
def __repr__(self):
return ("<%s %s db: %r, command: %r, operation_id: %s, service_id: %s>") % (
return ("<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}>").format(
self.__class__.__name__,
self.connection_id,
self.database_name,
@ -687,7 +685,7 @@ class CommandSucceededEvent(_CommandEvent):
operation_id: Optional[int],
service_id: Optional[ObjectId] = None,
) -> None:
super(CommandSucceededEvent, self).__init__(
super().__init__(
command_name, request_id, connection_id, operation_id, service_id=service_id
)
self.__duration_micros = _to_micros(duration)
@ -708,7 +706,9 @@ class CommandSucceededEvent(_CommandEvent):
return self.__reply
def __repr__(self):
return ("<%s %s command: %r, operation_id: %s, duration_micros: %s, service_id: %s>") % (
return (
"<{} {} command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}>"
).format(
self.__class__.__name__,
self.connection_id,
self.command_name,
@ -744,7 +744,7 @@ class CommandFailedEvent(_CommandEvent):
operation_id: Optional[int],
service_id: Optional[ObjectId] = None,
) -> None:
super(CommandFailedEvent, self).__init__(
super().__init__(
command_name, request_id, connection_id, operation_id, service_id=service_id
)
self.__duration_micros = _to_micros(duration)
@ -762,9 +762,9 @@ class CommandFailedEvent(_CommandEvent):
def __repr__(self):
return (
"<%s %s command: %r, operation_id: %s, duration_micros: %s, "
"failure: %r, service_id: %s>"
) % (
"<{} {} command: {!r}, operation_id: {}, duration_micros: {}, "
"failure: {!r}, service_id: {}>"
).format(
self.__class__.__name__,
self.connection_id,
self.command_name,
@ -775,7 +775,7 @@ class CommandFailedEvent(_CommandEvent):
)
class _PoolEvent(object):
class _PoolEvent:
"""Base class for pool events."""
__slots__ = ("__address",)
@ -791,7 +791,7 @@ class _PoolEvent(object):
return self.__address
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.__address)
return f"{self.__class__.__name__}({self.__address!r})"
class PoolCreatedEvent(_PoolEvent):
@ -807,7 +807,7 @@ class PoolCreatedEvent(_PoolEvent):
__slots__ = ("__options",)
def __init__(self, address: _Address, options: Dict[str, Any]) -> None:
super(PoolCreatedEvent, self).__init__(address)
super().__init__(address)
self.__options = options
@property
@ -816,7 +816,7 @@ class PoolCreatedEvent(_PoolEvent):
return self.__options
def __repr__(self):
return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__options)
return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})"
class PoolReadyEvent(_PoolEvent):
@ -846,7 +846,7 @@ class PoolClearedEvent(_PoolEvent):
__slots__ = ("__service_id",)
def __init__(self, address: _Address, service_id: Optional[ObjectId] = None) -> None:
super(PoolClearedEvent, self).__init__(address)
super().__init__(address)
self.__service_id = service_id
@property
@ -860,7 +860,7 @@ class PoolClearedEvent(_PoolEvent):
return self.__service_id
def __repr__(self):
return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__service_id)
return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r})"
class PoolClosedEvent(_PoolEvent):
@ -876,7 +876,7 @@ class PoolClosedEvent(_PoolEvent):
__slots__ = ()
class ConnectionClosedReason(object):
class ConnectionClosedReason:
"""An enum that defines values for `reason` on a
:class:`ConnectionClosedEvent`.
@ -897,7 +897,7 @@ class ConnectionClosedReason(object):
"""The pool was closed, making the connection no longer valid."""
class ConnectionCheckOutFailedReason(object):
class ConnectionCheckOutFailedReason:
"""An enum that defines values for `reason` on a
:class:`ConnectionCheckOutFailedEvent`.
@ -916,7 +916,7 @@ class ConnectionCheckOutFailedReason(object):
"""
class _ConnectionEvent(object):
class _ConnectionEvent:
"""Private base class for connection events."""
__slots__ = ("__address",)
@ -932,7 +932,7 @@ class _ConnectionEvent(object):
return self.__address
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.__address)
return f"{self.__class__.__name__}({self.__address!r})"
class _ConnectionIdEvent(_ConnectionEvent):
@ -950,7 +950,7 @@ class _ConnectionIdEvent(_ConnectionEvent):
return self.__connection_id
def __repr__(self):
return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__connection_id)
return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})"
class ConnectionCreatedEvent(_ConnectionIdEvent):
@ -999,7 +999,7 @@ class ConnectionClosedEvent(_ConnectionIdEvent):
__slots__ = ("__reason",)
def __init__(self, address, connection_id, reason):
super(ConnectionClosedEvent, self).__init__(address, connection_id)
super().__init__(address, connection_id)
self.__reason = reason
@property
@ -1012,7 +1012,7 @@ class ConnectionClosedEvent(_ConnectionIdEvent):
return self.__reason
def __repr__(self):
return "%s(%r, %r, %r)" % (
return "{}({!r}, {!r}, {!r})".format(
self.__class__.__name__,
self.address,
self.connection_id,
@ -1060,7 +1060,7 @@ class ConnectionCheckOutFailedEvent(_ConnectionEvent):
return self.__reason
def __repr__(self):
return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__reason)
return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r})"
class ConnectionCheckedOutEvent(_ConnectionIdEvent):
@ -1091,7 +1091,7 @@ class ConnectionCheckedInEvent(_ConnectionIdEvent):
__slots__ = ()
class _ServerEvent(object):
class _ServerEvent:
"""Base class for server events."""
__slots__ = ("__server_address", "__topology_id")
@ -1111,7 +1111,7 @@ class _ServerEvent(object):
return self.__topology_id
def __repr__(self):
return "<%s %s topology_id: %s>" % (
return "<{} {} topology_id: {}>".format(
self.__class__.__name__,
self.server_address,
self.topology_id,
@ -1130,26 +1130,28 @@ class ServerDescriptionChangedEvent(_ServerEvent):
self,
previous_description: "ServerDescription",
new_description: "ServerDescription",
*args: Any
*args: Any,
) -> None:
super(ServerDescriptionChangedEvent, self).__init__(*args)
super().__init__(*args)
self.__previous_description = previous_description
self.__new_description = new_description
@property
def previous_description(self) -> "ServerDescription":
"""The previous
:class:`~pymongo.server_description.ServerDescription`."""
:class:`~pymongo.server_description.ServerDescription`.
"""
return self.__previous_description
@property
def new_description(self) -> "ServerDescription":
"""The new
:class:`~pymongo.server_description.ServerDescription`."""
:class:`~pymongo.server_description.ServerDescription`.
"""
return self.__new_description
def __repr__(self):
return "<%s %s changed from: %s, to: %s>" % (
return "<{} {} changed from: {}, to: {}>".format(
self.__class__.__name__,
self.server_address,
self.previous_description,
@ -1175,7 +1177,7 @@ class ServerClosedEvent(_ServerEvent):
__slots__ = ()
class TopologyEvent(object):
class TopologyEvent:
"""Base class for topology description events."""
__slots__ = "__topology_id"
@ -1189,7 +1191,7 @@ class TopologyEvent(object):
return self.__topology_id
def __repr__(self):
return "<%s topology_id: %s>" % (self.__class__.__name__, self.topology_id)
return f"<{self.__class__.__name__} topology_id: {self.topology_id}>"
class TopologyDescriptionChangedEvent(TopologyEvent):
@ -1204,26 +1206,28 @@ class TopologyDescriptionChangedEvent(TopologyEvent):
self,
previous_description: "TopologyDescription",
new_description: "TopologyDescription",
*args: Any
*args: Any,
) -> None:
super(TopologyDescriptionChangedEvent, self).__init__(*args)
super().__init__(*args)
self.__previous_description = previous_description
self.__new_description = new_description
@property
def previous_description(self) -> "TopologyDescription":
"""The previous
:class:`~pymongo.topology_description.TopologyDescription`."""
:class:`~pymongo.topology_description.TopologyDescription`.
"""
return self.__previous_description
@property
def new_description(self) -> "TopologyDescription":
"""The new
:class:`~pymongo.topology_description.TopologyDescription`."""
:class:`~pymongo.topology_description.TopologyDescription`.
"""
return self.__new_description
def __repr__(self):
return "<%s topology_id: %s changed from: %s, to: %s>" % (
return "<{} topology_id: {} changed from: {}, to: {}>".format(
self.__class__.__name__,
self.topology_id,
self.previous_description,
@ -1249,7 +1253,7 @@ class TopologyClosedEvent(TopologyEvent):
__slots__ = ()
class _ServerHeartbeatEvent(object):
class _ServerHeartbeatEvent:
"""Base class for server heartbeat events."""
__slots__ = "__connection_id"
@ -1260,11 +1264,12 @@ class _ServerHeartbeatEvent(object):
@property
def connection_id(self) -> _Address:
"""The address (host, port) of the server this heartbeat was sent
to."""
to.
"""
return self.__connection_id
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.connection_id)
return f"<{self.__class__.__name__} {self.connection_id}>"
class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent):
@ -1287,7 +1292,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
def __init__(
self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False
) -> None:
super(ServerHeartbeatSucceededEvent, self).__init__(connection_id)
super().__init__(connection_id)
self.__duration = duration
self.__reply = reply
self.__awaited = awaited
@ -1313,7 +1318,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
return self.__awaited
def __repr__(self):
return "<%s %s duration: %s, awaited: %s, reply: %s>" % (
return "<{} {} duration: {}, awaited: {}, reply: {}>".format(
self.__class__.__name__,
self.connection_id,
self.duration,
@ -1334,7 +1339,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent):
def __init__(
self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False
) -> None:
super(ServerHeartbeatFailedEvent, self).__init__(connection_id)
super().__init__(connection_id)
self.__duration = duration
self.__reply = reply
self.__awaited = awaited
@ -1360,7 +1365,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent):
return self.__awaited
def __repr__(self):
return "<%s %s duration: %s, awaited: %s, reply: %r>" % (
return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format(
self.__class__.__name__,
self.connection_id,
self.duration,
@ -1369,7 +1374,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent):
)
class _EventListeners(object):
class _EventListeners:
"""Configure event listeners for a client instance.
Any event listeners registered globally are included by default.

View File

@ -219,15 +219,15 @@ def receive_message(sock_info, request_id, max_message_size=MAX_MESSAGE_SIZE):
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError("Got response id %r but expected %r" % (response_to, request_id))
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
"Message length (%r) not longer than standard message header size (16)" % (length,)
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
"Message length (%r) is larger than server max "
"message size (%r)" % (length, max_message_size)
"Message length ({!r}) is larger than server max "
"message size ({!r})".format(length, max_message_size)
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
@ -240,7 +240,7 @@ def receive_message(sock_info, request_id, max_message_size=MAX_MESSAGE_SIZE):
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError("Got opcode %r but expected %r" % (op_code, _UNPACK_REPLY.keys()))
raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}")
return unpack_reply(data)
@ -281,7 +281,7 @@ def wait_for_read(sock_info, deadline):
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
BLOCKING_IO_ERRORS = (BlockingIOError,) + ssl_support.BLOCKING_IO_ERRORS
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
def _receive_data_on_socket(sock_info, length, deadline):
@ -299,7 +299,7 @@ def _receive_data_on_socket(sock_info, length, deadline):
chunk_length = sock_info.sock.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out")
except (IOError, OSError) as exc: # noqa: B014
except OSError as exc: # noqa: B014
if _errno_from_exception(exc) == errno.EINTR:
continue
raise

View File

@ -20,7 +20,7 @@ from datetime import datetime as _datetime
from pymongo.lock import _create_lock
class _OCSPCache(object):
class _OCSPCache:
"""A cache for OCSP responses."""
CACHE_KEY_TYPE = namedtuple( # type: ignore

View File

@ -48,7 +48,7 @@ class InsertOne(Generic[_DocumentType]):
bulkobj.add_insert(self._doc)
def __repr__(self):
return "InsertOne(%r)" % (self._doc,)
return f"InsertOne({self._doc!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
@ -59,7 +59,7 @@ class InsertOne(Generic[_DocumentType]):
return not self == other
class DeleteOne(object):
class DeleteOne:
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation", "_hint")
@ -104,7 +104,7 @@ class DeleteOne(object):
bulkobj.add_delete(self._filter, 1, collation=self._collation, hint=self._hint)
def __repr__(self):
return "DeleteOne(%r, %r)" % (self._filter, self._collation)
return f"DeleteOne({self._filter!r}, {self._collation!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
@ -115,7 +115,7 @@ class DeleteOne(object):
return not self == other
class DeleteMany(object):
class DeleteMany:
"""Represents a delete_many operation."""
__slots__ = ("_filter", "_collation", "_hint")
@ -160,7 +160,7 @@ class DeleteMany(object):
bulkobj.add_delete(self._filter, 0, collation=self._collation, hint=self._hint)
def __repr__(self):
return "DeleteMany(%r, %r)" % (self._filter, self._collation)
return f"DeleteMany({self._filter!r}, {self._collation!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
@ -242,7 +242,7 @@ class ReplaceOne(Generic[_DocumentType]):
return not self == other
def __repr__(self):
return "%s(%r, %r, %r, %r, %r)" % (
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
@ -252,7 +252,7 @@ class ReplaceOne(Generic[_DocumentType]):
)
class _UpdateOp(object):
class _UpdateOp:
"""Private base class for update operations."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
@ -298,7 +298,7 @@ class _UpdateOp(object):
return not self == other
def __repr__(self):
return "%s(%r, %r, %r, %r, %r, %r)" % (
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
@ -352,7 +352,7 @@ class UpdateOne(_UpdateOp):
.. versionchanged:: 3.5
Added the `collation` option.
"""
super(UpdateOne, self).__init__(filter, update, upsert, collation, array_filters, hint)
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
@ -410,7 +410,7 @@ class UpdateMany(_UpdateOp):
.. versionchanged:: 3.5
Added the `collation` option.
"""
super(UpdateMany, self).__init__(filter, update, upsert, collation, array_filters, hint)
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
@ -425,7 +425,7 @@ class UpdateMany(_UpdateOp):
)
class IndexModel(object):
class IndexModel:
"""Represents an index to create."""
__slots__ = ("__document",)

View File

@ -22,7 +22,7 @@ from typing import Any, Optional
from pymongo.lock import _create_lock
class PeriodicExecutor(object):
class PeriodicExecutor:
def __init__(self, interval, min_interval, target, name=None):
""" "Run a target function periodically on a background thread.
@ -51,7 +51,7 @@ class PeriodicExecutor(object):
self._lock = _create_lock()
def __repr__(self):
return "<%s(name=%s) object at 0x%x>" % (self.__class__.__name__, self._name, id(self))
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def open(self) -> None:
"""Start. Multiple calls have no effect.

View File

@ -81,7 +81,6 @@ except ImportError:
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(fd):
"""Dummy function for platforms that don't provide fcntl."""
pass
_MAX_TCP_KEEPIDLE = 120
@ -134,7 +133,7 @@ else:
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except socket.error:
except OSError:
pass
def _set_keepalive_times(sock):
@ -351,7 +350,7 @@ def _raise_connection_failure(
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = "%s: %s" % (host, error)
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if isinstance(error, socket.timeout):
@ -371,7 +370,7 @@ def _cond_wait(condition, deadline):
return condition.wait(timeout)
class PoolOptions(object):
class PoolOptions:
"""Read only connection pool options for a MongoClient.
Should not be instantiated directly by application developers. Access
@ -456,17 +455,17 @@ class PoolOptions(object):
# }
if driver:
if driver.name:
self.__metadata["driver"]["name"] = "%s|%s" % (
self.__metadata["driver"]["name"] = "{}|{}".format(
_METADATA["driver"]["name"],
driver.name,
)
if driver.version:
self.__metadata["driver"]["version"] = "%s|%s" % (
self.__metadata["driver"]["version"] = "{}|{}".format(
_METADATA["driver"]["version"],
driver.version,
)
if driver.platform:
self.__metadata["platform"] = "%s|%s" % (_METADATA["platform"], driver.platform)
self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform)
env = _metadata_env()
if env:
@ -601,7 +600,7 @@ class PoolOptions(object):
return self.__load_balanced
class _CancellationContext(object):
class _CancellationContext:
def __init__(self):
self._cancelled = False
@ -615,7 +614,7 @@ class _CancellationContext(object):
return self._cancelled
class SocketInfo(object):
class SocketInfo:
"""Store a socket with some metadata.
:Parameters:
@ -1080,7 +1079,7 @@ class SocketInfo(object):
return hash(self.sock)
def __repr__(self):
return "SocketInfo(%s)%s at %s" % (
return "SocketInfo({}){} at {}".format(
repr(self.sock),
self.closed and " CLOSED" or "",
id(self),
@ -1106,7 +1105,7 @@ def _create_connection(address, options):
try:
sock.connect(host)
return sock
except socket.error:
except OSError:
sock.close()
raise
@ -1125,7 +1124,7 @@ def _create_connection(address, options):
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except socket.error:
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
@ -1144,7 +1143,7 @@ def _create_connection(address, options):
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except socket.error as e:
except OSError as e:
err = e
sock.close()
@ -1155,7 +1154,7 @@ def _create_connection(address, options):
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise socket.error("getaddrinfo failed")
raise OSError("getaddrinfo failed")
def _configured_socket(address, options):
@ -1182,7 +1181,7 @@ def _configured_socket(address, options):
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (IOError, OSError, SSLError) as exc: # noqa: B014
except (OSError, SSLError) as exc: # noqa: B014
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
@ -1208,10 +1207,8 @@ class _PoolClosedError(PyMongoError):
closed pool.
"""
pass
class _PoolGeneration(object):
class _PoolGeneration:
def __init__(self):
# Maps service_id to generation.
self._generations = collections.defaultdict(int)
@ -1242,7 +1239,7 @@ class _PoolGeneration(object):
return gen != self.get(service_id)
class PoolState(object):
class PoolState:
PAUSED = 1
READY = 2
CLOSED = 3
@ -1753,10 +1750,9 @@ class Pool:
other_ops = self.active_sockets - self.ncursors - self.ntxns
raise WaitQueueTimeoutError(
"Timeout waiting for connection from the connection pool. "
"maxPoolSize: %s, connections in use by cursors: %s, "
"connections in use by transactions: %s, connections in use "
"by other operations: %s, timeout: %s"
% (
"maxPoolSize: {}, connections in use by cursors: {}, "
"connections in use by transactions: {}, connections in use "
"by other operations: {}, timeout: {}".format(
self.opts.max_pool_size,
self.ncursors,
self.ntxns,
@ -1766,7 +1762,7 @@ class Pool:
)
raise WaitQueueTimeoutError(
"Timed out while checking out a connection from connection pool. "
"maxPoolSize: %s, timeout: %s" % (self.opts.max_pool_size, timeout)
"maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout)
)
def __del__(self):

View File

@ -67,7 +67,7 @@ _VERIFY_MAP = {
_stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
_REVERSE_VERIFY_MAP = dict((value, key) for key, value in _VERIFY_MAP.items())
_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()}
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
@ -97,7 +97,7 @@ class _sslConn(_SSL.Connection):
def __init__(self, ctx, sock, suppress_ragged_eofs):
self.socket_checker = _SocketChecker()
self.suppress_ragged_eofs = suppress_ragged_eofs
super(_sslConn, self).__init__(ctx, sock)
super().__init__(ctx, sock)
def _call(self, call, *args, **kwargs):
timeout = self.gettimeout()
@ -122,11 +122,11 @@ class _sslConn(_SSL.Connection):
continue
def do_handshake(self, *args, **kwargs):
return self._call(super(_sslConn, self).do_handshake, *args, **kwargs)
return self._call(super().do_handshake, *args, **kwargs)
def recv(self, *args, **kwargs):
try:
return self._call(super(_sslConn, self).recv, *args, **kwargs)
return self._call(super().recv, *args, **kwargs)
except _SSL.SysCallError as exc:
# Suppress ragged EOFs to match the stdlib.
if self.suppress_ragged_eofs and _ragged_eof(exc):
@ -135,7 +135,7 @@ class _sslConn(_SSL.Connection):
def recv_into(self, *args, **kwargs):
try:
return self._call(super(_sslConn, self).recv_into, *args, **kwargs)
return self._call(super().recv_into, *args, **kwargs)
except _SSL.SysCallError as exc:
# Suppress ragged EOFs to match the stdlib.
if self.suppress_ragged_eofs and _ragged_eof(exc):
@ -148,11 +148,11 @@ class _sslConn(_SSL.Connection):
total_sent = 0
while total_sent < total_length:
try:
sent = self._call(super(_sslConn, self).send, view[total_sent:], flags)
sent = self._call(super().send, view[total_sent:], flags)
# XXX: It's not clear if this can actually happen. PyOpenSSL
# doesn't appear to have any interrupt handling, nor any interrupt
# errors for OpenSSL connections.
except (IOError, OSError) as exc: # noqa: B014
except OSError as exc: # noqa: B014
if _errno_from_exception(exc) == _EINTR:
continue
raise
@ -163,7 +163,7 @@ class _sslConn(_SSL.Connection):
total_sent += sent
class _CallbackData(object):
class _CallbackData:
"""Data class which is passed to the OCSP callback."""
def __init__(self):
@ -172,7 +172,7 @@ class _CallbackData(object):
self.ocsp_response_cache = _OCSPCache()
class SSLContext(object):
class SSLContext:
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
context.
"""
@ -328,7 +328,8 @@ class SSLContext(object):
def set_default_verify_paths(self):
"""Specify that the platform provided CA certificates are to be used
for verification purposes."""
for verification purposes.
"""
# Note: See PyOpenSSL's docs for limitations, which are similar
# but not that same as CPython's.
self._ctx.set_default_verify_paths()

View File

@ -17,7 +17,7 @@
from typing import Any, Dict, Optional
class ReadConcern(object):
class ReadConcern:
"""ReadConcern
:Parameters:
@ -45,7 +45,8 @@ class ReadConcern(object):
@property
def ok_for_legacy(self) -> bool:
"""Return ``True`` if this read concern is compatible with
old wire protocol versions."""
old wire protocol versions.
"""
return self.level is None or self.level == "local"
@property

View File

@ -46,18 +46,18 @@ def _validate_tag_sets(tag_sets):
return tag_sets
if not isinstance(tag_sets, (list, tuple)):
raise TypeError(("Tag sets %r invalid, must be a sequence") % (tag_sets,))
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
if len(tag_sets) == 0:
raise ValueError(
("Tag sets %r invalid, must be None or contain at least one set of tags") % (tag_sets,)
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
)
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
"Tag set %r invalid, must be an instance of dict, "
"Tag set {!r} invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping" % (tags,)
"collection.Mapping".format(tags)
)
return list(tag_sets)
@ -88,7 +88,7 @@ def _validate_hedge(hedge):
return None
if not isinstance(hedge, dict):
raise TypeError("hedge must be a dictionary, not %r" % (hedge,))
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
return hedge
@ -97,7 +97,7 @@ _Hedge = Mapping[str, Any]
_TagSets = Sequence[Mapping[str, Any]]
class _ServerMode(object):
class _ServerMode:
"""Base class for all read preferences."""
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
@ -168,7 +168,8 @@ class _ServerMode(object):
def max_staleness(self) -> int:
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum."""
no longer be selected for operations, or -1 for no maximum.
"""
return self.__max_staleness
@property
@ -209,7 +210,7 @@ class _ServerMode(object):
return 0 if self.__max_staleness == -1 else 5
def __repr__(self):
return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % (
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
self.name,
self.__tag_sets,
self.__max_staleness,
@ -263,7 +264,7 @@ class Primary(_ServerMode):
__slots__ = ()
def __init__(self) -> None:
super(Primary, self).__init__(_PRIMARY)
super().__init__(_PRIMARY)
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to a Selection."""
@ -314,7 +315,7 @@ class PrimaryPreferred(_ServerMode):
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super(PrimaryPreferred, self).__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
@ -357,7 +358,7 @@ class Secondary(_ServerMode):
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super(Secondary, self).__init__(_SECONDARY, tag_sets, max_staleness, hedge)
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
@ -401,9 +402,7 @@ class SecondaryPreferred(_ServerMode):
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super(SecondaryPreferred, self).__init__(
_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge
)
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
@ -448,7 +447,7 @@ class Nearest(_ServerMode):
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super(Nearest, self).__init__(_NEAREST, tag_sets, max_staleness, hedge)
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
@ -490,7 +489,7 @@ class _AggWritePref:
return self.effective_pref(selection)
def __repr__(self):
return "_AggWritePref(pref=%r)" % (self.pref,)
return f"_AggWritePref(pref={self.pref!r})"
# Proxy other calls to the effective_pref so that _AggWritePref can be
# used in place of an actual read preference.
@ -524,7 +523,7 @@ _MODES = (
)
class ReadPreference(object):
class ReadPreference:
"""An enum that defines some commonly used read preference modes.
Apps can also create a custom read preference, for example::
@ -591,7 +590,7 @@ def read_pref_mode_from_name(name: str) -> int:
return _MONGOS_MODES.index(name)
class MovingAverage(object):
class MovingAverage:
"""Tracks an exponentially-weighted moving average."""
average: Optional[float]

View File

@ -15,7 +15,7 @@
"""Represent a response from the server."""
class Response(object):
class Response:
__slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs")
def __init__(self, data, address, request_id, duration, from_command, docs):
@ -86,9 +86,7 @@ class PinnedResponse(Response):
- `more_to_come`: Bool indicating whether cursor is ready to be
exhausted.
"""
super(PinnedResponse, self).__init__(
data, address, request_id, duration, from_command, docs
)
super().__init__(data, address, request_id, duration, from_command, docs)
self._socket_info = socket_info
self._more_to_come = more_to_come
@ -105,5 +103,6 @@ class PinnedResponse(Response):
@property
def more_to_come(self):
"""If true, server is ready to send batches on the socket until the
result set is exhausted or there is an error."""
result set is exhausted or there is an error.
"""
return self._more_to_come

View File

@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, cast
from pymongo.errors import InvalidOperation
class _WriteResult(object):
class _WriteResult:
"""Base class for write result classes."""
__slots__ = ("__acknowledged",)
@ -30,10 +30,10 @@ class _WriteResult(object):
"""Raise an exception on property access if unacknowledged."""
if not self.__acknowledged:
raise InvalidOperation(
"A value for %s is not available when "
"A value for {} is not available when "
"the write is unacknowledged. Check the "
"acknowledged attribute to avoid this "
"error." % (property_name,)
"error.".format(property_name)
)
@property
@ -63,7 +63,7 @@ class InsertOneResult(_WriteResult):
def __init__(self, inserted_id: Any, acknowledged: bool) -> None:
self.__inserted_id = inserted_id
super(InsertOneResult, self).__init__(acknowledged)
super().__init__(acknowledged)
@property
def inserted_id(self) -> Any:
@ -78,7 +78,7 @@ class InsertManyResult(_WriteResult):
def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None:
self.__inserted_ids = inserted_ids
super(InsertManyResult, self).__init__(acknowledged)
super().__init__(acknowledged)
@property
def inserted_ids(self) -> List:
@ -102,7 +102,7 @@ class UpdateResult(_WriteResult):
def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None:
self.__raw_result = raw_result
super(UpdateResult, self).__init__(acknowledged)
super().__init__(acknowledged)
@property
def raw_result(self) -> Dict[str, Any]:
@ -134,13 +134,14 @@ class UpdateResult(_WriteResult):
class DeleteResult(_WriteResult):
"""The return type for :meth:`~pymongo.collection.Collection.delete_one`
and :meth:`~pymongo.collection.Collection.delete_many`"""
and :meth:`~pymongo.collection.Collection.delete_many`
"""
__slots__ = ("__raw_result",)
def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None:
self.__raw_result = raw_result
super(DeleteResult, self).__init__(acknowledged)
super().__init__(acknowledged)
@property
def raw_result(self) -> Dict[str, Any]:
@ -169,7 +170,7 @@ class BulkWriteResult(_WriteResult):
:exc:`~pymongo.errors.InvalidOperation`.
"""
self.__bulk_api_result = bulk_api_result
super(BulkWriteResult, self).__init__(acknowledged)
super().__init__(acknowledged)
@property
def bulk_api_result(self) -> Dict[str, Any]:
@ -211,7 +212,5 @@ class BulkWriteResult(_WriteResult):
"""A map of operation index to the _id of the upserted document."""
self._raise_if_unacknowledged("upserted_ids")
if self.__bulk_api_result:
return dict(
(upsert["index"], upsert["_id"]) for upsert in self.bulk_api_result["upserted"]
)
return {upsert["index"]: upsert["_id"] for upsert in self.bulk_api_result["upserted"]}
return None

View File

@ -71,7 +71,7 @@ else:
return data
if prohibit_unassigned_code_points:
prohibited = _PROHIBITED + (stringprep.in_table_a1,)
prohibited = (*_PROHIBITED, stringprep.in_table_a1)
else:
prohibited = _PROHIBITED
@ -98,12 +98,12 @@ else:
raise ValueError("SASLprep: failed bidirectional check")
# RFC3454, Section 6, #2. If a string contains any RandALCat
# character, it MUST NOT contain any LCat character.
prohibited = prohibited + (stringprep.in_table_d2,)
prohibited = (*prohibited, stringprep.in_table_d2)
else:
# RFC3454, Section 6, #3. Following the logic of #3, if
# the first character is not a RandALCat, no other character
# can be either.
prohibited = prohibited + (in_table_d1,)
prohibited = (*prohibited, in_table_d1)
# RFC3454 section 2, step 3 and 4 - Prohibit and check bidi
for char in data:

View File

@ -25,7 +25,7 @@ from pymongo.response import PinnedResponse, Response
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
class Server(object):
class Server:
def __init__(
self, server_description, pool, monitor, topology_id=None, listeners=None, events=None
):
@ -245,4 +245,4 @@ class Server(object):
return request_id, data, 0
def __repr__(self):
return "<%s %r>" % (self.__class__.__name__, self._description)
return f"<{self.__class__.__name__} {self._description!r}>"

View File

@ -95,7 +95,7 @@ class ServerApiVersion:
"""Server API version "1"."""
class ServerApi(object):
class ServerApi:
"""MongoDB Stable API."""
def __init__(self, version, strict=None, deprecation_errors=None):
@ -113,16 +113,16 @@ class ServerApi(object):
.. versionadded:: 3.12
"""
if version != ServerApiVersion.V1:
raise ValueError("Unknown ServerApi version: %s" % (version,))
raise ValueError(f"Unknown ServerApi version: {version}")
if strict is not None and not isinstance(strict, bool):
raise TypeError(
"Wrong type for ServerApi strict, value must be an instance "
"of bool, not %s" % (type(strict),)
"of bool, not {}".format(type(strict))
)
if deprecation_errors is not None and not isinstance(deprecation_errors, bool):
raise TypeError(
"Wrong type for ServerApi deprecation_errors, value must be "
"an instance of bool, not %s" % (type(deprecation_errors),)
"an instance of bool, not {}".format(type(deprecation_errors))
)
self._version = version
self._strict = strict

View File

@ -25,7 +25,7 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _Address
class ServerDescription(object):
class ServerDescription:
"""Immutable representation of one server.
:Parameters:
@ -287,8 +287,8 @@ class ServerDescription(object):
def __repr__(self):
errmsg = ""
if self.error:
errmsg = ", error=%r" % (self.error,)
return "<%s %s server_type: %s, rtt: %s%s>" % (
errmsg = f", error={self.error!r}"
return "<{} {} server_type: {}, rtt: {}{}>".format(
self.__class__.__name__,
self.address,
self.server_type_name,

View File

@ -17,7 +17,7 @@
from pymongo.server_type import SERVER_TYPE
class Selection(object):
class Selection:
"""Input or output of a server selector function."""
@classmethod
@ -51,6 +51,7 @@ class Selection(object):
secondaries = secondary_server_selector(self)
if secondaries.server_descriptions:
return max(secondaries.server_descriptions, key=lambda sd: sd.last_write_date)
return None
@property
def primary_selection(self):

View File

@ -26,7 +26,7 @@ from pymongo.server_description import ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
class TopologySettings(object):
class TopologySettings:
def __init__(
self,
seeds=None,
@ -156,4 +156,4 @@ class TopologySettings(object):
def get_server_descriptions(self):
"""Initial dict of (address, ServerDescription) for all seeds."""
return dict([(address, ServerDescription(address)) for address in self.seeds])
return {address: ServerDescription(address) for address in self.seeds}

View File

@ -33,7 +33,7 @@ def _errno_from_exception(exc):
return None
class SocketChecker(object):
class SocketChecker:
def __init__(self) -> None:
self._poller: Optional[select.poll]
if _HAVE_POLL:
@ -78,7 +78,7 @@ class SocketChecker(object):
# ready: subsets of the first three arguments. Return
# True if any of the lists are not empty.
return any(res)
except (_SelectError, IOError) as exc: # type: ignore
except (_SelectError, OSError) as exc: # type: ignore
if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
continue
raise

View File

@ -51,7 +51,7 @@ _INVALID_HOST_MSG = (
)
class _SrvResolver(object):
class _SrvResolver:
def __init__(self, fqdn, connect_timeout, srv_service_name, srv_max_hosts=0):
self.__fqdn = fqdn
self.__srv = srv_service_name
@ -110,9 +110,9 @@ class _SrvResolver(object):
try:
nlist = node[0].split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__plist != nlist:
raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes

View File

@ -71,7 +71,7 @@ if HAVE_SSL:
try:
ctx.load_cert_chain(certfile, None, passphrase)
except _ssl.SSLError as exc:
raise ConfigurationError("Private key doesn't match certificate: %s" % (exc,))
raise ConfigurationError(f"Private key doesn't match certificate: {exc}")
if crlfile is not None:
if _ssl.IS_PYOPENSSL:
raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL")

View File

@ -75,7 +75,7 @@ def process_events_queue(queue_ref):
return True # Continue PeriodicExecutor.
class Topology(object):
class Topology:
"""Monitor a topology of one or more servers."""
def __init__(self, topology_settings):
@ -236,8 +236,7 @@ class Topology(object):
# No suitable servers.
if timeout == 0 or now > end_time:
raise ServerSelectionTimeoutError(
"%s, Timeout: %ss, Topology Description: %r"
% (self._error_message(selector), timeout, self.description)
f"{self._error_message(selector)}, Timeout: {timeout}s, Topology Description: {self.description!r}"
)
self._ensure_opened()
@ -431,7 +430,7 @@ class Topology(object):
):
return set()
return set([sd.address for sd in selector(self._new_selection())])
return {sd.address for sd in selector(self._new_selection())}
def get_secondaries(self):
"""Return set of secondary addresses."""
@ -499,7 +498,8 @@ class Topology(object):
def close(self):
"""Clear pools and terminate monitors. Topology does not reopen on
demand. Any further operations will raise
:exc:`~.errors.InvalidOperation`."""
:exc:`~.errors.InvalidOperation`.
"""
with self._lock:
for server in self._servers.values():
server.close()
@ -807,14 +807,14 @@ class Topology(object):
else:
return "No %s available for writes" % server_plural
else:
return 'No %s match selector "%s"' % (server_plural, selector)
return f'No {server_plural} match selector "{selector}"'
else:
addresses = list(self._description.server_descriptions())
servers = list(self._description.server_descriptions().values())
if not servers:
if is_replica_set:
# We removed all servers because of the wrong setName?
return 'No %s available for replica set name "%s"' % (
return 'No {} available for replica set name "{}"'.format(
server_plural,
self._settings.replica_set_name,
)
@ -844,7 +844,7 @@ class Topology(object):
msg = ""
if not self._opened:
msg = "CLOSED "
return "<%s %s%r>" % (self.__class__.__name__, msg, self._description)
return f"<{self.__class__.__name__} {msg}{self._description!r}>"
def eq_props(self):
"""The properties to use for MongoClient/Topology equality checks."""
@ -860,7 +860,7 @@ class Topology(object):
return hash(self.eq_props())
class _ErrorContext(object):
class _ErrorContext:
"""An error with context for SDAM error handling."""
def __init__(self, error, max_wire_version, sock_generation, completed_handshake, service_id):

View File

@ -47,7 +47,7 @@ SRV_POLLING_TOPOLOGIES: Tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
class TopologyDescription(object):
class TopologyDescription:
def __init__(
self,
topology_type: int,
@ -171,7 +171,7 @@ class TopologyDescription(object):
topology_type = self._topology_type
# The default ServerDescription's type is Unknown.
sds = dict((address, ServerDescription(address)) for address in self._server_descriptions)
sds = {address: ServerDescription(address) for address in self._server_descriptions}
return TopologyDescription(
topology_type,
@ -184,7 +184,8 @@ class TopologyDescription(object):
def server_descriptions(self) -> Dict[_Address, ServerDescription]:
"""Dict of (address,
:class:`~pymongo.server_description.ServerDescription`)."""
:class:`~pymongo.server_description.ServerDescription`).
"""
return self._server_descriptions.copy()
@property
@ -346,7 +347,7 @@ class TopologyDescription(object):
def __repr__(self):
# Sort the servers by address.
servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address)
return "<%s id: %s, topology_type: %s, servers: %r>" % (
return "<{} id: {}, topology_type: {}, servers: {!r}>".format(
self.__class__.__name__,
self._topology_settings._topology_id,
self.topology_type_name,
@ -400,8 +401,9 @@ def updated_topology_description(
if set_name is not None and set_name != server_description.replica_set_name:
error = ConfigurationError(
"client is configured to connect to a replica set named "
"'%s' but this node belongs to a set named '%s'"
% (set_name, server_description.replica_set_name)
"'{}' but this node belongs to a set named '{}'".format(
set_name, server_description.replica_set_name
)
)
sds[address] = server_description.to_unknown(error=error)
# Single type never changes.

View File

@ -29,7 +29,8 @@ _Pipeline = Sequence[Mapping[str, Any]]
def strip_optional(elem):
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
while inside a list comprehension."""
while inside a list comprehension.
"""
assert elem is not None
return elem

View File

@ -134,7 +134,7 @@ def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Addr
host, port = host.split(":", 1)
if isinstance(port, str):
if not port.isdigit() or int(port) > 65535 or int(port) <= 0:
raise ValueError("Port must be an integer between 0 and 65535: %r" % (port,))
raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}")
port = int(port)
# Normalize hostname to lowercase, since DNS is case-insensitive:
@ -155,7 +155,8 @@ _IMPLICIT_TLSINSECURE_OPTS = {
def _parse_options(opts, delim):
"""Helper method for split_options which creates the options dict.
Also handles the creation of a list for the URI tag_sets/
readpreferencetags portion, and the use of a unicode options string."""
readpreferencetags portion, and the use of a unicode options string.
"""
options = _CaseInsensitiveDictionary()
for uriopt in opts.split(delim):
key, value = uriopt.split("=")
@ -163,7 +164,7 @@ def _parse_options(opts, delim):
options.setdefault(key, []).append(value)
else:
if key in options:
warnings.warn("Duplicate URI option '%s'." % (key,))
warnings.warn(f"Duplicate URI option '{key}'.")
if key.lower() == "authmechanismproperties":
val = value
else:
@ -475,9 +476,7 @@ def parse_uri(
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
else:
raise InvalidURI(
"Invalid URI scheme: URI must begin with '%s' or '%s'" % (SCHEME, SRV_SCHEME)
)
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
if not scheme_free:
raise InvalidURI("Must provide at least one hostname or IP.")
@ -525,15 +524,13 @@ def parse_uri(
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
if options.get("directConnection"):
raise ConfigurationError(
"Cannot specify directConnection=true with %s URIs" % (SRV_SCHEME,)
)
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
nodes = split_hosts(hosts, default_port=None)
if len(nodes) != 1:
raise InvalidURI("%s URIs must include one, and only one, hostname" % (SRV_SCHEME,))
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
fqdn, port = nodes[0]
if port is not None:
raise InvalidURI("%s URIs must not include a port number" % (SRV_SCHEME,))
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.

View File

@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, Union
from pymongo.errors import ConfigurationError
class WriteConcern(object):
class WriteConcern:
"""WriteConcern
:Parameters:
@ -113,7 +113,9 @@ class WriteConcern(object):
return self.__acknowledged
def __repr__(self):
return "WriteConcern(%s)" % (", ".join("%s=%s" % kvt for kvt in self.__document.items()),)
return "WriteConcern({})".format(
", ".join("{}={}".format(*kvt) for kvt in self.__document.items())
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, WriteConcern):

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test suite for pymongo, bson, and gridfs.
"""
"""Test suite for pymongo, bson, and gridfs."""
import base64
import gc
@ -92,7 +91,7 @@ CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certifica
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
TLS_OPTIONS: Dict = dict(tls=True)
TLS_OPTIONS: Dict = {"tls": True}
if CLIENT_PEM:
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
if CA_PEM:
@ -149,7 +148,7 @@ def is_server_resolvable():
try:
socket.gethostbyname("server")
return True
except socket.error:
except OSError:
return False
finally:
socket.setdefaulttimeout(socket_timeout)
@ -165,7 +164,7 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
return authdb.command(cmd)
class client_knobs(object):
class client_knobs:
def __init__(
self,
heartbeat_frequency=None,
@ -234,10 +233,9 @@ class client_knobs(object):
def __del__(self):
if self._enabled:
msg = (
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, "
"MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, "
"EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s"
% (
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
common.HEARTBEAT_FREQUENCY,
common.MIN_HEARTBEAT_INTERVAL,
common.KILL_CURSOR_FREQUENCY,
@ -250,10 +248,10 @@ class client_knobs(object):
def _all_users(db):
return set(u["user"] for u in db.command("usersInfo").get("users", []))
return {u["user"] for u in db.command("usersInfo").get("users", [])}
class ClientContext(object):
class ClientContext:
client: MongoClient
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
@ -339,14 +337,14 @@ class ClientContext(object):
except pymongo.errors.OperationFailure as exc:
# SERVER-32063
self.connection_attempts.append(
"connected client %r, but legacy hello failed: %s" % (client, exc)
f"connected client {client!r}, but legacy hello failed: {exc}"
)
else:
self.connection_attempts.append("successfully connected client %r" % (client,))
self.connection_attempts.append(f"successfully connected client {client!r}")
# If connected, then return client with default timeout
return pymongo.MongoClient(host, port, **kwargs)
except pymongo.errors.ConnectionFailure as exc:
self.connection_attempts.append("failed to connect client %r: %s" % (client, exc))
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
return None
finally:
client.close()
@ -447,7 +445,7 @@ class ClientContext(object):
nodes.extend([partition_node(node.lower()) for node in hello.get("arbiters", [])])
self.nodes = set(nodes)
else:
self.nodes = set([(host, port)])
self.nodes = {(host, port)}
self.w = len(hello.get("hosts", [])) or 1
self.version = Version.from_client(self.client)
@ -587,7 +585,7 @@ class ClientContext(object):
for info in socket.getaddrinfo(self.host, self.port):
if info[0] == socket.AF_INET6:
return True
except socket.error:
except OSError:
pass
return False
@ -599,7 +597,7 @@ class ClientContext(object):
self.init()
# Always raise SkipTest if we can't connect to MongoDB
if not self.connected:
raise SkipTest("Cannot connect to MongoDB on %s" % (self.pair,))
raise SkipTest(f"Cannot connect to MongoDB on {self.pair}")
if condition():
return f(*args, **kwargs)
raise SkipTest(msg)
@ -625,7 +623,7 @@ class ClientContext(object):
"""Run a test only if we can connect to MongoDB."""
return self._require(
lambda: True, # _require checks if we're connected
"Cannot connect to MongoDB on %s" % (self.pair,),
f"Cannot connect to MongoDB on {self.pair}",
func=func,
)
@ -633,14 +631,15 @@ class ClientContext(object):
"""Run a test only if we are connected to Atlas Data Lake."""
return self._require(
lambda: self.is_data_lake,
"Not connected to Atlas Data Lake on %s" % (self.pair,),
f"Not connected to Atlas Data Lake on {self.pair}",
func=func,
)
def require_no_mmap(self, func):
"""Run a test only if the server is not using the MMAPv1 storage
engine. Only works for standalone and replica sets; tests are
run regardless of storage engine on sharded clusters."""
run regardless of storage engine on sharded clusters.
"""
def is_not_mmap():
if self.is_mongos:
@ -734,7 +733,8 @@ class ClientContext(object):
def require_multiple_mongoses(self, func):
"""Run a test only if the client is connected to a sharded cluster
that has 2 mongos nodes."""
that has 2 mongos nodes.
"""
return self._require(
lambda: len(self.mongoses) > 1, "Must have multiple mongoses available", func=func
)
@ -786,7 +786,7 @@ class ClientContext(object):
"load-balanced",
}
if unknown:
raise AssertionError("Unknown topologies: %r" % (unknown,))
raise AssertionError(f"Unknown topologies: {unknown!r}")
if self.load_balancer:
if "load-balanced" in topologies:
return True
@ -812,7 +812,8 @@ class ClientContext(object):
def require_cluster_type(self, topologies=[]): # noqa
"""Run a test only if the client is connected to a cluster that
conforms to one of the specified topologies. Acceptable topologies
are 'single', 'replicaset', and 'sharded'."""
are 'single', 'replicaset', and 'sharded'.
"""
def _is_valid_topology():
return self.is_topology_type(topologies)
@ -827,7 +828,8 @@ class ClientContext(object):
def require_failCommand_fail_point(self, func):
"""Run a test only if the server supports the failCommand fail
point."""
point.
"""
return self._require(
lambda: self.supports_failCommand_fail_point,
"failCommand fail point must be supported",
@ -930,7 +932,7 @@ class ClientContext(object):
)
def mongos_seeds(self):
return ",".join("%s:%s" % address for address in self.mongoses)
return ",".join("{}:{}".format(*address) for address in self.mongoses)
@property
def supports_failCommand_fail_point(self):
@ -1139,7 +1141,7 @@ class MockClientTest(unittest.TestCase):
pass
def setUp(self):
super(MockClientTest, self).setUp()
super().setUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
@ -1147,7 +1149,7 @@ class MockClientTest(unittest.TestCase):
def tearDown(self):
self.client_knobs.disable()
super(MockClientTest, self).tearDown()
super().tearDown()
# Global knobs to speed up the test suite.
@ -1181,9 +1183,9 @@ def print_running_topology(topology):
if running:
print(
"WARNING: found Topology with running threads:\n"
" Threads: %s\n"
" Topology: %s\n"
" Creation traceback:\n%s" % (running, topology, topology._settings._stack)
" Threads: {}\n"
" Topology: {}\n"
" Creation traceback:\n{}".format(running, topology, topology._settings._stack)
)
@ -1215,11 +1217,11 @@ def teardown():
global_knobs.disable()
garbage = []
for g in gc.garbage:
garbage.append("GARBAGE: %r" % (g,))
garbage.append(" gc.get_referents: %r" % (gc.get_referents(g),))
garbage.append(" gc.get_referrers: %r" % (gc.get_referrers(g),))
garbage.append(f"GARBAGE: {g!r}")
garbage.append(f" gc.get_referents: {gc.get_referents(g)!r}")
garbage.append(f" gc.get_referrers: {gc.get_referrers(g)!r}")
if garbage:
assert False, "\n".join(garbage)
raise AssertionError("\n".join(garbage))
c = client_context.client
if c:
if not client_context.is_data_lake:
@ -1237,7 +1239,7 @@ def teardown():
class PymongoTestRunner(unittest.TextTestRunner):
def run(self, test):
setup()
result = super(PymongoTestRunner, self).run(test)
result = super().run(test)
teardown()
return result
@ -1247,7 +1249,7 @@ if HAVE_XML:
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
def run(self, test):
setup()
result = super(PymongoXMLTestRunner, self).run(test)
result = super().run(test)
teardown()
return result
@ -1260,8 +1262,7 @@ def test_cases(suite):
yield suite_or_case
else:
# unittest.TestSuite
for case in test_cases(suite_or_case):
yield case
yield from test_cases(suite_or_case)
# Helper method to workaround https://bugs.python.org/issue21724
@ -1272,7 +1273,7 @@ def clear_warning_registry():
setattr(module, "__warningregistry__", {}) # noqa
class SystemCertsPatcher(object):
class SystemCertsPatcher:
def __init__(self, ca_certs):
if (
ssl.OPENSSL_VERSION.lower().startswith("libressl")

View File

@ -102,7 +102,7 @@ class TestAtlasConnect(unittest.TestCase):
duplicates = [names for names in uri_to_names.values() if len(names) > 1]
self.assertFalse(
duplicates,
"Error: the following env variables have duplicate values: %s" % (duplicates,),
f"Error: the following env variables have duplicate values: {duplicates}",
)

View File

@ -39,7 +39,7 @@ class TestAuthAWS(unittest.TestCase):
if "@" not in self.uri:
self.skipTest("MONGODB_URI already has no credentials")
hosts = ["%s:%s" % addr for addr in parse_uri(self.uri)["nodelist"]]
hosts = ["{}:{}".format(*addr) for addr in parse_uri(self.uri)["nodelist"]]
self.assertTrue(hosts)
with MongoClient(hosts) as client:
with self.assertRaises(OperationFailure):
@ -115,7 +115,7 @@ class TestAuthAWS(unittest.TestCase):
def test_environment_variables_ignored(self):
creds = self.setup_cache()
self.assertIsNotNone(creds)
prev = os.environ.copy()
os.environ.copy()
client = MongoClient(self.uri)
self.addCleanup(client.close)
@ -124,9 +124,11 @@ class TestAuthAWS(unittest.TestCase):
self.assertIsNotNone(auth.get_cached_credentials())
mock_env = dict(
AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar", AWS_SESSION_TOKEN="baz"
)
mock_env = {
"AWS_ACCESS_KEY_ID": "foo",
"AWS_SECRET_ACCESS_KEY": "bar",
"AWS_SESSION_TOKEN": "baz",
}
with patch.dict("os.environ", mock_env):
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo")
@ -147,7 +149,7 @@ class TestAuthAWS(unittest.TestCase):
self.assertIsNotNone(creds)
auth.set_cached_credentials(None)
mock_env = dict(AWS_ACCESS_KEY_ID=creds.username, AWS_SECRET_ACCESS_KEY=creds.password)
mock_env = {"AWS_ACCESS_KEY_ID": creds.username, "AWS_SECRET_ACCESS_KEY": creds.password}
if creds.token:
mock_env["AWS_SESSION_TOKEN"] = creds.token

View File

@ -65,7 +65,7 @@ class TestAuthOIDC(unittest.TestCase):
self.assertEqual(timeout_seconds, 60 * 5)
with open(token_file) as fid:
token = fid.read()
resp = dict(access_token=token)
resp = {"access_token": token}
time.sleep(sleep)
@ -94,7 +94,7 @@ class TestAuthOIDC(unittest.TestCase):
# Validate the timeout.
self.assertEqual(context["timeout_seconds"], 60 * 5)
resp = dict(access_token=token)
resp = {"access_token": token}
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.refresh_called += 1
@ -115,21 +115,21 @@ class TestAuthOIDC(unittest.TestCase):
def test_connect_callbacks_single_implicit_username(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_single_explicit_username(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user1(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
props: Dict = {"request_token_callback": request_token}
client = MongoClient(
self.uri_multiple, username="test_user1", authmechanismproperties=props
)
@ -138,7 +138,7 @@ class TestAuthOIDC(unittest.TestCase):
def test_connect_callbacks_multiple_principal_user2(self):
request_token = self.create_request_cb("test_user2")
props: Dict = dict(request_token_callback=request_token)
props: Dict = {"request_token_callback": request_token}
client = MongoClient(
self.uri_multiple, username="test_user2", authmechanismproperties=props
)
@ -147,7 +147,7 @@ class TestAuthOIDC(unittest.TestCase):
def test_connect_callbacks_multiple_no_username(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
with self.assertRaises(OperationFailure):
client.test.test.find_one()
@ -155,13 +155,13 @@ class TestAuthOIDC(unittest.TestCase):
def test_allowed_hosts_blocked(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token, allowed_hosts=[])
props: Dict = {"request_token_callback": request_token, "allowed_hosts": []}
client = MongoClient(self.uri_single, authmechanismproperties=props)
with self.assertRaises(ConfigurationError):
client.test.test.find_one()
client.close()
props: Dict = dict(request_token_callback=request_token, allowed_hosts=["example.com"])
props: Dict = {"request_token_callback": request_token, "allowed_hosts": ["example.com"]}
client = MongoClient(
self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False
)
@ -170,26 +170,26 @@ class TestAuthOIDC(unittest.TestCase):
client.close()
def test_connect_aws_single_principal(self):
props = dict(PROVIDER_NAME="aws")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user1(self):
props = dict(PROVIDER_NAME="aws")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user2(self):
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2")
props = dict(PROVIDER_NAME="aws")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_allowed_hosts_ignored(self):
props = dict(PROVIDER_NAME="aws", allowed_hosts=[])
props = {"PROVIDER_NAME": "aws", "allowed_hosts": []}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
@ -198,10 +198,10 @@ class TestAuthOIDC(unittest.TestCase):
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = dict(
request_token_callback=request_cb,
refresh_token_callback=refresh_cb,
)
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb,
}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
@ -214,7 +214,7 @@ class TestAuthOIDC(unittest.TestCase):
request_cb = self.create_request_cb(sleep=0.5)
refresh_cb = self.create_refresh_cb()
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
def run_test():
client = MongoClient(self.uri_single, authMechanismProperties=props)
@ -239,7 +239,7 @@ class TestAuthOIDC(unittest.TestCase):
def request_token_null(a, b):
return None
props: Dict = dict(request_token_callback=request_token_null)
props: Dict = {"request_token_callback": request_token_null}
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
@ -251,9 +251,10 @@ class TestAuthOIDC(unittest.TestCase):
def refresh_token_null(a, b):
return None
props: Dict = dict(
request_token_callback=request_cb, refresh_token_callback=refresh_token_null
)
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_token_null,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
@ -265,9 +266,9 @@ class TestAuthOIDC(unittest.TestCase):
def test_request_callback_invalid_result(self):
def request_token_invalid(a, b):
return dict()
return {}
props: Dict = dict(request_token_callback=request_token_invalid)
props: Dict = {"request_token_callback": request_token_invalid}
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
@ -278,7 +279,7 @@ class TestAuthOIDC(unittest.TestCase):
result["foo"] = "bar"
return result
props: Dict = dict(request_token_callback=request_cb_extra_value)
props: Dict = {"request_token_callback": request_cb_extra_value}
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
@ -288,11 +289,12 @@ class TestAuthOIDC(unittest.TestCase):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_no_token(a, b):
return dict()
return {}
props: Dict = dict(
request_token_callback=request_cb, refresh_token_callback=refresh_cb_no_token
)
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb_no_token,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
@ -310,9 +312,10 @@ class TestAuthOIDC(unittest.TestCase):
result["foo"] = "bar"
return result
props: Dict = dict(
request_token_callback=request_cb, refresh_token_callback=refresh_cb_extra_value
)
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb_extra_value,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
@ -329,7 +332,7 @@ class TestAuthOIDC(unittest.TestCase):
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Ensure that a ``find`` operation adds credentials to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
@ -352,7 +355,7 @@ class TestAuthOIDC(unittest.TestCase):
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb()
props = dict(request_token_callback=request_cb)
props = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds credentials to the cache.
@ -373,7 +376,7 @@ class TestAuthOIDC(unittest.TestCase):
def test_cache_key_includes_callback(self):
request_cb = self.create_request_cb()
props: Dict = dict(request_token_callback=request_cb)
props: Dict = {"request_token_callback": request_cb}
# Ensure that a ``find`` operation adds a new entry to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
@ -397,10 +400,10 @@ class TestAuthOIDC(unittest.TestCase):
# Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials.
def refresh_cb(a, b):
return dict(access_token="bad")
return {"access_token": "bad"}
# Add a token to the cache that will expire soon.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
@ -421,7 +424,7 @@ class TestAuthOIDC(unittest.TestCase):
def test_cache_is_not_used_in_aws_automatic_workflow(self):
# Create a new client using the AWS device workflow.
# Ensure that a ``find`` operation does not add credentials to the cache.
props = dict(PROVIDER_NAME="aws")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
@ -438,11 +441,11 @@ class TestAuthOIDC(unittest.TestCase):
def request_token(a, b):
with open(token_file) as fid:
token = fid.read()
return dict(access_token=token, expires_in_seconds=1000)
return {"access_token": token, "expires_in_seconds": 1000}
# Create a client with a request callback that returns a valid token
# that will not expire soon.
props: Dict = dict(request_token_callback=request_token)
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Set a fail point for saslStart commands.
@ -483,7 +486,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
@ -536,7 +539,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
@ -563,7 +566,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
@ -594,7 +597,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
@ -622,7 +625,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
@ -647,7 +650,7 @@ class TestAuthOIDC(unittest.TestCase):
def test_reauthenticate_succeeds_get_more_exhaust(self):
# Ensure no mongos
props = dict(PROVIDER_NAME="aws")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
hello = client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") != "isdbgrid":
@ -657,7 +660,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
@ -685,7 +688,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
print("start of test")
client = MongoClient(self.uri_single, authmechanismproperties=props)
@ -703,7 +706,7 @@ class TestAuthOIDC(unittest.TestCase):
}
):
# Perform a count operation.
cursor = client.test.command(dict(count="test"))
cursor = client.test.command({"count": "test"})
self.assertGreaterEqual(len(list(cursor)), 1)
@ -720,7 +723,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
@ -750,7 +753,7 @@ class TestAuthOIDC(unittest.TestCase):
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
@ -778,7 +781,7 @@ class TestAuthOIDC(unittest.TestCase):
request_cb = self.create_request_cb(expires_in_seconds=1e6)
refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6)
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client1 = MongoClient(self.uri_single, authMechanismProperties=props)
client1.test.test.find_one()
client2 = MongoClient(self.uri_single, authMechanismProperties=props)

View File

@ -27,7 +27,7 @@ class TestCrudV2(SpecRunner):
def allowable_errors(self, op):
"""Override expected error classes."""
errors = super(TestCrudV2, self).allowable_errors(op)
errors = super().allowable_errors(op)
errors += (ValueError,)
return errors
@ -51,4 +51,4 @@ class TestCrudV2(SpecRunner):
"""Allow specs to override a test's setup."""
# PYTHON-1935 Only create the collection if there is data to insert.
if scenario_def["data"]:
super(TestCrudV2, self).setup_scenario(scenario_def)
super().setup_scenario(scenario_def)

View File

@ -112,7 +112,7 @@ operations = [
]
_ops_by_name = dict([(op.name, op) for op in operations])
_ops_by_name = {op.name: op for op in operations}
Upgrade = namedtuple("Upgrade", ["name", "function", "old", "new", "wire_version"])

View File

@ -247,6 +247,7 @@ class TestHandshake(unittest.TestCase):
}
)
)
return None
else:
return request.reply(**primary_response)

View File

@ -46,7 +46,7 @@ class TestMixedVersionSharded(unittest.TestCase):
"ismaster", ismaster=True, msg="isdbgrid", maxWireVersion=upgrade.wire_version
)
self.mongoses_uri = "mongodb://%s,%s" % (
self.mongoses_uri = "mongodb://{},{}".format(
self.mongos_old.address_string,
self.mongos_new.address_string,
)

View File

@ -110,7 +110,7 @@ def generate_mongos_read_mode_tests():
# Skip something like command('foo', read_preference=SECONDARY).
continue
test = create_mongos_read_mode_test(mode, operation)
test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode)
test_name = "test_{}_with_mode_{}".format(operation.name.replace(" ", "_"), mode)
test.__name__ = test_name
setattr(TestMongosCommandReadMode, test_name, test)

View File

@ -26,7 +26,7 @@ class TestNetworkDisconnectPrimary(unittest.TestCase):
# Application operation fails against primary. Test that topology
# type changes from ReplicaSetWithPrimary to ReplicaSetNoPrimary.
# http://bit.ly/1B5ttuL
primary, secondary = servers = [MockupDB() for _ in range(2)]
primary, secondary = servers = (MockupDB() for _ in range(2))
for server in servers:
server.run()
self.addCleanup(server.stop)

View File

@ -304,7 +304,7 @@ def operation_test(op):
def create_tests(ops):
for op in ops:
test_name = "test_op_msg_%s" % (op.name,)
test_name = f"test_op_msg_{op.name}"
setattr(TestOpMsg, test_name, operation_test(op))

View File

@ -35,7 +35,7 @@ class OpMsgReadPrefBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(OpMsgReadPrefBase, cls).setUpClass()
super().setUpClass()
@classmethod
def add_test(cls, mode, test_name, test):
@ -50,7 +50,7 @@ class OpMsgReadPrefBase(unittest.TestCase):
class TestOpMsgMongos(OpMsgReadPrefBase):
@classmethod
def setUpClass(cls):
super(TestOpMsgMongos, cls).setUpClass()
super().setUpClass()
auto_ismaster = {
"ismaster": True,
"msg": "isdbgrid", # Mongos.
@ -64,13 +64,13 @@ class TestOpMsgMongos(OpMsgReadPrefBase):
@classmethod
def tearDownClass(cls):
cls.primary.stop()
super(TestOpMsgMongos, cls).tearDownClass()
super().tearDownClass()
class TestOpMsgReplicaSet(OpMsgReadPrefBase):
@classmethod
def setUpClass(cls):
super(TestOpMsgReplicaSet, cls).setUpClass()
super().setUpClass()
cls.primary, cls.secondary = MockupDB(), MockupDB()
for server in cls.primary, cls.secondary:
server.run()
@ -94,7 +94,7 @@ class TestOpMsgReplicaSet(OpMsgReadPrefBase):
def tearDownClass(cls):
for server in cls.primary, cls.secondary:
server.stop()
super(TestOpMsgReplicaSet, cls).tearDownClass()
super().tearDownClass()
@classmethod
def add_test(cls, mode, test_name, test):
@ -118,7 +118,7 @@ class TestOpMsgSingle(OpMsgReadPrefBase):
@classmethod
def setUpClass(cls):
super(TestOpMsgSingle, cls).setUpClass()
super().setUpClass()
auto_ismaster = {
"ismaster": True,
"minWireVersion": 2,
@ -131,7 +131,7 @@ class TestOpMsgSingle(OpMsgReadPrefBase):
@classmethod
def tearDownClass(cls):
cls.primary.stop()
super(TestOpMsgSingle, cls).tearDownClass()
super().tearDownClass()
def create_op_msg_read_mode_test(mode, operation):
@ -181,7 +181,7 @@ def generate_op_msg_read_mode_tests():
for entry in matrix:
mode, operation = entry
test = create_op_msg_read_mode_test(mode, operation)
test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode)
test_name = "test_{}_with_mode_{}".format(operation.name.replace(" ", "_"), mode)
test.__name__ = test_name
for cls in TestOpMsgMongos, TestOpMsgReplicaSet, TestOpMsgSingle:
cls.add_test(mode, test_name, test)

View File

@ -26,7 +26,7 @@ from pymongo.server_type import SERVER_TYPE
class TestResetAndRequestCheck(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestResetAndRequestCheck, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.ismaster_time = 0.0
self.client = None
self.server = None
@ -143,7 +143,7 @@ def generate_reset_tests():
for entry in matrix:
operation, (test_method, name) = entry
test = create_reset_test(operation, test_method)
test_name = "%s_%s" % (name, operation.name.replace(" ", "_"))
test_name = "{}_{}".format(name, operation.name.replace(" ", "_"))
test.__name__ = test_name
setattr(TestResetAndRequestCheck, test_name, test)

View File

@ -43,7 +43,7 @@ class TestSlaveOkaySharded(unittest.TestCase):
"ismaster", minWireVersion=2, maxWireVersion=6, ismaster=True, msg="isdbgrid"
)
self.mongoses_uri = "mongodb://%s,%s" % (
self.mongoses_uri = "mongodb://{},{}".format(
self.mongos1.address_string,
self.mongos2.address_string,
)
@ -59,7 +59,7 @@ def create_slave_ok_sharded_test(mode, operation):
elif operation.op_type == "must-use-primary":
slave_ok = False
else:
assert False, "unrecognized op_type %r" % operation.op_type
raise AssertionError("unrecognized op_type %r" % operation.op_type)
pref = make_read_preference(read_pref_mode_from_name(mode), tag_sets=None)
@ -84,7 +84,7 @@ def generate_slave_ok_sharded_tests():
for entry in matrix:
mode, operation = entry
test = create_slave_ok_sharded_test(mode, operation)
test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode)
test_name = "test_{}_with_mode_{}".format(operation.name.replace(" ", "_"), mode)
test.__name__ = test_name
setattr(TestSlaveOkaySharded, test_name, test)

View File

@ -78,7 +78,7 @@ def generate_slave_ok_single_tests():
mode, (server_type, ismaster), operation = entry
test = create_slave_ok_single_test(mode, server_type, ismaster, operation)
test_name = "test_%s_%s_with_mode_%s" % (
test_name = "test_{}_{}_with_mode_{}".format(
operation.name.replace(" ", "_"),
server_type,
mode,

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test client for mod_wsgi application, see bug PYTHON-353.
"""
"""Test client for mod_wsgi application, see bug PYTHON-353."""
import _thread as thread
import sys
@ -91,14 +90,14 @@ class URLGetterThread(threading.Thread):
counter = 0
def __init__(self, options, url, nrequests_per_thread):
super(URLGetterThread, self).__init__()
super().__init__()
self.options = options
self.url = url
self.nrequests_per_thread = nrequests_per_thread
self.errors = 0
def run(self):
for i in range(self.nrequests_per_thread):
for _i in range(self.nrequests_per_thread):
try:
get(url)
except Exception as e:
@ -128,9 +127,8 @@ def main(options, mode, url):
if options.verbose:
print(
"Getting %s %s times total in %s threads, "
"%s times per thread"
% (
"Getting {} {} times total in {} threads, "
"{} times per thread".format(
url,
nrequests_per_thread * options.nthreads,
options.nthreads,
@ -154,7 +152,7 @@ def main(options, mode, url):
else:
assert mode == "serial"
if options.verbose:
print("Getting %s %s times in one thread" % (url, options.nrequests))
print(f"Getting {url} {options.nrequests} times in one thread")
for i in range(1, options.nrequests + 1):
try:

View File

@ -40,7 +40,7 @@ else:
def _connect(options):
uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS=%s&tlsCAFile=%s&%s") % (
uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS={}&tlsCAFile={}&{}").format(
TIMEOUT_MS,
CA_FILE,
options,

View File

@ -58,7 +58,7 @@ def tearDownModule():
print(output)
class Timer(object):
class Timer:
def __enter__(self):
self.start = time.monotonic()
return self
@ -68,7 +68,7 @@ class Timer(object):
self.interval = self.end - self.start
class PerformanceTest(object):
class PerformanceTest:
dataset: Any
data_size: Any
do_task: Any
@ -85,7 +85,7 @@ class PerformanceTest(object):
name = self.__class__.__name__
median = self.percentile(50)
bytes_per_sec = self.data_size / median
print("Running %s. MEDIAN=%s" % (self.__class__.__name__, self.percentile(50)))
print(f"Running {self.__class__.__name__}. MEDIAN={self.percentile(50)}")
result_data.append(
{
"info": {
@ -113,6 +113,7 @@ class PerformanceTest(object):
return sorted_results[percentile_index]
else:
self.fail("Test execution failed")
return None
def runTest(self):
results = []
@ -202,7 +203,7 @@ class TestDocument(PerformanceTest):
def setUp(self):
# Location of test data.
with open(
os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)), "r"
os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset))
) as data:
self.document = json.loads(data.read())
@ -210,7 +211,7 @@ class TestDocument(PerformanceTest):
self.client.drop_database("perftest")
def tearDown(self):
super(TestDocument, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")
def before(self):
@ -225,7 +226,7 @@ class TestFindOneByID(TestDocument, unittest.TestCase):
def setUp(self):
self.dataset = "tweet.json"
super(TestFindOneByID, self).setUp()
super().setUp()
documents = [self.document.copy() for _ in range(NUM_DOCS)]
self.corpus = self.client.perftest.corpus
@ -249,7 +250,7 @@ class TestSmallDocInsertOne(TestDocument, unittest.TestCase):
def setUp(self):
self.dataset = "small_doc.json"
super(TestSmallDocInsertOne, self).setUp()
super().setUp()
self.documents = [self.document.copy() for _ in range(NUM_DOCS)]
@ -264,7 +265,7 @@ class TestLargeDocInsertOne(TestDocument, unittest.TestCase):
def setUp(self):
self.dataset = "large_doc.json"
super(TestLargeDocInsertOne, self).setUp()
super().setUp()
self.documents = [self.document.copy() for _ in range(10)]
@ -280,7 +281,7 @@ class TestFindManyAndEmptyCursor(TestDocument, unittest.TestCase):
def setUp(self):
self.dataset = "tweet.json"
super(TestFindManyAndEmptyCursor, self).setUp()
super().setUp()
for _ in range(10):
self.client.perftest.command("insert", "corpus", documents=[self.document] * 1000)
@ -301,7 +302,7 @@ class TestSmallDocBulkInsert(TestDocument, unittest.TestCase):
def setUp(self):
self.dataset = "small_doc.json"
super(TestSmallDocBulkInsert, self).setUp()
super().setUp()
self.documents = [self.document.copy() for _ in range(NUM_DOCS)]
def before(self):
@ -316,7 +317,7 @@ class TestLargeDocBulkInsert(TestDocument, unittest.TestCase):
def setUp(self):
self.dataset = "large_doc.json"
super(TestLargeDocBulkInsert, self).setUp()
super().setUp()
self.documents = [self.document.copy() for _ in range(10)]
def before(self):
@ -342,7 +343,7 @@ class TestGridFsUpload(PerformanceTest, unittest.TestCase):
self.bucket = GridFSBucket(self.client.perftest)
def tearDown(self):
super(TestGridFsUpload, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")
def before(self):
@ -368,7 +369,7 @@ class TestGridFsDownload(PerformanceTest, unittest.TestCase):
self.uploaded_id = self.bucket.upload_from_stream("gridfstest", gfile)
def tearDown(self):
super(TestGridFsDownload, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")
def do_task(self):
@ -392,14 +393,14 @@ def mp_map(map_func, files):
def insert_json_file(filename):
assert proc_client is not None
with open(filename, "r") as data:
with open(filename) as data:
coll = proc_client.perftest.corpus
coll.insert_many([json.loads(line) for line in data])
def insert_json_file_with_file_id(filename):
documents = []
with open(filename, "r") as data:
with open(filename) as data:
for line in data:
doc = json.loads(line)
doc["file"] = filename
@ -461,7 +462,7 @@ class TestJsonMultiImport(PerformanceTest, unittest.TestCase):
self.client.perftest.drop_collection("corpus")
def tearDown(self):
super(TestJsonMultiImport, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")
@ -482,7 +483,7 @@ class TestJsonMultiExport(PerformanceTest, unittest.TestCase):
mp_map(read_json_file, self.files)
def tearDown(self):
super(TestJsonMultiExport, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")
@ -505,7 +506,7 @@ class TestGridFsMultiFileUpload(PerformanceTest, unittest.TestCase):
mp_map(insert_gridfs_file, self.files)
def tearDown(self):
super(TestGridFsMultiFileUpload, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")
@ -529,7 +530,7 @@ class TestGridFsMultiFileDownload(PerformanceTest, unittest.TestCase):
mp_map(read_gridfs_file, self.files)
def tearDown(self):
super(TestGridFsMultiFileDownload, self).tearDown()
super().tearDown()
self.client.drop_database("perftest")

View File

@ -40,7 +40,7 @@ class MockPool(Pool):
@contextlib.contextmanager
def get_socket(self, handler=None):
client = self.client
host_and_port = "%s:%s" % (self.mock_host, self.mock_port)
host_and_port = f"{self.mock_host}:{self.mock_port}"
if host_and_port in client.mock_down_hosts:
raise AutoReconnect("mock error")
@ -54,7 +54,7 @@ class MockPool(Pool):
yield sock_info
class DummyMonitor(object):
class DummyMonitor:
def __init__(self, server_description, topology, pool, topology_settings):
self._server_description = server_description
self.opened = False
@ -99,7 +99,7 @@ class MockClient(MongoClient):
arbiters=None,
down_hosts=None,
*args,
**kwargs
**kwargs,
):
"""A MongoClient connected to the default server, with a mock topology.
@ -144,7 +144,7 @@ class MockClient(MongoClient):
client_options = client_context.default_client_options.copy()
client_options.update(kwargs)
super(MockClient, self).__init__(*args, **client_options)
super().__init__(*args, **client_options)
def kill_host(self, host):
"""Host is like 'a:1'."""

View File

@ -116,7 +116,8 @@ def gen_regexp(gen_length):
# TODO our patterns only consist of one letter.
# this is because of a bug in CPython's regex equality testing,
# which I haven't quite tracked down, so I'm just ignoring it...
pattern = lambda: "".join(gen_list(choose_lifted("a"), gen_length)())
def pattern():
return "".join(gen_list(choose_lifted("a"), gen_length)())
def gen_flags():
flags = 0
@ -230,9 +231,9 @@ def check(predicate, generator):
try:
if not predicate(case):
reduction = reduce(case, predicate)
counter_examples.append("after %s reductions: %r" % reduction)
counter_examples.append("after {} reductions: {!r}".format(*reduction))
except:
counter_examples.append("%r : %s" % (case, traceback.format_exc()))
counter_examples.append(f"{case!r} : {traceback.format_exc()}")
return counter_examples

View File

@ -84,7 +84,7 @@ if __name__ == "__main__":
if len(sys.argv) != 2:
print("unknown or missing options")
print(f"usage: python3 {sys.argv[0]} 'mongodb://localhost'")
exit(1)
sys.exit(1)
# Enable logs in this format:
# 2022-03-30 12:40:55,582 INFO <ServerHeartbeatStartedEvent ('localhost', 27017)>

View File

@ -67,7 +67,7 @@ class AutoAuthenticateThread(threading.Thread):
"""
def __init__(self, collection):
super(AutoAuthenticateThread, self).__init__()
super().__init__()
self.collection = collection
self.success = False
@ -89,10 +89,10 @@ class TestGSSAPI(unittest.TestCase):
cls.service_realm_required = (
GSSAPI_SERVICE_REALM is not None and GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL
)
mech_properties = "SERVICE_NAME:%s" % (GSSAPI_SERVICE_NAME,)
mech_properties += ",CANONICALIZE_HOST_NAME:%s" % (GSSAPI_CANONICALIZE,)
mech_properties = f"SERVICE_NAME:{GSSAPI_SERVICE_NAME}"
mech_properties += f",CANONICALIZE_HOST_NAME:{GSSAPI_CANONICALIZE}"
if GSSAPI_SERVICE_REALM is not None:
mech_properties += ",SERVICE_REALM:%s" % (GSSAPI_SERVICE_REALM,)
mech_properties += f",SERVICE_REALM:{GSSAPI_SERVICE_REALM}"
cls.mech_properties = mech_properties
def test_credentials_hashing(self):
@ -111,8 +111,8 @@ class TestGSSAPI(unittest.TestCase):
"GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "B"}}, None
)
self.assertEqual(1, len(set([creds1, creds2])))
self.assertEqual(3, len(set([creds0, creds1, creds2, creds3])))
self.assertEqual(1, len({creds1, creds2}))
self.assertEqual(3, len({creds0, creds1, creds2, creds3}))
@ignore_deprecations
def test_gssapi_simple(self):
@ -160,7 +160,7 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].collection.find_one()
# Log in using URI, with authMechanismProperties.
mech_uri = uri + "&authMechanismProperties=%s" % (self.mech_properties,)
mech_uri = uri + f"&authMechanismProperties={self.mech_properties}"
client = MongoClient(mech_uri)
client[GSSAPI_DB].collection.find_one()
@ -179,7 +179,7 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].list_collection_names()
uri = uri + "&replicaSet=%s" % (str(set_name),)
uri = uri + f"&replicaSet={str(set_name)}"
client = MongoClient(uri)
client[GSSAPI_DB].list_collection_names()
@ -196,7 +196,7 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].list_collection_names()
mech_uri = mech_uri + "&replicaSet=%s" % (str(set_name),)
mech_uri = mech_uri + f"&replicaSet={str(set_name)}"
client = MongoClient(mech_uri)
client[GSSAPI_DB].list_collection_names()
@ -336,12 +336,12 @@ class TestSASLPlain(unittest.TestCase):
class TestSCRAMSHA1(IntegrationTest):
@client_context.require_auth
def setUp(self):
super(TestSCRAMSHA1, self).setUp()
super().setUp()
client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"])
def tearDown(self):
client_context.drop_user("pymongo_test", "user")
super(TestSCRAMSHA1, self).tearDown()
super().tearDown()
def test_scram_sha1(self):
host, port = client_context.host, client_context.port
@ -368,16 +368,16 @@ class TestSCRAM(IntegrationTest):
@client_context.require_auth
@client_context.require_version_min(3, 7, 2)
def setUp(self):
super(TestSCRAM, self).setUp()
super().setUp()
self._SENSITIVE_COMMANDS = monitoring._SENSITIVE_COMMANDS
monitoring._SENSITIVE_COMMANDS = set([])
monitoring._SENSITIVE_COMMANDS = set()
self.listener = AllowListEventListener("saslStart")
def tearDown(self):
monitoring._SENSITIVE_COMMANDS = self._SENSITIVE_COMMANDS
client_context.client.testscram.command("dropAllUsersFromDatabase")
client_context.client.drop_database("testscram")
super(TestSCRAM, self).tearDown()
super().tearDown()
def test_scram_skip_empty_exchange(self):
listener = AllowListEventListener("saslStart", "saslContinue")
@ -597,14 +597,14 @@ class TestSCRAM(IntegrationTest):
class TestAuthURIOptions(IntegrationTest):
@client_context.require_auth
def setUp(self):
super(TestAuthURIOptions, self).setUp()
super().setUp()
client_context.create_user("admin", "admin", "pass")
client_context.create_user("pymongo_test", "user", "pass", ["userAdmin", "readWrite"])
def tearDown(self):
client_context.drop_user("pymongo_test", "user")
client_context.drop_user("admin", "admin")
super(TestAuthURIOptions, self).tearDown()
super().tearDown()
def test_uri_options(self):
# Test default to admin

View File

@ -67,7 +67,7 @@ def create_test(test_case):
expected = credential["mechanism_properties"]
if expected is not None:
actual = credentials.mechanism_properties
for key, val in expected.items():
for key, _val in expected.items():
if "SERVICE_NAME" in expected:
self.assertEqual(actual.service_name, expected["SERVICE_NAME"])
elif "CANONICALIZE_HOST_NAME" in expected:
@ -91,7 +91,7 @@ def create_test(test_case):
actual.refresh_token_callback, expected["refresh_token_callback"]
)
else:
self.fail("Unhandled property: %s" % (key,))
self.fail(f"Unhandled property: {key}")
else:
if credential["mechanism"] == "MONGODB-AWS":
self.assertIsNone(credentials.mechanism_properties.aws_session_token)
@ -111,7 +111,7 @@ def create_tests():
continue
test_method = create_test(test_case)
name = str(test_case["description"].lower().replace(" ", "_"))
setattr(TestAuthSpec, "test_%s_%s" % (test_suffix, name), test_method)
setattr(TestAuthSpec, f"test_{test_suffix}_{name}", test_method)
create_tests()

View File

@ -122,15 +122,15 @@ class TestBinary(unittest.TestCase):
def test_repr(self):
one = Binary(b"hello world")
self.assertEqual(repr(one), "Binary(%s, 0)" % (repr(b"hello world"),))
self.assertEqual(repr(one), "Binary({}, 0)".format(repr(b"hello world")))
two = Binary(b"hello world", 2)
self.assertEqual(repr(two), "Binary(%s, 2)" % (repr(b"hello world"),))
self.assertEqual(repr(two), "Binary({}, 2)".format(repr(b"hello world")))
three = Binary(b"\x08\xFF")
self.assertEqual(repr(three), "Binary(%s, 0)" % (repr(b"\x08\xFF"),))
self.assertEqual(repr(three), "Binary({}, 0)".format(repr(b"\x08\xFF")))
four = Binary(b"\x08\xFF", 2)
self.assertEqual(repr(four), "Binary(%s, 2)" % (repr(b"\x08\xFF"),))
self.assertEqual(repr(four), "Binary({}, 2)".format(repr(b"\x08\xFF")))
five = Binary(b"test", 100)
self.assertEqual(repr(five), "Binary(%s, 100)" % (repr(b"test"),))
self.assertEqual(repr(five), "Binary({}, 100)".format(repr(b"test")))
def test_hash(self):
one = Binary(b"hello world")
@ -351,7 +351,7 @@ class TestUuidSpecExplicitCoding(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestUuidSpecExplicitCoding, cls).setUpClass()
super().setUpClass()
cls.uuid = uuid.UUID("00112233445566778899AABBCCDDEEFF")
@staticmethod
@ -452,7 +452,7 @@ class TestUuidSpecImplicitCoding(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestUuidSpecImplicitCoding, cls).setUpClass()
super().setUpClass()
cls.uuid = uuid.UUID("00112233445566778899AABBCCDDEEFF")
@staticmethod

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Copyright 2009-present MongoDB, Inc.
#
@ -370,7 +369,7 @@ class TestBSON(unittest.TestCase):
),
]
for i, data in enumerate(bad_bsons):
msg = "bad_bson[{}]".format(i)
msg = f"bad_bson[{i}]"
with self.assertRaises(InvalidBSON, msg=msg):
decode_all(data)
with self.assertRaises(InvalidBSON, msg=msg):
@ -491,7 +490,7 @@ class TestBSON(unittest.TestCase):
def test_unknown_type(self):
# Repr value differs with major python version
part = "type %r for fieldname 'foo'" % (b"\x14",)
part = "type {!r} for fieldname 'foo'".format(b"\x14")
docs = [
b"\x0e\x00\x00\x00\x14foo\x00\x01\x00\x00\x00\x00",
(b"\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x140\x00\x01\x00\x00\x00\x00\x00"),
@ -648,7 +647,7 @@ class TestBSON(unittest.TestCase):
encoded1 = encode({"x": 256})
decoded1 = decode(encoded1)["x"]
self.assertEqual(256, decoded1)
self.assertEqual(type(256), type(decoded1))
self.assertEqual(int, type(decoded1))
encoded2 = encode({"x": Int64(256)})
decoded2 = decode(encoded2)["x"]
@ -925,7 +924,7 @@ class TestBSON(unittest.TestCase):
def test_bson_encode_thread_safe(self):
def target(i):
for j in range(1000):
my_int = type("MyInt_%s_%s" % (i, j), (int,), {})
my_int = type(f"MyInt_{i}_{j}", (int,), {})
bson.encode({"my_int": my_int()})
threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)]
@ -939,7 +938,7 @@ class TestBSON(unittest.TestCase):
self.assertIsNone(t.exc)
def test_raise_invalid_document(self):
class Wrapper(object):
class Wrapper:
def __init__(self, val):
self.val = val

View File

@ -50,12 +50,12 @@ class BulkTestBase(IntegrationTest):
@classmethod
def setUpClass(cls):
super(BulkTestBase, cls).setUpClass()
super().setUpClass()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
def setUp(self):
super(BulkTestBase, self).setUp()
super().setUp()
self.coll.drop()
def assertEqualResponse(self, expected, actual):
@ -93,7 +93,7 @@ class BulkTestBase(IntegrationTest):
self.assertEqual(
actual.get(key),
value,
"%r value of %r does not match expected %r" % (key, actual.get(key), value),
f"{key!r} value of {actual.get(key)!r} does not match expected {value!r}",
)
def assertEqualUpsert(self, expected, actual):
@ -793,10 +793,10 @@ class BulkAuthorizationTestBase(BulkTestBase):
@client_context.require_auth
@client_context.require_no_api_version
def setUpClass(cls):
super(BulkAuthorizationTestBase, cls).setUpClass()
super().setUpClass()
def setUp(self):
super(BulkAuthorizationTestBase, self).setUp()
super().setUp()
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
self.db.command(
"createRole",
@ -902,7 +902,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
InsertOne({"x": 3}), # Never attempted.
]
self.assertRaises(OperationFailure, coll.bulk_write, requests)
self.assertEqual(set([1, 2]), set(self.coll.distinct("x")))
self.assertEqual({1, 2}, set(self.coll.distinct("x")))
class TestBulkWriteConcern(BulkTestBase):
@ -911,7 +911,7 @@ class TestBulkWriteConcern(BulkTestBase):
@classmethod
def setUpClass(cls):
super(TestBulkWriteConcern, cls).setUpClass()
super().setUpClass()
cls.w = client_context.w
cls.secondary = None
if cls.w is not None and cls.w > 1:

View File

@ -104,7 +104,8 @@ class TestChangeStreamBase(IntegrationTest):
def get_start_at_operation_time(self):
"""Get an operationTime. Advances the operation clock beyond the most
recently returned timestamp."""
recently returned timestamp.
"""
optime = self.client.admin.command("ping")["operationTime"]
return Timestamp(optime.time, optime.inc + 1)
@ -120,7 +121,7 @@ class TestChangeStreamBase(IntegrationTest):
client._close_cursor_now(cursor.cursor_id, address)
class APITestsMixin(object):
class APITestsMixin:
@no_type_check
def test_watch(self):
with self.change_stream(
@ -208,7 +209,7 @@ class APITestsMixin(object):
# Stream still works after a resume.
coll.insert_one({"_id": 3})
wait_until(lambda: stream.try_next() is not None, "get change from try_next")
self.assertEqual(set(listener.started_command_names()), set(["getMore"]))
self.assertEqual(set(listener.started_command_names()), {"getMore"})
self.assertIsNone(stream.try_next())
@no_type_check
@ -249,7 +250,7 @@ class APITestsMixin(object):
coll.insert_many([{"data": i} for i in range(ndocs)])
with self.change_stream(start_at_operation_time=optime) as cs:
for i in range(ndocs):
for _i in range(ndocs):
cs.next()
@no_type_check
@ -443,7 +444,7 @@ class APITestsMixin(object):
self.assertEqual(change["fullDocument"], {"_id": 2})
class ProseSpecTestsMixin(object):
class ProseSpecTestsMixin:
@no_type_check
def _client_with_listener(self, *commands):
listener = AllowListEventListener(*commands)
@ -461,7 +462,8 @@ class ProseSpecTestsMixin(object):
def _get_expected_resume_token_legacy(self, stream, listener, previous_change=None):
"""Predicts what the resume token should currently be for server
versions that don't support postBatchResumeToken. Assumes the stream
has never returned any changes if previous_change is None."""
has never returned any changes if previous_change is None.
"""
if previous_change is None:
agg_cmd = listener.started_events[0]
stage = agg_cmd.command["pipeline"][0]["$changeStream"]
@ -474,7 +476,8 @@ class ProseSpecTestsMixin(object):
versions that support postBatchResumeToken. Assumes the stream has
never returned any changes if previous_change is None. Assumes
listener is a AllowListEventListener that listens for aggregate and
getMore commands."""
getMore commands.
"""
if previous_change is None or stream._cursor._has_next():
token = self._get_expected_resume_token_legacy(stream, listener, previous_change)
if token is not None:
@ -767,14 +770,14 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def setUpClass(cls):
super(TestClusterChangeStream, cls).setUpClass()
super().setUpClass()
cls.dbs = [cls.db, cls.client.pymongo_test_2]
@classmethod
def tearDownClass(cls):
for db in cls.dbs:
cls.client.drop_database(db)
super(TestClusterChangeStream, cls).tearDownClass()
super().tearDownClass()
def change_stream_with_client(self, client, *args, **kwargs):
return client.watch(*args, **kwargs)
@ -828,7 +831,7 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def setUpClass(cls):
super(TestDatabaseChangeStream, cls).setUpClass()
super().setUpClass()
def change_stream_with_client(self, client, *args, **kwargs):
return client[self.db.name].watch(*args, **kwargs)
@ -913,7 +916,7 @@ class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecT
@classmethod
@client_context.require_change_streams
def setUpClass(cls):
super(TestCollectionChangeStream, cls).setUpClass()
super().setUpClass()
def setUp(self):
# Use a new collection for each test.
@ -1044,17 +1047,17 @@ class TestAllLegacyScenarios(IntegrationTest):
@classmethod
@client_context.require_connection
def setUpClass(cls):
super(TestAllLegacyScenarios, cls).setUpClass()
super().setUpClass()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def tearDownClass(cls):
cls.client.close()
super(TestAllLegacyScenarios, cls).tearDownClass()
super().tearDownClass()
def setUp(self):
super(TestAllLegacyScenarios, self).setUp()
super().setUp()
self.listener.reset()
def setUpCluster(self, scenario_dict):
@ -1088,7 +1091,8 @@ class TestAllLegacyScenarios(IntegrationTest):
def assert_list_contents_are_subset(self, superlist, sublist):
"""Check that each element in sublist is a subset of the corresponding
element in superlist."""
element in superlist.
"""
self.assertEqual(len(superlist), len(sublist))
for sup, sub in zip(superlist, sublist):
if isinstance(sub, dict):
@ -1104,7 +1108,7 @@ class TestAllLegacyScenarios(IntegrationTest):
exempt_fields = ["documentKey", "_id", "getMore"]
for key, value in subdict.items():
if key not in superdict:
self.fail("Key %s not found in %s" % (key, superdict))
self.fail(f"Key {key} not found in {superdict}")
if isinstance(value, dict):
self.assert_dict_is_subset(superdict[key], value)
continue

View File

@ -325,7 +325,7 @@ class ClientUnitTest(unittest.TestCase):
self.assertRaises(TypeError, MongoClient, driver=("Foo", "1", "a"))
# Test appending to driver info.
metadata["driver"]["name"] = "PyMongo|FooDriver"
metadata["driver"]["version"] = "%s|1.2.3" % (_METADATA["driver"]["version"],)
metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"])
client = MongoClient(
"foo",
27017,
@ -335,7 +335,7 @@ class ClientUnitTest(unittest.TestCase):
)
options = client._MongoClient__options
self.assertEqual(options.pool_options.metadata, metadata)
metadata["platform"] = "%s|FooPlatform" % (_METADATA["platform"],)
metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"])
client = MongoClient(
"foo",
27017,
@ -347,7 +347,7 @@ class ClientUnitTest(unittest.TestCase):
self.assertEqual(options.pool_options.metadata, metadata)
def test_kwargs_codec_options(self):
class MyFloatType(object):
class MyFloatType:
def __init__(self, x):
self.__x = x
@ -704,7 +704,7 @@ class TestClient(IntegrationTest):
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
def test_equality(self):
seed = "%s:%s" % list(self.client._topology_settings.seeds)[0]
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = rs_or_single_client(seed, connect=False)
self.addCleanup(c.close)
self.assertEqual(client_context.client, c)
@ -723,7 +723,7 @@ class TestClient(IntegrationTest):
)
def test_hashable(self):
seed = "%s:%s" % list(self.client._topology_settings.seeds)[0]
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = rs_or_single_client(seed, connect=False)
self.addCleanup(c.close)
self.assertIn(c, {client_context.client})
@ -735,7 +735,7 @@ class TestClient(IntegrationTest):
with self.assertRaises(ValueError):
connected(
MongoClient(
"%s:1234567" % (client_context.host,),
f"{client_context.host}:1234567",
connectTimeoutMS=1,
serverSelectionTimeoutMS=10,
)
@ -1002,7 +1002,7 @@ class TestClient(IntegrationTest):
@client_context.require_auth
def test_lazy_auth_raises_operation_failure(self):
lazy_client = rs_or_single_client_noauth(
"mongodb://user:wrong@%s/pymongo_test" % (client_context.host,), connect=False
f"mongodb://user:wrong@{client_context.host}/pymongo_test", connect=False
)
assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one)
@ -1160,7 +1160,7 @@ class TestClient(IntegrationTest):
raise SkipTest("Need the ipaddress module to test with SSL")
if client_context.auth_enabled:
auth_str = "%s:%s@" % (db_user, db_pwd)
auth_str = f"{db_user}:{db_pwd}@"
else:
auth_str = ""
@ -1533,7 +1533,7 @@ class TestClient(IntegrationTest):
# Continuously reset the pool.
class ResetPoolThread(threading.Thread):
def __init__(self, pool):
super(ResetPoolThread, self).__init__()
super().__init__()
self.running = True
self.pool = pool
@ -1657,7 +1657,7 @@ class TestClient(IntegrationTest):
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
):
assert client.address is not None
expected = "%s:%s: " % client.address
expected = "{}:{}: ".format(*client.address)
with self.assertRaisesRegex(AutoReconnect, expected):
client.pymongo_test.test.find_one({})
@ -1836,7 +1836,7 @@ class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""
def setUp(self):
super(TestExhaustCursor, self).setUp()
super().setUp()
if client_context.is_mongos:
raise SkipTest("mongos doesn't support exhaust, SERVER-2627")
@ -2188,23 +2188,33 @@ class TestMongoClientFailover(MockClientTest):
self.assertEqual(7, sd_b.max_wire_version)
def test_network_error_on_query(self):
callback = lambda client: client.db.collection.find_one()
def callback(client):
return client.db.collection.find_one()
self._test_network_error(callback)
def test_network_error_on_insert(self):
callback = lambda client: client.db.collection.insert_one({})
def callback(client):
return client.db.collection.insert_one({})
self._test_network_error(callback)
def test_network_error_on_update(self):
callback = lambda client: client.db.collection.update_one({}, {"$unset": "x"})
def callback(client):
return client.db.collection.update_one({}, {"$unset": "x"})
self._test_network_error(callback)
def test_network_error_on_replace(self):
callback = lambda client: client.db.collection.replace_one({}, {})
def callback(client):
return client.db.collection.replace_one({}, {})
self._test_network_error(callback)
def test_network_error_on_delete(self):
callback = lambda client: client.db.collection.delete_many({})
def callback(client):
return client.db.collection.delete_many({})
self._test_network_error(callback)
@ -2227,7 +2237,7 @@ class TestClientPool(MockClientTest):
wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(c.address, ("a", 1))
self.assertEqual(c.arbiters, set([("c", 3)]))
self.assertEqual(c.arbiters, {("c", 3)})
# Assert that we create 2 and only 2 pooled connections.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)

View File

@ -28,8 +28,9 @@ class TestClientContext(unittest.TestCase):
self.assertTrue(
client_context.connected,
"client context must be connected when "
"PYMONGO_MUST_CONNECT is set. Failed attempts:\n%s"
% (client_context.connection_attempt_info(),),
"PYMONGO_MUST_CONNECT is set. Failed attempts:\n{}".format(
client_context.connection_attempt_info()
),
)
def test_serverless(self):
@ -39,8 +40,9 @@ class TestClientContext(unittest.TestCase):
self.assertTrue(
client_context.connected and client_context.serverless,
"client context must be connected to serverless when "
"TEST_SERVERLESS is set. Failed attempts:\n%s"
% (client_context.connection_attempt_info(),),
"TEST_SERVERLESS is set. Failed attempts:\n{}".format(
client_context.connection_attempt_info()
),
)
def test_enableTestCommands_is_disabled(self):

View File

@ -116,7 +116,7 @@ class TestCMAP(IntegrationTest):
timeout = op.get("timeout", 10000) / 1000.0
wait_until(
lambda: self.listener.event_count(event) >= count,
"find %s %s event(s)" % (count, event),
f"find {count} {event} event(s)",
timeout=timeout,
)
@ -191,11 +191,11 @@ class TestCMAP(IntegrationTest):
"""Check the events of a test."""
actual_events = self.actual_events(ignore)
for actual, expected in zip(actual_events, events):
self.logs.append("Checking event actual: %r vs expected: %r" % (actual, expected))
self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}")
self.check_event(actual, expected)
if len(events) > len(actual_events):
self.fail("missing events: %r" % (events[len(actual_events) :],))
self.fail(f"missing events: {events[len(actual_events) :]!r}")
def check_error(self, actual, expected):
message = expected.pop("message")
@ -260,9 +260,9 @@ class TestCMAP(IntegrationTest):
self.pool = list(client._topology._servers.values())[0].pool
# Map of target names to Thread objects.
self.targets: dict = dict()
self.targets: dict = {}
# Map of label names to Connection objects
self.labels: dict = dict()
self.labels: dict = {}
def cleanup():
for t in self.targets.values():
@ -285,7 +285,7 @@ class TestCMAP(IntegrationTest):
self.check_events(test["events"], test["ignore"])
except Exception:
# Print the events after a test failure.
print("\nFailed test: %r" % (test["description"],))
print("\nFailed test: {!r}".format(test["description"]))
print("Operations:")
for op in self._ops:
print(op)
@ -332,8 +332,8 @@ class TestCMAP(IntegrationTest):
self.assertEqual(pool.opts, pool_opts)
def test_3_uri_connection_pool_options(self):
opts = "&".join(["%s=%s" % (k, v) for k, v in self.POOL_OPTIONS.items()])
uri = "mongodb://%s/?%s" % (client_context.pair, opts)
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
uri = f"mongodb://{client_context.pair}/?{opts}"
client = rs_or_single_client(uri)
self.addCleanup(client.close)
pool_opts = get_pool(client).opts

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Copyright 2009-present MongoDB, Inc.
#
@ -67,7 +66,7 @@ class TestCode(unittest.TestCase):
c = Code("hello world", {"blah": 3})
self.assertEqual(repr(c), "Code('hello world', {'blah': 3})")
c = Code("\x08\xFF")
self.assertEqual(repr(c), "Code(%s, None)" % (repr("\x08\xFF"),))
self.assertEqual(repr(c), "Code({}, None)".format(repr("\x08\xFF")))
def test_equality(self):
b = Code("hello")

View File

@ -96,7 +96,7 @@ class TestCollation(IntegrationTest):
@classmethod
@client_context.require_connection
def setUpClass(cls):
super(TestCollation, cls).setUpClass()
super().setUpClass()
cls.listener = EventListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
@ -110,11 +110,11 @@ class TestCollation(IntegrationTest):
cls.warn_context.__exit__()
cls.warn_context = None
cls.client.close()
super(TestCollation, cls).tearDownClass()
super().tearDownClass()
def tearDown(self):
self.listener.reset()
super(TestCollation, self).tearDown()
super().tearDown()
def last_command_started(self):
return self.listener.started_events[-1].command

View File

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -151,7 +149,7 @@ class TestCollection(IntegrationTest):
@classmethod
def setUpClass(cls):
super(TestCollection, cls).setUpClass()
super().setUpClass()
cls.w = client_context.w # type: ignore
@classmethod
@ -373,7 +371,7 @@ class TestCollection(IntegrationTest):
db.test.insert_one({}) # create collection
def map_indexes(indexes):
return dict([(index["name"], index) for index in indexes])
return {index["name"]: index for index in indexes}
indexes = list(db.test.list_indexes())
self.assertEqual(len(indexes), 1)
@ -485,7 +483,7 @@ class TestCollection(IntegrationTest):
db.test.drop_indexes()
self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)]))
for dummy, info in db.test.index_information().items():
for _dummy, info in db.test.index_information().items():
field, idx_type = info["key"][0]
if field == "geo" and idx_type == "2dsphere":
break
@ -504,7 +502,7 @@ class TestCollection(IntegrationTest):
db.test.drop_indexes()
self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)]))
for dummy, info in db.test.index_information().items():
for _dummy, info in db.test.index_information().items():
field, idx_type = info["key"][0]
if field == "a" and idx_type == "hashed":
break
@ -1638,8 +1636,8 @@ class TestCollection(IntegrationTest):
self.assertTrue("hello" in db.test.find_one(projection=("hello",)))
self.assertTrue("hello" not in db.test.find_one(projection=("foo",)))
self.assertTrue("hello" in db.test.find_one(projection=set(["hello"])))
self.assertTrue("hello" not in db.test.find_one(projection=set(["foo"])))
self.assertTrue("hello" in db.test.find_one(projection={"hello"}))
self.assertTrue("hello" not in db.test.find_one(projection={"foo"}))
self.assertTrue("hello" in db.test.find_one(projection=frozenset(["hello"])))
self.assertTrue("hello" not in db.test.find_one(projection=frozenset(["foo"])))

View File

@ -28,7 +28,7 @@ from pymongo.command_cursor import CommandCursor
from pymongo.operations import IndexModel
class Empty(object):
class Empty:
def __getattr__(self, item):
try:
self.__dict__[item]

Some files were not shown because too many files have changed in this diff Show More