PYTHON-3698 Support mypy 1.2 --strict testing (part 1) (#1371)
This commit is contained in:
parent
611303ac00
commit
9b6f2e18cf
9
.github/workflows/test-python.yml
vendored
9
.github/workflows/test-python.yml
vendored
@ -83,16 +83,11 @@ jobs:
|
||||
typing:
|
||||
name: Typing Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.7', '3.11']
|
||||
fail-fast: false
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.11"
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'pyproject.toml'
|
||||
- name: Install dependencies
|
||||
|
||||
1
mypy.ini
1
mypy.ini
@ -1,4 +1,5 @@
|
||||
[mypy]
|
||||
python_version = 3.7
|
||||
check_untyped_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_incomplete_defs = true
|
||||
|
||||
@ -117,7 +117,7 @@ def has_c() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def timeout(seconds: Optional[float]) -> ContextManager:
|
||||
def timeout(seconds: Optional[float]) -> ContextManager[None]:
|
||||
"""**(Provisional)** Apply the given timeout for a block of operations.
|
||||
|
||||
.. note:: :func:`~pymongo.timeout` is currently provisional. Backwards
|
||||
|
||||
@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import AbstractContextManager
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
|
||||
|
||||
@ -59,7 +60,7 @@ def clamp_remaining(max_timeout: float) -> float:
|
||||
return min(timeout, max_timeout)
|
||||
|
||||
|
||||
class _TimeoutContext:
|
||||
class _TimeoutContext(AbstractContextManager):
|
||||
"""Internal timeout context manager.
|
||||
|
||||
Use :func:`pymongo.timeout` instead::
|
||||
@ -68,11 +69,9 @@ class _TimeoutContext:
|
||||
client.test.test.insert_one({})
|
||||
"""
|
||||
|
||||
__slots__ = ("_timeout", "_tokens")
|
||||
|
||||
def __init__(self, timeout: Optional[float]):
|
||||
self._timeout = timeout
|
||||
self._tokens: Optional[tuple[Token, Token, Token]] = None
|
||||
self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None
|
||||
|
||||
def __enter__(self) -> _TimeoutContext:
|
||||
timeout_token = TIMEOUT.set(self._timeout)
|
||||
@ -99,7 +98,7 @@ def apply(func: F) -> F:
|
||||
"""Apply the client's timeoutMS to this operation."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> F:
|
||||
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if get_timeout() is None:
|
||||
timeout = self._timeout
|
||||
if timeout is not None:
|
||||
@ -110,7 +109,9 @@ def apply(func: F) -> F:
|
||||
return cast(F, csot_wrapper)
|
||||
|
||||
|
||||
def apply_write_concern(cmd: MutableMapping, write_concern: Optional[WriteConcern]) -> None:
|
||||
def apply_write_concern(
|
||||
cmd: MutableMapping[str, Any], write_concern: Optional[WriteConcern]
|
||||
) -> None:
|
||||
"""Apply the given write concern to a command."""
|
||||
if not write_concern or write_concern.is_server_default:
|
||||
return
|
||||
|
||||
@ -23,7 +23,16 @@ import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
@ -193,10 +202,11 @@ def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> dict:
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[str, str], item.split(b"=", 1)) for item in response.split(b",")
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
@ -523,7 +533,7 @@ def _authenticate_default(credentials: MongoCredential, conn: Connection) -> Non
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable] = {
|
||||
_AUTH_MAP: Mapping[str, Callable[..., None]] = {
|
||||
"GSSAPI": _authenticate_gssapi,
|
||||
"MONGODB-CR": _authenticate_mongo_cr,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
@ -547,13 +557,13 @@ class _AuthContext:
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return spec_cls(creds, address)
|
||||
return cast(_AuthContext, spec_cls(creds, address))
|
||||
return None
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello: Hello) -> None:
|
||||
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self) -> bool:
|
||||
@ -595,7 +605,7 @@ class _OIDCContext(_AuthContext):
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
|
||||
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
|
||||
from bson import _bson_to_dict
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common
|
||||
@ -122,7 +122,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options = target.codec_options
|
||||
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
@ -425,14 +425,14 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
def __enter__(self) -> "ChangeStream":
|
||||
def __enter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class CollectionChangeStream(ChangeStream, Generic[_DocumentType]):
|
||||
class CollectionChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
@ -448,11 +448,11 @@ class CollectionChangeStream(ChangeStream, Generic[_DocumentType]):
|
||||
return _CollectionAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient:
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class DatabaseChangeStream(ChangeStream, Generic[_DocumentType]):
|
||||
class DatabaseChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
@ -468,11 +468,11 @@ class DatabaseChangeStream(ChangeStream, Generic[_DocumentType]):
|
||||
return _DatabaseAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient:
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.client
|
||||
|
||||
|
||||
class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]):
|
||||
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
|
||||
@ -24,14 +24,14 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
# The maximum amount of time to wait for the intermediate subprocess.
|
||||
_WAIT_TIMEOUT = 10
|
||||
_THIS_FILE = os.path.realpath(__file__)
|
||||
|
||||
|
||||
def _popen_wait(popen: subprocess.Popen, timeout: Optional[float]) -> Optional[int]:
|
||||
def _popen_wait(popen: subprocess.Popen[Any], timeout: Optional[float]) -> Optional[int]:
|
||||
"""Implement wait timeout support for Python 3."""
|
||||
try:
|
||||
return popen.wait(timeout=timeout)
|
||||
@ -40,7 +40,7 @@ def _popen_wait(popen: subprocess.Popen, timeout: Optional[float]) -> Optional[i
|
||||
return None
|
||||
|
||||
|
||||
def _silence_resource_warning(popen: Optional[subprocess.Popen]) -> None:
|
||||
def _silence_resource_warning(popen: Optional[subprocess.Popen[Any]]) -> None:
|
||||
"""Silence Popen's ResourceWarning.
|
||||
|
||||
Note this should only be used if the process was created as a daemon.
|
||||
@ -89,7 +89,7 @@ else:
|
||||
# to be safe to call from any thread. Using Popen instead of fork also
|
||||
# avoids triggering the application's os.register_at_fork() callbacks when
|
||||
# we spawn the mongocryptd daemon process.
|
||||
def _spawn(args: Sequence[str]) -> Optional[subprocess.Popen]:
|
||||
def _spawn(args: Sequence[str]) -> Optional[subprocess.Popen[Any]]:
|
||||
"""Spawn the process and silence stdout/stderr."""
|
||||
try:
|
||||
with open(os.devnull, "r+b") as devnull:
|
||||
|
||||
@ -46,8 +46,14 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import bson
|
||||
import bson.codec_options
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.server import Server
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
def _check_name(name: str) -> None:
|
||||
@ -60,15 +66,6 @@ def _check_name(name: str) -> None:
|
||||
raise InvalidName("database names cannot contain the character %r" % invalid_char)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import bson
|
||||
import bson.codec_options
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
_CodecDocumentType = TypeVar("_CodecDocumentType", bound=Mapping[str, Any])
|
||||
|
||||
|
||||
@ -162,7 +159,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
) -> "Database[_DocumentType]":
|
||||
) -> Database[_DocumentType]:
|
||||
"""Get a clone of this database changing the specified settings.
|
||||
|
||||
>>> db1.read_preference
|
||||
@ -302,7 +299,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
encrypted_fields = kwargs.get("encryptedFields")
|
||||
if encrypted_fields:
|
||||
return deepcopy(encrypted_fields)
|
||||
return cast(Mapping[str, Any], deepcopy(encrypted_fields))
|
||||
if (
|
||||
self.client.options.auto_encryption_opts
|
||||
and self.client.options.auto_encryption_opts._encrypted_fields_map
|
||||
@ -310,15 +307,18 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
f"{self.name}.{coll_name}"
|
||||
)
|
||||
):
|
||||
return deepcopy(
|
||||
self.client.options.auto_encryption_opts._encrypted_fields_map[
|
||||
f"{self.name}.{coll_name}"
|
||||
]
|
||||
return cast(
|
||||
Mapping[str, Any],
|
||||
deepcopy(
|
||||
self.client.options.auto_encryption_opts._encrypted_fields_map[
|
||||
f"{self.name}.{coll_name}"
|
||||
]
|
||||
),
|
||||
)
|
||||
if ask_db and self.client.options.auto_encryption_opts:
|
||||
options = self[coll_name].options()
|
||||
if options.get("encryptedFields"):
|
||||
return deepcopy(options["encryptedFields"])
|
||||
return cast(Mapping[str, Any], deepcopy(options["encryptedFields"]))
|
||||
return None
|
||||
|
||||
@_csot.apply
|
||||
@ -914,7 +914,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
comment: Optional[Any] = None,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> CommandCursor:
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
"""Issue a MongoDB command and parse the response as a cursor.
|
||||
|
||||
If the response from the server does not include a cursor field, an error will be thrown.
|
||||
@ -1298,7 +1298,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
if background is not None:
|
||||
cmd["background"] = background
|
||||
|
||||
result = cast(dict, self.command(cmd, session=session))
|
||||
result = self.command(cmd, session=session)
|
||||
|
||||
valid = True
|
||||
# Pre 1.9 results
|
||||
|
||||
@ -29,6 +29,7 @@ from typing import (
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -71,7 +72,7 @@ from pymongo.pool import PoolOptions, _configured_socket, _raise_connection_fail
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.typings import _DocumentType
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
@ -83,7 +84,9 @@ _KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to C
|
||||
_MONGOCRYPTD_TIMEOUT_MS = 10000
|
||||
|
||||
|
||||
_DATA_KEY_OPTS: CodecOptions = CodecOptions(document_class=SON, uuid_representation=STANDARD)
|
||||
_DATA_KEY_OPTS: CodecOptions[SON[str, Any]] = CodecOptions(
|
||||
document_class=SON[str, Any], uuid_representation=STANDARD
|
||||
)
|
||||
# Use RawBSONDocument codec options to avoid needlessly decoding
|
||||
# documents from the key vault.
|
||||
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
|
||||
@ -105,9 +108,9 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[MongoClient],
|
||||
key_vault_coll: Collection,
|
||||
mongocryptd_client: Optional[MongoClient],
|
||||
client: Optional[MongoClient[_DocumentTypeArg]],
|
||||
key_vault_coll: Collection[_DocumentTypeArg],
|
||||
mongocryptd_client: Optional[MongoClient[_DocumentTypeArg]],
|
||||
opts: AutoEncryptionOpts,
|
||||
):
|
||||
"""Internal class to perform I/O on behalf of pymongocrypt."""
|
||||
@ -117,10 +120,13 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
self.client_ref = weakref.ref(client)
|
||||
else:
|
||||
self.client_ref = None
|
||||
self.key_vault_coll: Optional[Collection] = key_vault_coll.with_options(
|
||||
codec_options=_KEY_VAULT_OPTS,
|
||||
read_concern=ReadConcern(level="majority"),
|
||||
write_concern=WriteConcern(w="majority"),
|
||||
self.key_vault_coll: Optional[Collection[RawBSONDocument]] = cast(
|
||||
Collection[RawBSONDocument],
|
||||
key_vault_coll.with_options(
|
||||
codec_options=_KEY_VAULT_OPTS,
|
||||
read_concern=ReadConcern(level="majority"),
|
||||
write_concern=WriteConcern(w="majority"),
|
||||
),
|
||||
)
|
||||
self.mongocryptd_client = mongocryptd_client
|
||||
self.opts = opts
|
||||
@ -180,7 +186,9 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
|
||||
def collection_info(self, database: Database, filter: bytes) -> Optional[bytes]:
|
||||
def collection_info(
|
||||
self, database: Database[Mapping[str, Any]], filter: bytes
|
||||
) -> Optional[bytes]:
|
||||
"""Get the collection info for a namespace.
|
||||
|
||||
The returned collection info is passed to libmongocrypt which reads
|
||||
@ -322,7 +330,7 @@ class _Encrypter:
|
||||
MongoDB commands.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MongoClient, opts: AutoEncryptionOpts):
|
||||
def __init__(self, client: MongoClient[_DocumentTypeArg], opts: AutoEncryptionOpts):
|
||||
"""Create a _Encrypter for a client.
|
||||
|
||||
:Parameters:
|
||||
@ -341,7 +349,9 @@ class _Encrypter:
|
||||
self._bypass_auto_encryption = opts._bypass_auto_encryption
|
||||
self._internal_client = None
|
||||
|
||||
def _get_internal_client(encrypter: _Encrypter, mongo_client: MongoClient) -> MongoClient:
|
||||
def _get_internal_client(
|
||||
encrypter: _Encrypter, mongo_client: MongoClient[_DocumentTypeArg]
|
||||
) -> MongoClient[_DocumentTypeArg]:
|
||||
if mongo_client.options.pool_options.max_pool_size is None:
|
||||
# Unlimited pool size, use the same client.
|
||||
return mongo_client
|
||||
@ -365,11 +375,13 @@ class _Encrypter:
|
||||
db, coll = opts._key_vault_namespace.split(".", 1)
|
||||
key_vault_coll = key_vault_client[db][coll]
|
||||
|
||||
mongocryptd_client: MongoClient = MongoClient(
|
||||
mongocryptd_client: MongoClient[Mapping[str, Any]] = MongoClient(
|
||||
opts._mongocryptd_uri, connect=False, serverSelectionTimeoutMS=_MONGOCRYPTD_TIMEOUT_MS
|
||||
)
|
||||
|
||||
io_callbacks = _EncryptionIO(metadata_client, key_vault_coll, mongocryptd_client, opts)
|
||||
io_callbacks = _EncryptionIO(
|
||||
metadata_client, key_vault_coll, mongocryptd_client, opts
|
||||
) # type:ignore[misc]
|
||||
self._auto_encrypter = AutoEncrypter(
|
||||
io_callbacks,
|
||||
MongoCryptOptions(
|
||||
@ -385,8 +397,8 @@ class _Encrypter:
|
||||
self._closed = False
|
||||
|
||||
def encrypt(
|
||||
self, database: str, cmd: Mapping[str, Any], codec_options: CodecOptions
|
||||
) -> MutableMapping[Any, Any]:
|
||||
self, database: str, cmd: Mapping[str, Any], codec_options: CodecOptions[_DocumentTypeArg]
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Encrypt a MongoDB command.
|
||||
|
||||
:Parameters:
|
||||
@ -416,7 +428,7 @@ class _Encrypter:
|
||||
"""
|
||||
self._check_closed()
|
||||
with _wrap_encryption_errors():
|
||||
return self._auto_encrypter.decrypt(response)
|
||||
return cast(bytes, self._auto_encrypter.decrypt(response))
|
||||
|
||||
def _check_closed(self) -> None:
|
||||
if self._closed:
|
||||
@ -482,8 +494,8 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: MongoClient,
|
||||
codec_options: CodecOptions,
|
||||
key_vault_client: MongoClient[_DocumentTypeArg],
|
||||
codec_options: CodecOptions[_DocumentTypeArg],
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Explicit client-side field level encryption.
|
||||
@ -583,17 +595,18 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
self._io_callbacks, MongoCryptOptions(kms_providers, None)
|
||||
)
|
||||
# Use the same key vault collection as the callback.
|
||||
assert self._io_callbacks.key_vault_coll is not None
|
||||
self._key_vault_coll = self._io_callbacks.key_vault_coll
|
||||
|
||||
def create_encrypted_collection(
|
||||
self,
|
||||
database: Database,
|
||||
database: Database[_DocumentTypeArg],
|
||||
name: str,
|
||||
encrypted_fields: Mapping[str, Any],
|
||||
kms_provider: Optional[str] = None,
|
||||
master_key: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Collection[_DocumentType], Mapping[str, Any]]:
|
||||
) -> tuple[Collection[_DocumentTypeArg], Mapping[str, Any]]:
|
||||
"""Create a collection with encryptedFields.
|
||||
|
||||
.. warning::
|
||||
@ -748,11 +761,14 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
"""
|
||||
self._check_closed()
|
||||
with _wrap_encryption_errors():
|
||||
return self._encryption.create_data_key(
|
||||
kms_provider,
|
||||
master_key=master_key,
|
||||
key_alt_names=key_alt_names,
|
||||
key_material=key_material,
|
||||
return cast(
|
||||
Binary,
|
||||
self._encryption.create_data_key(
|
||||
kms_provider,
|
||||
master_key=master_key,
|
||||
key_alt_names=key_alt_names,
|
||||
key_material=key_material,
|
||||
),
|
||||
)
|
||||
|
||||
def _encrypt_helper(
|
||||
@ -831,15 +847,18 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
.. versionchanged:: 4.2
|
||||
Added the `query_type` and `contention_factor` parameters.
|
||||
"""
|
||||
return self._encrypt_helper(
|
||||
value=value,
|
||||
algorithm=algorithm,
|
||||
key_id=key_id,
|
||||
key_alt_name=key_alt_name,
|
||||
query_type=query_type,
|
||||
contention_factor=contention_factor,
|
||||
range_opts=range_opts,
|
||||
is_expression=False,
|
||||
return cast(
|
||||
Binary,
|
||||
self._encrypt_helper(
|
||||
value=value,
|
||||
algorithm=algorithm,
|
||||
key_id=key_id,
|
||||
key_alt_name=key_alt_name,
|
||||
query_type=query_type,
|
||||
contention_factor=contention_factor,
|
||||
range_opts=range_opts,
|
||||
is_expression=False,
|
||||
),
|
||||
)
|
||||
|
||||
def encrypt_expression(
|
||||
@ -878,15 +897,18 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
return self._encrypt_helper(
|
||||
value=expression,
|
||||
algorithm=algorithm,
|
||||
key_id=key_id,
|
||||
key_alt_name=key_alt_name,
|
||||
query_type=query_type,
|
||||
contention_factor=contention_factor,
|
||||
range_opts=range_opts,
|
||||
is_expression=True,
|
||||
return cast(
|
||||
RawBSONDocument,
|
||||
self._encrypt_helper(
|
||||
value=expression,
|
||||
algorithm=algorithm,
|
||||
key_id=key_id,
|
||||
key_alt_name=key_alt_name,
|
||||
query_type=query_type,
|
||||
contention_factor=contention_factor,
|
||||
range_opts=range_opts,
|
||||
is_expression=True,
|
||||
),
|
||||
)
|
||||
|
||||
def decrypt(self, value: Binary) -> Any:
|
||||
@ -1087,7 +1109,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
result = self._key_vault_coll.bulk_write(replacements)
|
||||
return RewrapManyDataKeyResult(result)
|
||||
|
||||
def __enter__(self) -> "ClientEncryption":
|
||||
def __enter__(self) -> ClientEncryption[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
|
||||
@ -30,6 +30,7 @@ from pymongo.uri_parser import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.typings import _DocumentTypeArg
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
@ -39,7 +40,7 @@ class AutoEncryptionOpts:
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[MongoClient] = None,
|
||||
key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
@ -50,7 +51,7 @@ class AutoEncryptionOpts:
|
||||
crypt_shared_lib_path: Optional[str] = None,
|
||||
crypt_shared_lib_required: bool = False,
|
||||
bypass_query_analysis: bool = False,
|
||||
encrypted_fields_map: Optional[Mapping] = None,
|
||||
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
|
||||
@ -100,11 +100,11 @@ class AutoReconnect(ConnectionFailure):
|
||||
Subclass of :exc:`~pymongo.errors.ConnectionFailure`.
|
||||
"""
|
||||
|
||||
errors: Union[Mapping[str, Any], Sequence]
|
||||
details: Union[Mapping[str, Any], Sequence]
|
||||
errors: Union[Mapping[str, Any], Sequence[Any]]
|
||||
details: Union[Mapping[str, Any], Sequence[Any]]
|
||||
|
||||
def __init__(
|
||||
self, message: str = "", errors: Optional[Union[Mapping[str, Any], Sequence]] = None
|
||||
self, message: str = "", errors: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None
|
||||
) -> None:
|
||||
error_labels = None
|
||||
if errors is not None:
|
||||
@ -128,7 +128,9 @@ class NetworkTimeout(AutoReconnect):
|
||||
return True
|
||||
|
||||
|
||||
def _format_detailed_error(message: str, details: Optional[Union[Mapping[str, Any], list]]) -> str:
|
||||
def _format_detailed_error(
|
||||
message: str, details: Optional[Union[Mapping[str, Any], list[Any]]]
|
||||
) -> str:
|
||||
if details is not None:
|
||||
message = f"{message}, full error: {details}"
|
||||
return message
|
||||
@ -151,7 +153,7 @@ class NotPrimaryError(AutoReconnect):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "", errors: Optional[Union[Mapping[str, Any], list]] = None
|
||||
self, message: str = "", errors: Optional[Union[Mapping[str, Any], list[Any]]] = None
|
||||
) -> None:
|
||||
super().__init__(_format_detailed_error(message, errors), errors=errors)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ import weakref
|
||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||
|
||||
# References to instances of _create_lock
|
||||
_forkable_locks: weakref.WeakSet = weakref.WeakSet()
|
||||
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
||||
|
||||
|
||||
def _create_lock() -> threading.Lock:
|
||||
|
||||
@ -843,7 +843,7 @@ class Connection:
|
||||
if creds:
|
||||
self.negotiated_mechs = hello.sasl_supported_mechs
|
||||
if auth_ctx:
|
||||
auth_ctx.parse_response(hello)
|
||||
auth_ctx.parse_response(hello) # type:ignore[arg-type]
|
||||
if auth_ctx.speculate_succeeded():
|
||||
self.auth_ctx = auth_ctx
|
||||
if self.opts.load_balanced:
|
||||
|
||||
@ -83,7 +83,7 @@ class InsertManyResult(_WriteResult):
|
||||
super().__init__(acknowledged)
|
||||
|
||||
@property
|
||||
def inserted_ids(self) -> list:
|
||||
def inserted_ids(self) -> list[Any]:
|
||||
"""A list of _ids of the inserted documents, in the order provided.
|
||||
|
||||
.. note:: If ``False`` is passed for the `ordered` parameter to
|
||||
|
||||
@ -22,7 +22,7 @@ try:
|
||||
except ImportError:
|
||||
HAVE_STRINGPREP = False
|
||||
|
||||
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str:
|
||||
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> Any:
|
||||
"""SASLprep dummy"""
|
||||
if isinstance(data, str):
|
||||
raise TypeError(
|
||||
@ -51,7 +51,7 @@ else:
|
||||
stringprep.in_table_c9,
|
||||
)
|
||||
|
||||
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str:
|
||||
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> Any:
|
||||
"""An implementation of RFC4013 SASLprep.
|
||||
|
||||
:Parameters:
|
||||
|
||||
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import errno
|
||||
import select
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
|
||||
# https://bugs.jython.org/issue2900
|
||||
@ -28,9 +28,9 @@ _SelectError = getattr(select, "error", OSError)
|
||||
|
||||
def _errno_from_exception(exc: BaseException) -> Optional[int]:
|
||||
if hasattr(exc, "errno"):
|
||||
return exc.errno
|
||||
return cast(int, exc.errno)
|
||||
if exc.args:
|
||||
return exc.args[0]
|
||||
return cast(int, exc.args[0])
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user