diff --git a/motor/__init__.py b/motor/__init__.py index d7d2b588..3ce8de37 100644 --- a/motor/__init__.py +++ b/motor/__init__.py @@ -21,7 +21,7 @@ from ._version import get_version_string, version, version_tuple # noqa: F401 try: import tornado except ImportError: - tornado = None + tornado = None # type:ignore[assignment] else: # For backwards compatibility with Motor 0.4, export Motor's Tornado classes # at module root. This may change in the future. diff --git a/motor/aiohttp/__init__.py b/motor/aiohttp/__init__.py index bde3b13e..23515949 100644 --- a/motor/aiohttp/__init__.py +++ b/motor/aiohttp/__init__.py @@ -28,6 +28,8 @@ import gridfs from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorGridFSBucket from motor.motor_gridfs import _hash_gridout +# mypy: disable-error-code="no-untyped-def,no-untyped-call" + def get_gridfs_file(bucket, filename, request): """Override to choose a GridFS file to serve at a URL. diff --git a/motor/core.pyi b/motor/core.pyi index 5abe4643..04e342e2 100644 --- a/motor/core.pyi +++ b/motor/core.pyi @@ -22,6 +22,7 @@ from typing import ( Callable, Collection, Coroutine, + Generic, Iterable, Mapping, MutableMapping, @@ -30,6 +31,7 @@ from typing import ( Sequence, TypeVar, Union, + overload, ) import pymongo.common @@ -42,12 +44,13 @@ from pymongo import IndexModel, ReadPreference, WriteConcern from pymongo.change_stream import ChangeStream from pymongo.client_options import ClientOptions from pymongo.client_session import _T, ClientSession, SessionOptions, TransactionOptions -from pymongo.collection import ReturnDocument, _IndexKeyHint, _IndexList, _WriteOp +from pymongo.collection import ReturnDocument, _WriteOp from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor from pymongo.cursor import Cursor, RawBatchCursor, _Hint, _Sort from pymongo.database import Database from pymongo.encryption import ClientEncryption, RewrapManyDataKeyResult from pymongo.encryption_options import RangeOpts +from pymongo.operations import _IndexKeyHint, _IndexList from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.results import ( @@ -67,9 +70,9 @@ from pymongo.typings import ( ) try: - from pymongo import SearchIndexModel + from pymongo.operations import SearchIndexModel except ImportError: - SearchIndexModel = Any + SearchIndexModel = Any # type:ignore[misc,assignment] _WITH_TRANSACTION_RETRY_TIME_LIMIT: int @@ -85,15 +88,15 @@ class AgnosticBase: def __init__(self, delegate: Any) -> None: ... def __repr__(self) -> str: ... -class AgnosticBaseProperties(AgnosticBase): - codec_options: CodecOptions +class AgnosticBaseProperties(AgnosticBase, Generic[_DocumentType]): + codec_options: CodecOptions[_DocumentType] read_preference: _ServerMode read_concern: ReadConcern write_concern: WriteConcern -class AgnosticClient(AgnosticBaseProperties): +class AgnosticClient(AgnosticBaseProperties[_DocumentType]): __motor_class_name__: str - __delegate_class__: type[pymongo.MongoClient] + __delegate_class__: type[pymongo.MongoClient[_DocumentType]] def address(self) -> Optional[tuple[str, int]]: ... def arbiters(self) -> set[tuple[str, int]]: ... @@ -226,9 +229,9 @@ class AgnosticClientSession(AgnosticBase): def __enter__(self) -> None: ... def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... -class AgnosticDatabase(AgnosticBaseProperties): +class AgnosticDatabase(AgnosticBaseProperties[_DocumentType]): __motor_class_name__: str - __delegate_class__: type[Database] + __delegate_class__: type[Database[_DocumentType]] def __hash__(self) -> int: ... def __bool__(self) -> int: ... @@ -243,6 +246,33 @@ class AgnosticDatabase(AgnosticBaseProperties): max_await_time_ms: Optional[int] = None, **kwargs: Any, ) -> AgnosticCommandCursor: ... + @overload + async def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = ..., + check: bool = ..., + allowable_errors: Optional[Sequence[Union[str, int]]] = ..., + read_preference: Optional[_ServerMode] = ..., + codec_options: None = ..., + session: Optional[AgnosticClientSession] = ..., + comment: Optional[Any] = ..., + **kwargs: Any, + ) -> dict[str, Any]: ... + @overload + async def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: CodecOptions[_CodecDocumentType] = ..., + session: Optional[AgnosticClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> _CodecDocumentType: ... + @overload async def command( self, command: Union[str, MutableMapping[str, Any]], @@ -280,7 +310,7 @@ class AgnosticDatabase(AgnosticBaseProperties): comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, ) -> dict[str, Any]: ... - async def get_collection( + def get_collection( self, name: str, codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, @@ -302,6 +332,7 @@ class AgnosticDatabase(AgnosticBaseProperties): comment: Optional[Any] = None, **kwargs: Any, ) -> AgnosticCommandCursor: ... + @property def name(self) -> str: ... async def validate_collection( self, @@ -349,9 +380,9 @@ class AgnosticDatabase(AgnosticBaseProperties): def wrap(self, obj: Any) -> Any: ... def get_io_loop(self) -> Any: ... -class AgnosticCollection(AgnosticBaseProperties): +class AgnosticCollection(AgnosticBaseProperties[_DocumentType]): __motor_class_name__: str - __delegate_class__: type[Collection] + __delegate_class__: type[Collection[_DocumentType]] def __hash__(self) -> int: ... def __bool__(self) -> bool: ... @@ -495,6 +526,7 @@ class AgnosticCollection(AgnosticBaseProperties): session: Optional[AgnosticClientSession] = None, comment: Optional[Any] = None, ) -> InsertOneResult: ... + @property def name(self) -> str: ... async def options( self, session: Optional[AgnosticClientSession] = None, comment: Optional[Any] = None @@ -651,8 +683,10 @@ class AgnosticBaseCursor(AgnosticBase): def next_object(self) -> Any: ... def each(self, callback: Callable) -> None: ... def _each_got_more(self, callback: Callable, future: Any) -> None: ... - def to_list(self, length: int) -> Future[list]: ... - def _to_list(self, length: int, the_list: list, future: Any, get_more_result: Any) -> None: ... + def to_list(self, length: Optional[int]) -> Future[list]: ... + def _to_list( + self, length: Optional[int], the_list: list, future: Any, get_more_result: Any + ) -> None: ... def get_io_loop(self) -> Any: ... def batch_size(self, batch_size: int) -> AgnosticBaseCursor: ... def _buffer_size(self) -> int: ... diff --git a/motor/frameworks/asyncio/__init__.py b/motor/frameworks/asyncio/__init__.py index 1c7431d2..6db97a46 100644 --- a/motor/frameworks/asyncio/__init__.py +++ b/motor/frameworks/asyncio/__init__.py @@ -25,6 +25,8 @@ import warnings from asyncio import get_event_loop # noqa: F401 - For framework interface. from concurrent.futures import ThreadPoolExecutor +# mypy: ignore-errors + try: import contextvars except ImportError: diff --git a/motor/frameworks/tornado/__init__.py b/motor/frameworks/tornado/__init__.py index befb8908..5d399584 100644 --- a/motor/frameworks/tornado/__init__.py +++ b/motor/frameworks/tornado/__init__.py @@ -32,6 +32,7 @@ try: except ImportError: contextvars = None +# mypy: ignore-errors CLASS_PREFIX = "" diff --git a/motor/metaprogramming.py b/motor/metaprogramming.py index c304e516..d291826f 100644 --- a/motor/metaprogramming.py +++ b/motor/metaprogramming.py @@ -19,6 +19,8 @@ from typing import Any, Callable, Dict, TypeVar _class_cache: Dict[Any, Any] = {} +# mypy: ignore-errors + def asynchronize(framework, sync_method: Callable, doc=None, wrap_class=None, unwrap_class=None): """Decorate `sync_method` so it returns a Future. diff --git a/motor/web.py b/motor/web.py index 3626e8cb..37e5dd46 100644 --- a/motor/web.py +++ b/motor/web.py @@ -24,6 +24,8 @@ import tornado.web import motor from motor.motor_gridfs import _hash_gridout +# mypy: disable-error-code="no-untyped-def,no-untyped-call" + # TODO: this class is not a drop-in replacement for StaticFileHandler. # StaticFileHandler provides class method make_static_url, which appends # an checksum of the static file's contents. Templates thus can do diff --git a/test/__init__.py b/test/__init__.py index ab8b35e9..fd343180 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -15,6 +15,7 @@ """Test Motor, an asynchronous driver for MongoDB and Tornado.""" from test.test_environment import CLIENT_PEM, db_user, env # noqa: F401 +from typing import Any from unittest import SkipTest # noqa: F401 try: @@ -34,11 +35,11 @@ except ImportError: class MockRequestHandler: """For testing MotorGridOut.stream_to_handler.""" - def __init__(self): + def __init__(self) -> None: self.n_written = 0 - def write(self, data): + def write(self, data: Any) -> None: self.n_written += len(data) - def flush(self): + def flush(self) -> None: pass diff --git a/test/assert_logs_backport.py b/test/assert_logs_backport.py index c3e12312..b5474d1d 100644 --- a/test/assert_logs_backport.py +++ b/test/assert_logs_backport.py @@ -3,6 +3,8 @@ import collections import logging +# mypy: ignore-errors + _LoggingWatcher = collections.namedtuple("_LoggingWatcher", ["records", "output"]) diff --git a/test/asyncio_tests/__init__.py b/test/asyncio_tests/__init__.py index 62ab8a75..a1ee767f 100644 --- a/test/asyncio_tests/__init__.py +++ b/test/asyncio_tests/__init__.py @@ -28,6 +28,9 @@ from unittest import SkipTest from mockupdb import MockupDB from motor import motor_asyncio +from motor.core import AgnosticClient + +# mypy: ignore-errors class _TestMethodWrapper: @@ -104,7 +107,7 @@ class AsyncIOTestCase(AssertLogsMixin, unittest.TestCase): kwargs.setdefault("tls", env.mongod_started_with_ssl) return kwargs - def asyncio_client(self, uri=None, *args, set_loop=True, **kwargs): + def asyncio_client(self, uri=None, *args, set_loop=True, **kwargs) -> AgnosticClient: """Get an AsyncIOMotorClient. Ignores self.ssl, you must pass 'ssl' argument. diff --git a/test/test_environment.py b/test/test_environment.py index 95025b58..9bdbd1c7 100644 --- a/test/test_environment.py +++ b/test/test_environment.py @@ -24,6 +24,9 @@ from unittest import SkipTest import pymongo.errors +# mypy: ignore-errors + + HAVE_SSL = True try: import ssl diff --git a/test/test_typing.py b/test/test_typing.py index b32a121f..bb8ebc33 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -17,7 +17,7 @@ sample client code that uses Motor typings. """ import unittest from test.asyncio_tests import AsyncIOTestCase, asyncio_test -from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, TypeVar, Union, cast from bson import CodecOptions from bson.raw_bson import RawBSONDocument @@ -65,7 +65,7 @@ def only_type_check(func: FuncT) -> FuncT: return cast(FuncT, inner) -class TestMotor(AsyncIOTestCase): # type:ignore[misc] +class TestMotor(AsyncIOTestCase): cx: AgnosticClient @asyncio_test # type:ignore[misc] @@ -89,6 +89,20 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc] docs = await cursor.to_list(None) self.assertTrue(docs) + @asyncio_test # type:ignore[misc] + async def test_get_collection(self) -> None: + coll = self.db.get_collection("test_collection") + self.assertEqual(coll.name, "test_collection") + + @asyncio_test # type:ignore[misc] + async def test_find_one(self) -> None: + c: AgnosticClient[Movie] = self.asyncio_client() + coll = c[self.db.name]["movies"] + await coll.insert_one(Movie(name="American Graffiti", year=1973)) + result = await coll.find_one({}) + assert result is not None + self.assertEqual(result["year"], 1973) + @only_type_check @asyncio_test # type:ignore[misc] async def test_bulk_write(self) -> None: @@ -123,7 +137,7 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc] InsertOne(Movie(name="American Graffiti", year=1973)), ReplaceOne( {}, - {"name": "American Graffiti", "year": "WRONG_TYPE"}, + {"name": "American Graffiti", "year": "WRONG_TYPE"}, # type:ignore[typeddict-item] ), DeleteOne({}), ] @@ -138,13 +152,13 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc] @asyncio_test # type:ignore[misc] async def test_list_collections(self) -> None: cursor = await self.cx.test.list_collections() - value = await cursor.next() + value: Mapping[str, Any] = await cursor.next() value.items() @asyncio_test # type:ignore[misc] async def test_list_databases(self) -> None: cursor = await self.cx.list_databases() - value = await cursor.next() + value: Mapping[str, Any] = await cursor.next() value.items() @asyncio_test # type:ignore[misc] @@ -193,14 +207,14 @@ class TestMotor(AsyncIOTestCase): # type:ignore[misc] ) -class TestDocumentType(AsyncIOTestCase): # type:ignore[misc] +class TestDocumentType(AsyncIOTestCase): @only_type_check def test_typeddict_explicit_document_type(self) -> None: out = MovieWithId(_id=ObjectId(), name="THX-1138", year=1971) assert out is not None # This should fail because the output is a Movie. assert out["foo"] # type:ignore[typeddict-item] - assert out["_id"] + assert bool(out["_id"]) # This should work the same as the test above, but this time using NotRequired to allow # automatic insertion of the _id field by insert_one. @@ -211,7 +225,7 @@ class TestDocumentType(AsyncIOTestCase): # type:ignore[misc] # This should fail because the output is a Movie. assert out["foo"] # type:ignore[typeddict-item] # pyright gives reportTypedDictNotRequiredAccess for the following: - assert out["_id"] + assert bool(out["_id"]) @only_type_check def test_typeddict_empty_document_type(self) -> None: @@ -223,7 +237,7 @@ class TestDocumentType(AsyncIOTestCase): # type:ignore[misc] assert out["_id"] # type:ignore[typeddict-item] -class TestCommandDocumentType(AsyncIOTestCase): # type:ignore[misc] +class TestCommandDocumentType(AsyncIOTestCase): @only_type_check async def test_default(self) -> None: client: AgnosticClient = AgnosticClient() @@ -249,14 +263,14 @@ class TestCommandDocumentType(AsyncIOTestCase): # type:ignore[misc] async def test_raw_bson_document_type(self) -> None: client: AgnosticClient = AgnosticClient() codec_options = CodecOptions(RawBSONDocument) - result: RawBSONDocument = await client.admin.command( + result = await client.admin.command( "ping", codec_options=codec_options ) # Fix once @overload for command works assert len(result.raw) > 0 @only_type_check async def test_son_document_type(self) -> None: - client = AgnosticClient(document_class=SON[str, Any]) + client: AgnosticClient[SON[str, Any]] = AgnosticClient(document_class=SON[str, Any]) codec_options = CodecOptions(SON[str, Any]) result = await client.admin.command("ping", codec_options=codec_options) result["a"] = 1 diff --git a/test/utils.py b/test/utils.py index b17df736..bf6fe166 100644 --- a/test/utils.py +++ b/test/utils.py @@ -25,6 +25,8 @@ import os import time import warnings +# mypy: ignore-errors + def one(s): """Get one element of a set""" diff --git a/test/version.py b/test/version.py index b3ee2035..64bffad8 100644 --- a/test/version.py +++ b/test/version.py @@ -14,6 +14,8 @@ """Some tools for running tests based on MongoDB server version.""" +# mypy: ignore-errors + class Version(tuple): """Copied from PyMongo's test.version submodule.""" diff --git a/tox.ini b/tox.ini index 51baff44..38d60f2d 100644 --- a/tox.ini +++ b/tox.ini @@ -169,6 +169,6 @@ setenv = # TODO: Remove as part of MOTOR-1163 AIOHTTP_NO_EXTENSIONS=1 commands = - mypy --install-types --non-interactive motor/ --disallow-untyped-defs --follow-imports=skip --exclude 'motor/aiohttp/*' --exclude 'motor/frameworks/*' --exclude 'motor/metaprogramming.py' --exclude 'motor/web.py' - mypy --install-types --non-interactive --follow-imports=skip test/test_typing.py + mypy --install-types --non-interactive motor + mypy --install-types --non-interactive test/test_typing.py pytest test/test_mypy_fails.py