PYTHON-3802 add types to database.py (#1295)

This commit is contained in:
Iris 2023-07-12 10:48:33 -07:00 committed by GitHub
parent f81cda0e22
commit 7bcbb0de9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 42 deletions

View File

@ -485,7 +485,7 @@ else:
return CodecOptions(**opts)
DEFAULT_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions()
DEFAULT_CODEC_OPTIONS: "CodecOptions[Dict[str, Any]]" = CodecOptions()
def _parse_codec_options(options: Any) -> CodecOptions:

View File

@ -144,6 +144,7 @@ from typing import (
Any,
Callable,
ContextManager,
Dict,
List,
Mapping,
MutableMapping,
@ -831,19 +832,19 @@ class ClientSession:
self._transaction.state = _TxnState.ABORTED
self._unpin()
def _finish_transaction_with_retry(self, command_name: str) -> List[Any]:
def _finish_transaction_with_retry(self, command_name: str) -> Dict[str, Any]:
"""Run commit or abort with one retry after any retryable error.
:Parameters:
- `command_name`: Either "commitTransaction" or "abortTransaction".
"""
def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> List[Any]:
def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]:
return self._finish_transaction(sock_info, command_name)
return self._client._retry_internal(True, func, self, None)
def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> List[Any]:
def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]:
self._transaction.attempt += 1
opts = self._transaction.opts
assert opts

View File

@ -30,9 +30,10 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
from bson.dbref import DBRef
from bson.son import SON
from bson.timestamp import Timestamp
@ -46,8 +47,12 @@ from pymongo.errors import CollectionInvalid, InvalidName, InvalidOperation
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
if TYPE_CHECKING:
from pymongo.pool import SocketInfo
from pymongo.server import Server
def _check_name(name):
def _check_name(name: str) -> None:
"""Check if a database name is valid."""
if not name:
raise InvalidName("database name cannot be the empty string")
@ -212,7 +217,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def __hash__(self) -> int:
return hash((self.__client, self.__name))
def __repr__(self):
def __repr__(self) -> str:
return f"Database({self.__client!r}, {self.__name!r})"
def __getattr__(self, name: str) -> Collection[_DocumentType]:
@ -295,7 +300,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
read_concern,
)
def _get_encrypted_fields(self, kwargs, coll_name, ask_db):
def _get_encrypted_fields(
self, kwargs: Mapping[str, Any], coll_name: str, ask_db: bool
) -> Optional[Mapping[str, Any]]:
encrypted_fields = kwargs.get("encryptedFields")
if encrypted_fields:
return deepcopy(encrypted_fields)
@ -680,20 +687,56 @@ class Database(common.BaseObject, Generic[_DocumentType]):
show_expanded_events=show_expanded_events,
)
@overload
def _command(
self,
sock_info,
command,
value=1,
check=True,
allowable_errors=None,
read_preference=ReadPreference.PRIMARY,
codec_options=DEFAULT_CODEC_OPTIONS,
write_concern=None,
parse_write_concern_error=False,
session=None,
**kwargs,
):
sock_info: SocketInfo,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> Dict[str, Any]:
...
@overload
def _command(
self,
sock_info: SocketInfo,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
codec_options: CodecOptions[_CodecDocumentType] = ...,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> _CodecDocumentType:
...
def _command(
self,
sock_info: SocketInfo,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
codec_options: Union[
CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType]
] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> Union[Dict[str, Any], _CodecDocumentType]:
"""Internal command helper."""
if isinstance(command, str):
command = SON([(command, value)])
@ -713,6 +756,36 @@ class Database(common.BaseObject, Generic[_DocumentType]):
client=self.__client,
)
@overload
def command(
self,
command: Union[str, MutableMapping[str, Any]],
value: Any = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: None = None,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> Dict[str, Any]:
...
@overload
def command(
self,
command: Union[str, MutableMapping[str, Any]],
value: Any = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: CodecOptions[_CodecDocumentType] = ...,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> _CodecDocumentType:
...
@_csot.apply
def command(
self,
@ -725,7 +798,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> _CodecDocumentType:
) -> Union[Dict[str, Any], _CodecDocumentType]:
"""Issue a MongoDB command.
Send command `command` to the database and return the
@ -942,35 +1015,34 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _retryable_read_command(
self,
command,
value=1,
check=True,
allowable_errors=None,
read_preference=None,
codec_options=DEFAULT_CODEC_OPTIONS,
session=None,
**kwargs,
):
command: Union[str, MutableMapping[str, Any]],
session: Optional[ClientSession] = None,
) -> Dict[str, Any]:
"""Same as command but used for retryable read commands."""
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
def _cmd(session, server, sock_info, read_preference):
def _cmd(
session: Optional[ClientSession],
server: Server,
sock_info: SocketInfo,
read_preference: _ServerMode,
) -> Dict[str, Any]:
return self._command(
sock_info,
command,
value,
check,
allowable_errors,
read_preference,
codec_options,
read_preference=read_preference,
session=session,
**kwargs,
)
return self.__client._retryable_read(_cmd, read_preference, session)
def _list_collections(self, sock_info, session, read_preference, **kwargs):
def _list_collections(
self,
sock_info: SocketInfo,
session: Optional[ClientSession],
read_preference: _ServerMode,
**kwargs: Any,
) -> CommandCursor:
"""Internal listCollections helper."""
coll = self.get_collection("$cmd", read_preference=read_preference)
cmd = SON([("listCollections", 1), ("cursor", {})])
@ -1024,7 +1096,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
kwargs["comment"] = comment
def _cmd(session, server, sock_info, read_preference):
def _cmd(
session: Optional[ClientSession],
server: Server,
sock_info: SocketInfo,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
return self._list_collections(
sock_info, session, read_preference=read_preference, **kwargs
)
@ -1079,7 +1156,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
return [result["name"] for result in self.list_collections(session=session, **kwargs)]
def _drop_helper(self, name, session=None, comment=None):
def _drop_helper(
self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None
) -> Dict[str, Any]:
command = SON([("drop", name)])
if comment is not None:
command["comment"] = comment