PYTHON-3073 Copy the unit tests from pymongo-stubs into pymongo (#859)

This commit is contained in:
Steven Silvester 2022-02-10 18:54:46 -06:00 committed by GitHub
parent ddb6614442
commit cbc7cc33e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 149 additions and 1 deletions

View File

@ -37,6 +37,7 @@ jobs:
mongodb-version: 4.4
- name: Run tests
run: |
pip install mypy
python setup.py test
mypytest:
@ -59,4 +60,4 @@ jobs:
- name: Run mypy
run: |
mypy --install-types --non-interactive bson gridfs tools pymongo
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index --exclude "test/mypy_fails/*.*" test

View File

@ -0,0 +1,6 @@
from pymongo import MongoClient
client = MongoClient()
client.test.test.insert_many(
{"a": 1}
) # error: Dict entry 0 has incompatible type "str": "int"; expected "Mapping[str, Any]": "int"

View File

@ -0,0 +1,6 @@
from pymongo import MongoClient
client = MongoClient()
client.test.test.insert_one(
[{}]
) # error: Argument 1 to "insert_one" of "Collection" has incompatible type "List[Dict[<nothing>, <nothing>]]"; expected "Mapping[str, Any]"

View File

@ -1117,6 +1117,16 @@ class TestCodecOptions(unittest.TestCase):
)
self.round_trip_pickle(i64, pickled_with_3)
def test_bson_encode_decode(self) -> None:
doc = {"_id": ObjectId()}
encoded = bson.encode(doc)
decoded = bson.decode(encoded)
encoded = bson.encode(decoded)
decoded = bson.decode(encoded)
# Documents returned from decode are mutable.
decoded["new_field"] = 1
self.assertTrue(decoded["_id"].generation_time)
if __name__ == "__main__":
unittest.main()

125
test/test_mypy.py Normal file
View File

@ -0,0 +1,125 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that each file in mypy_fails/ actually fails mypy, and test some
sample client code that uses PyMongo typings."""
import os
import sys
import unittest
from typing import Any, Dict, Iterable, List
try:
from mypy import api
except ImportError:
api = None
from bson.son import SON
from pymongo.collection import Collection
from pymongo.errors import ServerSelectionTimeoutError
from pymongo.mongo_client import MongoClient
from pymongo.operations import InsertOne
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails")
def get_tests() -> Iterable[str]:
for dirpath, _, filenames in os.walk(TEST_PATH):
for filename in filenames:
yield os.path.join(dirpath, filename)
class TestMypyFails(unittest.TestCase):
def ensure_mypy_fails(self, filename: str) -> None:
if api is None:
raise unittest.SkipTest("Mypy is not installed")
stdout, stderr, exit_status = api.run([filename])
self.assertTrue(exit_status, msg=stdout)
def test_mypy_failures(self) -> None:
for filename in get_tests():
with self.subTest(filename=filename):
self.ensure_mypy_fails(filename)
class TestPymongo(unittest.TestCase):
client: MongoClient
coll: Collection
@classmethod
def setUpClass(cls) -> None:
cls.client = MongoClient(serverSelectionTimeoutMS=250, directConnection=False)
cls.coll = cls.client.test.test
try:
cls.client.admin.command("ping")
except ServerSelectionTimeoutError as exc:
raise unittest.SkipTest(f"Could not connect to MongoDB: {exc}")
@classmethod
def tearDownClass(cls) -> None:
cls.client.close()
def test_insert_find(self) -> None:
doc = {"my": "doc"}
coll2 = self.client.test.test2
result = self.coll.insert_one(doc)
self.assertEqual(result.inserted_id, doc["_id"])
retreived = self.coll.find_one({"_id": doc["_id"]})
if retreived:
# Documents returned from find are mutable.
retreived["new_field"] = 1
result2 = coll2.insert_one(retreived)
self.assertEqual(result2.inserted_id, result.inserted_id)
def test_cursor_iterable(self) -> None:
def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
return list(iterable)
self.coll.insert_one({})
cursor = self.coll.find()
docs = to_list(cursor)
self.assertTrue(docs)
def test_bulk_write(self) -> None:
self.coll.insert_one({})
requests = [InsertOne({})]
result = self.coll.bulk_write(requests)
self.assertTrue(result.acknowledged)
def test_aggregate_pipeline(self) -> None:
coll3 = self.client.test.test3
coll3.insert_many(
[
{"x": 1, "tags": ["dog", "cat"]},
{"x": 2, "tags": ["cat"]},
{"x": 2, "tags": ["mouse", "cat", "dog"]},
{"x": 3, "tags": []},
]
)
class mydict(Dict[str, Any]):
pass
result = coll3.aggregate(
[
mydict({"$unwind": "$tags"}),
{"$group": {"_id": "$tags", "count": {"$sum": 1}}},
{"$sort": SON([("count", -1), ("_id", -1)])},
]
)
self.assertTrue(len(list(result)))
if __name__ == "__main__":
unittest.main()