MOTOR-1187 Fix Motor Types (#237)
* MOTOR-1187 Fix Motor Types * make searchindexmodel optional * fix tests
This commit is contained in:
parent
c644cdd9a6
commit
f0c9ff3e48
@ -21,7 +21,7 @@ from ._version import get_version_string, version, version_tuple # noqa: F401
|
||||
try:
|
||||
import tornado
|
||||
except ImportError:
|
||||
tornado = None
|
||||
tornado = None # type:ignore[assignment]
|
||||
else:
|
||||
# For backwards compatibility with Motor 0.4, export Motor's Tornado classes
|
||||
# at module root. This may change in the future.
|
||||
|
||||
@ -28,6 +28,8 @@ import gridfs
|
||||
from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorGridFSBucket
|
||||
from motor.motor_gridfs import _hash_gridout
|
||||
|
||||
# mypy: disable-error-code="no-untyped-def,no-untyped-call"
|
||||
|
||||
|
||||
def get_gridfs_file(bucket, filename, request):
|
||||
"""Override to choose a GridFS file to serve at a URL.
|
||||
|
||||
@ -22,6 +22,7 @@ from typing import (
|
||||
Callable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
Generic,
|
||||
Iterable,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
@ -30,6 +31,7 @@ from typing import (
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import pymongo.common
|
||||
@ -42,12 +44,13 @@ from pymongo import IndexModel, ReadPreference, WriteConcern
|
||||
from pymongo.change_stream import ChangeStream
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.client_session import _T, ClientSession, SessionOptions, TransactionOptions
|
||||
from pymongo.collection import ReturnDocument, _IndexKeyHint, _IndexList, _WriteOp
|
||||
from pymongo.collection import ReturnDocument, _WriteOp
|
||||
from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor
|
||||
from pymongo.cursor import Cursor, RawBatchCursor, _Hint, _Sort
|
||||
from pymongo.database import Database
|
||||
from pymongo.encryption import ClientEncryption, RewrapManyDataKeyResult
|
||||
from pymongo.encryption_options import RangeOpts
|
||||
from pymongo.operations import _IndexKeyHint, _IndexList
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.results import (
|
||||
@ -67,9 +70,9 @@ from pymongo.typings import (
|
||||
)
|
||||
|
||||
try:
|
||||
from pymongo import SearchIndexModel
|
||||
from pymongo.operations import SearchIndexModel
|
||||
except ImportError:
|
||||
SearchIndexModel = Any
|
||||
SearchIndexModel = Any # type:ignore[misc,assignment]
|
||||
|
||||
_WITH_TRANSACTION_RETRY_TIME_LIMIT: int
|
||||
|
||||
@ -85,15 +88,15 @@ class AgnosticBase:
|
||||
def __init__(self, delegate: Any) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
class AgnosticBaseProperties(AgnosticBase):
|
||||
codec_options: CodecOptions
|
||||
class AgnosticBaseProperties(AgnosticBase, Generic[_DocumentType]):
|
||||
codec_options: CodecOptions[_DocumentType]
|
||||
read_preference: _ServerMode
|
||||
read_concern: ReadConcern
|
||||
write_concern: WriteConcern
|
||||
|
||||
class AgnosticClient(AgnosticBaseProperties):
|
||||
class AgnosticClient(AgnosticBaseProperties[_DocumentType]):
|
||||
__motor_class_name__: str
|
||||
__delegate_class__: type[pymongo.MongoClient]
|
||||
__delegate_class__: type[pymongo.MongoClient[_DocumentType]]
|
||||
|
||||
def address(self) -> Optional[tuple[str, int]]: ...
|
||||
def arbiters(self) -> set[tuple[str, int]]: ...
|
||||
@ -226,9 +229,9 @@ class AgnosticClientSession(AgnosticBase):
|
||||
def __enter__(self) -> None: ...
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
|
||||
|
||||
class AgnosticDatabase(AgnosticBaseProperties):
|
||||
class AgnosticDatabase(AgnosticBaseProperties[_DocumentType]):
|
||||
__motor_class_name__: str
|
||||
__delegate_class__: type[Database]
|
||||
__delegate_class__: type[Database[_DocumentType]]
|
||||
|
||||
def __hash__(self) -> int: ...
|
||||
def __bool__(self) -> int: ...
|
||||
@ -243,6 +246,33 @@ class AgnosticDatabase(AgnosticBaseProperties):
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgnosticCommandCursor: ...
|
||||
@overload
|
||||
async def command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: Any = ...,
|
||||
check: bool = ...,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = ...,
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
codec_options: None = ...,
|
||||
session: Optional[AgnosticClientSession] = ...,
|
||||
comment: Optional[Any] = ...,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]: ...
|
||||
@overload
|
||||
async def command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: Any = 1,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: CodecOptions[_CodecDocumentType] = ...,
|
||||
session: Optional[AgnosticClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> _CodecDocumentType: ...
|
||||
@overload
|
||||
async def command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
@ -280,7 +310,7 @@ class AgnosticDatabase(AgnosticBaseProperties):
|
||||
comment: Optional[Any] = None,
|
||||
encrypted_fields: Optional[Mapping[str, Any]] = None,
|
||||
) -> dict[str, Any]: ...
|
||||
async def get_collection(
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
|
||||
@ -302,6 +332,7 @@ class AgnosticDatabase(AgnosticBaseProperties):
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgnosticCommandCursor: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
async def validate_collection(
|
||||
self,
|
||||
@ -349,9 +380,9 @@ class AgnosticDatabase(AgnosticBaseProperties):
|
||||
def wrap(self, obj: Any) -> Any: ...
|
||||
def get_io_loop(self) -> Any: ...
|
||||
|
||||
class AgnosticCollection(AgnosticBaseProperties):
|
||||
class AgnosticCollection(AgnosticBaseProperties[_DocumentType]):
|
||||
__motor_class_name__: str
|
||||
__delegate_class__: type[Collection]
|
||||
__delegate_class__: type[Collection[_DocumentType]]
|
||||
|
||||
def __hash__(self) -> int: ...
|
||||
def __bool__(self) -> bool: ...
|
||||
@ -495,6 +526,7 @@ class AgnosticCollection(AgnosticBaseProperties):
|
||||
session: Optional[AgnosticClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertOneResult: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
async def options(
|
||||
self, session: Optional[AgnosticClientSession] = None, comment: Optional[Any] = None
|
||||
@ -651,8 +683,10 @@ class AgnosticBaseCursor(AgnosticBase):
|
||||
def next_object(self) -> Any: ...
|
||||
def each(self, callback: Callable) -> None: ...
|
||||
def _each_got_more(self, callback: Callable, future: Any) -> None: ...
|
||||
def to_list(self, length: int) -> Future[list]: ...
|
||||
def _to_list(self, length: int, the_list: list, future: Any, get_more_result: Any) -> None: ...
|
||||
def to_list(self, length: Optional[int]) -> Future[list]: ...
|
||||
def _to_list(
|
||||
self, length: Optional[int], the_list: list, future: Any, get_more_result: Any
|
||||
) -> None: ...
|
||||
def get_io_loop(self) -> Any: ...
|
||||
def batch_size(self, batch_size: int) -> AgnosticBaseCursor: ...
|
||||
def _buffer_size(self) -> int: ...
|
||||
|
||||
@ -25,6 +25,8 @@ import warnings
|
||||
from asyncio import get_event_loop # noqa: F401 - For framework interface.
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
try:
|
||||
import contextvars
|
||||
except ImportError:
|
||||
|
||||
@ -32,6 +32,7 @@ try:
|
||||
except ImportError:
|
||||
contextvars = None
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
CLASS_PREFIX = ""
|
||||
|
||||
|
||||
@ -19,6 +19,8 @@ from typing import Any, Callable, Dict, TypeVar
|
||||
|
||||
_class_cache: Dict[Any, Any] = {}
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
def asynchronize(framework, sync_method: Callable, doc=None, wrap_class=None, unwrap_class=None):
|
||||
"""Decorate `sync_method` so it returns a Future.
|
||||
|
||||
@ -24,6 +24,8 @@ import tornado.web
|
||||
import motor
|
||||
from motor.motor_gridfs import _hash_gridout
|
||||
|
||||
# mypy: disable-error-code="no-untyped-def,no-untyped-call"
|
||||
|
||||
# TODO: this class is not a drop-in replacement for StaticFileHandler.
|
||||
# StaticFileHandler provides class method make_static_url, which appends
|
||||
# an checksum of the static file's contents. Templates thus can do
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Test Motor, an asynchronous driver for MongoDB and Tornado."""
|
||||
|
||||
from test.test_environment import CLIENT_PEM, db_user, env # noqa: F401
|
||||
from typing import Any
|
||||
from unittest import SkipTest # noqa: F401
|
||||
|
||||
try:
|
||||
@ -34,11 +35,11 @@ except ImportError:
|
||||
class MockRequestHandler:
|
||||
"""For testing MotorGridOut.stream_to_handler."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.n_written = 0
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data: Any) -> None:
|
||||
self.n_written += len(data)
|
||||
|
||||
def flush(self):
|
||||
def flush(self) -> None:
|
||||
pass
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
import collections
|
||||
import logging
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
_LoggingWatcher = collections.namedtuple("_LoggingWatcher", ["records", "output"])
|
||||
|
||||
|
||||
|
||||
@ -28,6 +28,9 @@ from unittest import SkipTest
|
||||
from mockupdb import MockupDB
|
||||
|
||||
from motor import motor_asyncio
|
||||
from motor.core import AgnosticClient
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
class _TestMethodWrapper:
|
||||
@ -104,7 +107,7 @@ class AsyncIOTestCase(AssertLogsMixin, unittest.TestCase):
|
||||
kwargs.setdefault("tls", env.mongod_started_with_ssl)
|
||||
return kwargs
|
||||
|
||||
def asyncio_client(self, uri=None, *args, set_loop=True, **kwargs):
|
||||
def asyncio_client(self, uri=None, *args, set_loop=True, **kwargs) -> AgnosticClient:
|
||||
"""Get an AsyncIOMotorClient.
|
||||
|
||||
Ignores self.ssl, you must pass 'ssl' argument.
|
||||
|
||||
@ -24,6 +24,9 @@ from unittest import SkipTest
|
||||
|
||||
import pymongo.errors
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
HAVE_SSL = True
|
||||
try:
|
||||
import ssl
|
||||
|
||||
@ -17,7 +17,7 @@ sample client code that uses Motor typings.
|
||||
"""
|
||||
import unittest
|
||||
from test.asyncio_tests import AsyncIOTestCase, asyncio_test
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, TypeVar, Union, cast
|
||||
|
||||
from bson import CodecOptions
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
@ -65,7 +65,7 @@ def only_type_check(func: FuncT) -> FuncT:
|
||||
return cast(FuncT, inner)
|
||||
|
||||
|
||||
class TestMotor(AsyncIOTestCase): # type:ignore[misc]
|
||||
class TestMotor(AsyncIOTestCase):
|
||||
cx: AgnosticClient
|
||||
|
||||
@asyncio_test # type:ignore[misc]
|
||||
@ -89,6 +89,20 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc]
|
||||
docs = await cursor.to_list(None)
|
||||
self.assertTrue(docs)
|
||||
|
||||
@asyncio_test # type:ignore[misc]
|
||||
async def test_get_collection(self) -> None:
|
||||
coll = self.db.get_collection("test_collection")
|
||||
self.assertEqual(coll.name, "test_collection")
|
||||
|
||||
@asyncio_test # type:ignore[misc]
|
||||
async def test_find_one(self) -> None:
|
||||
c: AgnosticClient[Movie] = self.asyncio_client()
|
||||
coll = c[self.db.name]["movies"]
|
||||
await coll.insert_one(Movie(name="American Graffiti", year=1973))
|
||||
result = await coll.find_one({})
|
||||
assert result is not None
|
||||
self.assertEqual(result["year"], 1973)
|
||||
|
||||
@only_type_check
|
||||
@asyncio_test # type:ignore[misc]
|
||||
async def test_bulk_write(self) -> None:
|
||||
@ -123,7 +137,7 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc]
|
||||
InsertOne(Movie(name="American Graffiti", year=1973)),
|
||||
ReplaceOne(
|
||||
{},
|
||||
{"name": "American Graffiti", "year": "WRONG_TYPE"},
|
||||
{"name": "American Graffiti", "year": "WRONG_TYPE"}, # type:ignore[typeddict-item]
|
||||
),
|
||||
DeleteOne({}),
|
||||
]
|
||||
@ -138,13 +152,13 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc]
|
||||
@asyncio_test # type:ignore[misc]
|
||||
async def test_list_collections(self) -> None:
|
||||
cursor = await self.cx.test.list_collections()
|
||||
value = await cursor.next()
|
||||
value: Mapping[str, Any] = await cursor.next()
|
||||
value.items()
|
||||
|
||||
@asyncio_test # type:ignore[misc]
|
||||
async def test_list_databases(self) -> None:
|
||||
cursor = await self.cx.list_databases()
|
||||
value = await cursor.next()
|
||||
value: Mapping[str, Any] = await cursor.next()
|
||||
value.items()
|
||||
|
||||
@asyncio_test # type:ignore[misc]
|
||||
@ -193,14 +207,14 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc]
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentType(AsyncIOTestCase): # type:ignore[misc]
|
||||
class TestDocumentType(AsyncIOTestCase):
|
||||
@only_type_check
|
||||
def test_typeddict_explicit_document_type(self) -> None:
|
||||
out = MovieWithId(_id=ObjectId(), name="THX-1138", year=1971)
|
||||
assert out is not None
|
||||
# This should fail because the output is a Movie.
|
||||
assert out["foo"] # type:ignore[typeddict-item]
|
||||
assert out["_id"]
|
||||
assert bool(out["_id"])
|
||||
|
||||
# This should work the same as the test above, but this time using NotRequired to allow
|
||||
# automatic insertion of the _id field by insert_one.
|
||||
@ -211,7 +225,7 @@ class TestDocumentType(AsyncIOTestCase): # type:ignore[misc]
|
||||
# This should fail because the output is a Movie.
|
||||
assert out["foo"] # type:ignore[typeddict-item]
|
||||
# pyright gives reportTypedDictNotRequiredAccess for the following:
|
||||
assert out["_id"]
|
||||
assert bool(out["_id"])
|
||||
|
||||
@only_type_check
|
||||
def test_typeddict_empty_document_type(self) -> None:
|
||||
@ -223,7 +237,7 @@ class TestDocumentType(AsyncIOTestCase): # type:ignore[misc]
|
||||
assert out["_id"] # type:ignore[typeddict-item]
|
||||
|
||||
|
||||
class TestCommandDocumentType(AsyncIOTestCase): # type:ignore[misc]
|
||||
class TestCommandDocumentType(AsyncIOTestCase):
|
||||
@only_type_check
|
||||
async def test_default(self) -> None:
|
||||
client: AgnosticClient = AgnosticClient()
|
||||
@ -249,14 +263,14 @@ class TestCommandDocumentType(AsyncIOTestCase): # type:ignore[misc]
|
||||
async def test_raw_bson_document_type(self) -> None:
|
||||
client: AgnosticClient = AgnosticClient()
|
||||
codec_options = CodecOptions(RawBSONDocument)
|
||||
result: RawBSONDocument = await client.admin.command(
|
||||
result = await client.admin.command(
|
||||
"ping", codec_options=codec_options
|
||||
) # Fix once @overload for command works
|
||||
assert len(result.raw) > 0
|
||||
|
||||
@only_type_check
|
||||
async def test_son_document_type(self) -> None:
|
||||
client = AgnosticClient(document_class=SON[str, Any])
|
||||
client: AgnosticClient[SON[str, Any]] = AgnosticClient(document_class=SON[str, Any])
|
||||
codec_options = CodecOptions(SON[str, Any])
|
||||
result = await client.admin.command("ping", codec_options=codec_options)
|
||||
result["a"] = 1
|
||||
|
||||
@ -25,6 +25,8 @@ import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
def one(s):
|
||||
"""Get one element of a set"""
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
|
||||
"""Some tools for running tests based on MongoDB server version."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
class Version(tuple):
|
||||
"""Copied from PyMongo's test.version submodule."""
|
||||
|
||||
4
tox.ini
4
tox.ini
@ -169,6 +169,6 @@ setenv =
|
||||
# TODO: Remove as part of MOTOR-1163
|
||||
AIOHTTP_NO_EXTENSIONS=1
|
||||
commands =
|
||||
mypy --install-types --non-interactive motor/ --disallow-untyped-defs --follow-imports=skip --exclude 'motor/aiohttp/*' --exclude 'motor/frameworks/*' --exclude 'motor/metaprogramming.py' --exclude 'motor/web.py'
|
||||
mypy --install-types --non-interactive --follow-imports=skip test/test_typing.py
|
||||
mypy --install-types --non-interactive motor
|
||||
mypy --install-types --non-interactive test/test_typing.py
|
||||
pytest test/test_mypy_fails.py
|
||||
|
||||
Loading…
Reference in New Issue
Block a user