From 7bcbb0de9b120fcec02202bebc904f2adcbc9a92 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:48:33 -0700 Subject: [PATCH] PYTHON-3802 add types to database.py (#1295) --- bson/codec_options.py | 2 +- pymongo/client_session.py | 7 +- pymongo/database.py | 155 ++++++++++++++++++++++++++++---------- 3 files changed, 122 insertions(+), 42 deletions(-) diff --git a/bson/codec_options.py b/bson/codec_options.py index 45860fa70..f146898df 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -485,7 +485,7 @@ else: return CodecOptions(**opts) -DEFAULT_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions() +DEFAULT_CODEC_OPTIONS: "CodecOptions[Dict[str, Any]]" = CodecOptions() def _parse_codec_options(options: Any) -> CodecOptions: diff --git a/pymongo/client_session.py b/pymongo/client_session.py index d19631866..a43982e43 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -144,6 +144,7 @@ from typing import ( Any, Callable, ContextManager, + Dict, List, Mapping, MutableMapping, @@ -831,19 +832,19 @@ class ClientSession: self._transaction.state = _TxnState.ABORTED self._unpin() - def _finish_transaction_with_retry(self, command_name: str) -> List[Any]: + def _finish_transaction_with_retry(self, command_name: str) -> Dict[str, Any]: """Run commit or abort with one retry after any retryable error. :Parameters: - `command_name`: Either "commitTransaction" or "abortTransaction". """ - def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> List[Any]: + def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]: return self._finish_transaction(sock_info, command_name) return self._client._retry_internal(True, func, self, None) - def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> List[Any]: + def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]: self._transaction.attempt += 1 opts = self._transaction.opts assert opts diff --git a/pymongo/database.py b/pymongo/database.py index 7829c28fe..f78b721b2 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -30,9 +30,10 @@ from typing import ( TypeVar, Union, cast, + overload, ) -from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.son import SON from bson.timestamp import Timestamp @@ -46,8 +47,12 @@ from pymongo.errors import CollectionInvalid, InvalidName, InvalidOperation from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +if TYPE_CHECKING: + from pymongo.pool import SocketInfo + from pymongo.server import Server -def _check_name(name): + +def _check_name(name: str) -> None: """Check if a database name is valid.""" if not name: raise InvalidName("database name cannot be the empty string") @@ -212,7 +217,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def __hash__(self) -> int: return hash((self.__client, self.__name)) - def __repr__(self): + def __repr__(self) -> str: return f"Database({self.__client!r}, {self.__name!r})" def __getattr__(self, name: str) -> Collection[_DocumentType]: @@ -295,7 +300,9 @@ class Database(common.BaseObject, Generic[_DocumentType]): read_concern, ) - def _get_encrypted_fields(self, kwargs, coll_name, ask_db): + def _get_encrypted_fields( + self, kwargs: Mapping[str, Any], coll_name: str, ask_db: bool + ) -> Optional[Mapping[str, Any]]: encrypted_fields = kwargs.get("encryptedFields") if encrypted_fields: return deepcopy(encrypted_fields) @@ -680,20 +687,56 @@ class Database(common.BaseObject, Generic[_DocumentType]): show_expanded_events=show_expanded_events, ) + @overload def _command( self, - sock_info, - command, - value=1, - check=True, - allowable_errors=None, - read_preference=ReadPreference.PRIMARY, - codec_options=DEFAULT_CODEC_OPTIONS, - write_concern=None, - parse_write_concern_error=False, - session=None, - **kwargs, - ): + sock_info: SocketInfo, + command: Union[str, MutableMapping[str, Any]], + value: int = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY, + codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS, + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + ... + + @overload + def _command( + self, + sock_info: SocketInfo, + command: Union[str, MutableMapping[str, Any]], + value: int = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY, + codec_options: CodecOptions[_CodecDocumentType] = ..., + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> _CodecDocumentType: + ... + + def _command( + self, + sock_info: SocketInfo, + command: Union[str, MutableMapping[str, Any]], + value: int = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY, + codec_options: Union[ + CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType] + ] = DEFAULT_CODEC_OPTIONS, + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> Union[Dict[str, Any], _CodecDocumentType]: """Internal command helper.""" if isinstance(command, str): command = SON([(command, value)]) @@ -713,6 +756,36 @@ class Database(common.BaseObject, Generic[_DocumentType]): client=self.__client, ) + @overload + def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: None = None, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + ... + + @overload + def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: CodecOptions[_CodecDocumentType] = ..., + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> _CodecDocumentType: + ... + @_csot.apply def command( self, @@ -725,7 +798,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> _CodecDocumentType: + ) -> Union[Dict[str, Any], _CodecDocumentType]: """Issue a MongoDB command. Send command `command` to the database and return the @@ -942,35 +1015,34 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _retryable_read_command( self, - command, - value=1, - check=True, - allowable_errors=None, - read_preference=None, - codec_options=DEFAULT_CODEC_OPTIONS, - session=None, - **kwargs, - ): + command: Union[str, MutableMapping[str, Any]], + session: Optional[ClientSession] = None, + ) -> Dict[str, Any]: """Same as command but used for retryable read commands.""" - if read_preference is None: - read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY + read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY - def _cmd(session, server, sock_info, read_preference): + def _cmd( + session: Optional[ClientSession], + server: Server, + sock_info: SocketInfo, + read_preference: _ServerMode, + ) -> Dict[str, Any]: return self._command( sock_info, command, - value, - check, - allowable_errors, - read_preference, - codec_options, + read_preference=read_preference, session=session, - **kwargs, ) return self.__client._retryable_read(_cmd, read_preference, session) - def _list_collections(self, sock_info, session, read_preference, **kwargs): + def _list_collections( + self, + sock_info: SocketInfo, + session: Optional[ClientSession], + read_preference: _ServerMode, + **kwargs: Any, + ) -> CommandCursor: """Internal listCollections helper.""" coll = self.get_collection("$cmd", read_preference=read_preference) cmd = SON([("listCollections", 1), ("cursor", {})]) @@ -1024,7 +1096,12 @@ class Database(common.BaseObject, Generic[_DocumentType]): if comment is not None: kwargs["comment"] = comment - def _cmd(session, server, sock_info, read_preference): + def _cmd( + session: Optional[ClientSession], + server: Server, + sock_info: SocketInfo, + read_preference: _ServerMode, + ) -> CommandCursor[_DocumentType]: return self._list_collections( sock_info, session, read_preference=read_preference, **kwargs ) @@ -1079,7 +1156,9 @@ class Database(common.BaseObject, Generic[_DocumentType]): return [result["name"] for result in self.list_collections(session=session, **kwargs)] - def _drop_helper(self, name, session=None, comment=None): + def _drop_helper( + self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None + ) -> Dict[str, Any]: command = SON([("drop", name)]) if comment is not None: command["comment"] = comment