mongo-python-driver/test/test_cursor.py
Bernie Hackett 1ac607c447 PYTHON-314, PYTHON-744 - Hint by index name, count with hint.
This change introduces two closely related features. Cursor.hint
now accepts the name of an index as an alternative to passing
the index spec. Cursor.count will now pass the hint, if one was
specified, to the count command. Count with hint is only supported
by MongoDB 2.6 and newer.
2014-09-12 13:50:04 -07:00

1170 lines
41 KiB
Python

# Copyright 2009-2014 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the cursor module."""
import copy
import itertools
import random
import re
import sys
import unittest
import warnings
sys.path[0:0] = [""]
from nose.plugins.skip import SkipTest
from bson.code import Code
from bson.son import SON
from pymongo import (ASCENDING,
DESCENDING,
ALL,
OFF)
from pymongo.command_cursor import CommandCursor
from pymongo.cursor_manager import CursorManager
from pymongo.database import Database
from pymongo.errors import (InvalidOperation,
OperationFailure,
ExecutionTimeout)
from test import version, skip_restricted_localhost
from test.test_client import get_client
from test.utils import (catch_warnings, is_mongos,
get_command_line, server_started_with_auth)
setUpModule = skip_restricted_localhost
class TestCursor(unittest.TestCase):
def setUp(self):
self.client = get_client()
self.db = Database(self.client, "pymongo_test")
def tearDown(self):
self.db = None
def test_max_time_ms(self):
if not version.at_least(self.db.connection, (2, 5, 3, -1)):
raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")
db = self.db
db.pymongo_test.drop()
coll = db.pymongo_test
self.assertRaises(TypeError, coll.find().max_time_ms, 'foo')
coll.insert({"amalia": 1})
coll.insert({"amalia": 2})
coll.find().max_time_ms(None)
coll.find().max_time_ms(1L)
cursor = coll.find().max_time_ms(999)
self.assertEqual(999, cursor._Cursor__max_time_ms)
cursor = coll.find().max_time_ms(10).max_time_ms(1000)
self.assertEqual(1000, cursor._Cursor__max_time_ms)
cursor = coll.find().max_time_ms(999)
c2 = cursor.clone()
self.assertEqual(999, c2._Cursor__max_time_ms)
self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec())
self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec())
self.assertTrue(coll.find_one(max_time_ms=1000))
if "enableTestCommands=1" in get_command_line(self.client)["argv"]:
# Cursor parses server timeout error in response to initial query.
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")
try:
cursor = coll.find().max_time_ms(1)
try:
cursor.next()
except ExecutionTimeout:
pass
else:
self.fail("ExecutionTimeout not raised")
self.assertRaises(ExecutionTimeout,
coll.find_one, max_time_ms=1)
finally:
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="off")
def test_max_time_ms_getmore(self):
# Test that Cursor handles server timeout error in response to getmore.
if "enableTestCommands=1" not in get_command_line(self.client)["argv"]:
raise SkipTest("Need test commands enabled")
if not version.at_least(self.db.connection, (2, 5, 3, -1)):
raise SkipTest("MaxTimeMS requires MongoDB >= 2.5.3")
coll = self.db.pymongo_test
coll.insert({} for _ in range(200))
cursor = coll.find().max_time_ms(100)
# Send initial query before turning on failpoint.
cursor.next()
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="alwaysOn")
try:
try:
# Iterate up to first getmore.
list(cursor)
except ExecutionTimeout:
pass
else:
self.fail("ExecutionTimeout not raised")
finally:
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",
mode="off")
def test_explain(self):
a = self.db.test.find()
a.explain()
for _ in a:
break
b = a.explain()
# "cursor" pre MongoDB 2.7.6, "executionStats" post
self.assertTrue("cursor" in b or "executionStats" in b)
def test_hint(self):
db = self.db
self.assertRaises(TypeError, db.test.find().hint, 5.5)
db.test.drop()
for i in range(100):
db.test.insert({"num": i, "foo": i})
self.assertRaises(OperationFailure,
db.test.find({"num": 17, "foo": 17})
.hint([("num", ASCENDING)]).explain)
self.assertRaises(OperationFailure,
db.test.find({"num": 17, "foo": 17})
.hint([("foo", ASCENDING)]).explain)
spec = [("num", DESCENDING)]
index = db.test.create_index(spec)
first = db.test.find().next()
self.assertEqual(0, first.get('num'))
first = db.test.find().hint(spec).next()
self.assertEqual(99, first.get('num'))
self.assertRaises(OperationFailure,
db.test.find({"num": 17, "foo": 17})
.hint([("foo", ASCENDING)]).explain)
a = db.test.find({"num": 17})
a.hint(spec)
for _ in a:
break
self.assertRaises(InvalidOperation, a.hint, spec)
def test_hint_by_name(self):
db = self.db
db.test.drop()
for i in range(100):
db.test.insert({'i': i})
db.test.create_index([('i', DESCENDING)], name='fooindex')
first = db.test.find().next()
self.assertEqual(0, first.get('i'))
first = db.test.find().hint('fooindex').next()
self.assertEqual(99, first.get('i'))
def test_limit(self):
db = self.db
self.assertRaises(TypeError, db.test.find().limit, None)
self.assertRaises(TypeError, db.test.find().limit, "hello")
self.assertRaises(TypeError, db.test.find().limit, 5.5)
self.assertTrue(db.test.find().limit(5L))
db.test.drop()
for i in range(100):
db.test.save({"x": i})
count = 0
for _ in db.test.find():
count += 1
self.assertEqual(count, 100)
count = 0
for _ in db.test.find().limit(20):
count += 1
self.assertEqual(count, 20)
count = 0
for _ in db.test.find().limit(99):
count += 1
self.assertEqual(count, 99)
count = 0
for _ in db.test.find().limit(1):
count += 1
self.assertEqual(count, 1)
count = 0
for _ in db.test.find().limit(0):
count += 1
self.assertEqual(count, 100)
count = 0
for _ in db.test.find().limit(0).limit(50).limit(10):
count += 1
self.assertEqual(count, 10)
a = db.test.find()
a.limit(10)
for _ in a:
break
self.assertRaises(InvalidOperation, a.limit, 5)
def test_max(self):
db = self.db
db.test.drop()
db.test.ensure_index([("j", ASCENDING)])
for j in range(10):
db.test.insert({"j": j, "k": j})
cursor = db.test.find().max([("j", 3)])
self.assertEqual(len(list(cursor)), 3)
# Tuple.
cursor = db.test.find().max((("j", 3), ))
self.assertEqual(len(list(cursor)), 3)
# Compound index.
db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
cursor = db.test.find().max([("j", 3), ("k", 3)])
self.assertEqual(len(list(cursor)), 3)
# Wrong order.
cursor = db.test.find().max([("k", 3), ("j", 3)])
self.assertRaises(OperationFailure, list, cursor)
# No such index.
cursor = db.test.find().max([("k", 3)])
self.assertRaises(OperationFailure, list, cursor)
self.assertRaises(TypeError, db.test.find().max, 10)
self.assertRaises(TypeError, db.test.find().max, {"j": 10})
def test_min(self):
db = self.db
db.test.drop()
db.test.ensure_index([("j", ASCENDING)])
for j in range(10):
db.test.insert({"j": j, "k": j})
cursor = db.test.find().min([("j", 3)])
self.assertEqual(len(list(cursor)), 7)
# Tuple.
cursor = db.test.find().min((("j", 3), ))
self.assertEqual(len(list(cursor)), 7)
# Compound index.
db.test.ensure_index([("j", ASCENDING), ("k", ASCENDING)])
cursor = db.test.find().min([("j", 3), ("k", 3)])
self.assertEqual(len(list(cursor)), 7)
# Wrong order.
cursor = db.test.find().min([("k", 3), ("j", 3)])
self.assertRaises(OperationFailure, list, cursor)
# No such index.
cursor = db.test.find().min([("k", 3)])
self.assertRaises(OperationFailure, list, cursor)
self.assertRaises(TypeError, db.test.find().min, 10)
self.assertRaises(TypeError, db.test.find().min, {"j": 10})
def test_batch_size(self):
db = self.db
db.test.drop()
for x in range(200):
db.test.save({"x": x})
self.assertRaises(TypeError, db.test.find().batch_size, None)
self.assertRaises(TypeError, db.test.find().batch_size, "hello")
self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
self.assertRaises(ValueError, db.test.find().batch_size, -1)
self.assertTrue(db.test.find().batch_size(5L))
a = db.test.find()
for _ in a:
break
self.assertRaises(InvalidOperation, a.batch_size, 5)
def cursor_count(cursor, expected_count):
count = 0
for _ in cursor:
count += 1
self.assertEqual(expected_count, count)
cursor_count(db.test.find().batch_size(0), 200)
cursor_count(db.test.find().batch_size(1), 200)
cursor_count(db.test.find().batch_size(2), 200)
cursor_count(db.test.find().batch_size(5), 200)
cursor_count(db.test.find().batch_size(100), 200)
cursor_count(db.test.find().batch_size(500), 200)
cursor_count(db.test.find().batch_size(0).limit(1), 1)
cursor_count(db.test.find().batch_size(1).limit(1), 1)
cursor_count(db.test.find().batch_size(2).limit(1), 1)
cursor_count(db.test.find().batch_size(5).limit(1), 1)
cursor_count(db.test.find().batch_size(100).limit(1), 1)
cursor_count(db.test.find().batch_size(500).limit(1), 1)
cursor_count(db.test.find().batch_size(0).limit(10), 10)
cursor_count(db.test.find().batch_size(1).limit(10), 10)
cursor_count(db.test.find().batch_size(2).limit(10), 10)
cursor_count(db.test.find().batch_size(5).limit(10), 10)
cursor_count(db.test.find().batch_size(100).limit(10), 10)
cursor_count(db.test.find().batch_size(500).limit(10), 10)
def test_limit_and_batch_size(self):
db = self.db
db.test.drop()
for x in range(500):
db.test.save({"x": x})
curs = db.test.find().limit(0).batch_size(10)
curs.next()
self.assertEqual(10, curs._Cursor__retrieved)
curs = db.test.find().limit(-2).batch_size(0)
curs.next()
self.assertEqual(2, curs._Cursor__retrieved)
curs = db.test.find().limit(-4).batch_size(5)
curs.next()
self.assertEqual(4, curs._Cursor__retrieved)
curs = db.test.find().limit(50).batch_size(500)
curs.next()
self.assertEqual(50, curs._Cursor__retrieved)
curs = db.test.find().batch_size(500)
curs.next()
self.assertEqual(500, curs._Cursor__retrieved)
curs = db.test.find().limit(50)
curs.next()
self.assertEqual(50, curs._Cursor__retrieved)
# these two might be shaky, as the default
# is set by the server. as of 2.0.0-rc0, 101
# or 1MB (whichever is smaller) is default
# for queries without ntoreturn
curs = db.test.find()
curs.next()
self.assertEqual(101, curs._Cursor__retrieved)
curs = db.test.find().limit(0).batch_size(0)
curs.next()
self.assertEqual(101, curs._Cursor__retrieved)
def test_skip(self):
db = self.db
self.assertRaises(TypeError, db.test.find().skip, None)
self.assertRaises(TypeError, db.test.find().skip, "hello")
self.assertRaises(TypeError, db.test.find().skip, 5.5)
self.assertRaises(ValueError, db.test.find().skip, -5)
self.assertTrue(db.test.find().skip(5L))
db.drop_collection("test")
for i in range(100):
db.test.save({"x": i})
for i in db.test.find():
self.assertEqual(i["x"], 0)
break
for i in db.test.find().skip(20):
self.assertEqual(i["x"], 20)
break
for i in db.test.find().skip(99):
self.assertEqual(i["x"], 99)
break
for i in db.test.find().skip(1):
self.assertEqual(i["x"], 1)
break
for i in db.test.find().skip(0):
self.assertEqual(i["x"], 0)
break
for i in db.test.find().skip(0).skip(50).skip(10):
self.assertEqual(i["x"], 10)
break
for i in db.test.find().skip(1000):
self.fail()
a = db.test.find()
a.skip(10)
for _ in a:
break
self.assertRaises(InvalidOperation, a.skip, 5)
def test_sort(self):
db = self.db
self.assertRaises(TypeError, db.test.find().sort, 5)
self.assertRaises(ValueError, db.test.find().sort, [])
self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
self.assertRaises(TypeError, db.test.find().sort,
[("hello", DESCENDING)], DESCENDING)
db.test.drop()
unsort = range(10)
random.shuffle(unsort)
for i in unsort:
db.test.save({"x": i})
asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
self.assertEqual(asc, range(10))
asc = [i["x"] for i in db.test.find().sort("x")]
self.assertEqual(asc, range(10))
asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
self.assertEqual(asc, range(10))
expect = range(10)
expect.reverse()
desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
self.assertEqual(desc, expect)
desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
self.assertEqual(desc, expect)
desc = [i["x"] for i in
db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
self.assertEqual(desc, expect)
expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
shuffled = list(expected)
random.shuffle(shuffled)
db.test.drop()
for (a, b) in shuffled:
db.test.save({"a": a, "b": b})
result = [(i["a"], i["b"]) for i in
db.test.find().sort([("b", DESCENDING),
("a", ASCENDING)])]
self.assertEqual(result, expected)
a = db.test.find()
a.sort("x", ASCENDING)
for _ in a:
break
self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)
def test_count(self):
db = self.db
db.test.drop()
self.assertEqual(0, db.test.find().count())
for i in range(10):
db.test.save({"x": i})
self.assertEqual(10, db.test.find().count())
self.assertTrue(isinstance(db.test.find().count(), int))
self.assertEqual(10, db.test.find().limit(5).count())
self.assertEqual(10, db.test.find().skip(5).count())
self.assertEqual(1, db.test.find({"x": 1}).count())
self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())
a = db.test.find()
b = a.count()
for _ in a:
break
self.assertEqual(b, a.count())
self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())
def test_count_with_hint(self):
collection = self.db.test
collection.drop()
collection.save({'i': 1})
collection.save({'i': 2})
self.assertEqual(2, collection.find().count())
collection.create_index([('i', 1)])
self.assertEqual(1, collection.find({'i': 1}).hint("_id_").count())
self.assertEqual(2, collection.find().hint("_id_").count())
if version.at_least(self.client, (2, 6, 0)):
# Count supports hint
self.assertRaises(OperationFailure,
collection.find({'i': 1}).hint("BAD HINT").count)
else:
# Hint is ignored
self.assertEqual(
1, collection.find({'i': 1}).hint("BAD HINT").count())
# Create a sparse index which should have no entries.
collection.create_index([('x', 1)], sparse=True)
if version.at_least(self.client, (2, 6, 0)):
# Count supports hint
self.assertEqual(0, collection.find({'i': 1}).hint("x_1").count())
else:
# Hint is ignored
self.assertEqual(1, collection.find({'i': 1}).hint("x_1").count())
self.assertEqual(2, collection.find().hint("x_1").count())
def test_where(self):
db = self.db
db.test.drop()
a = db.test.find()
self.assertRaises(TypeError, a.where, 5)
self.assertRaises(TypeError, a.where, None)
self.assertRaises(TypeError, a.where, {})
for i in range(10):
db.test.save({"x": i})
self.assertEqual(3, len(list(db.test.find().where('this.x < 3'))))
self.assertEqual(3,
len(list(db.test.find().where(Code('this.x < 3')))))
self.assertEqual(3, len(list(db.test.find().where(Code('this.x < i',
{"i": 3})))))
self.assertEqual(10, len(list(db.test.find())))
self.assertEqual(3, db.test.find().where('this.x < 3').count())
self.assertEqual(10, db.test.find().count())
self.assertEqual(3, db.test.find().where(u'this.x < 3').count())
self.assertEqual([0, 1, 2],
[a["x"] for a in
db.test.find().where('this.x < 3')])
self.assertEqual([],
[a["x"] for a in
db.test.find({"x": 5}).where('this.x < 3')])
self.assertEqual([5],
[a["x"] for a in
db.test.find({"x": 5}).where('this.x > 3')])
cursor = db.test.find().where('this.x < 3').where('this.x > 7')
self.assertEqual([8, 9], [a["x"] for a in cursor])
a = db.test.find()
b = a.where('this.x > 3')
for _ in a:
break
self.assertRaises(InvalidOperation, a.where, 'this.x < 3')
def test_rewind(self):
self.db.test.save({"x": 1})
self.db.test.save({"x": 2})
self.db.test.save({"x": 3})
cursor = self.db.test.find().limit(2)
count = 0
for _ in cursor:
count += 1
self.assertEqual(2, count)
count = 0
for _ in cursor:
count += 1
self.assertEqual(0, count)
cursor.rewind()
count = 0
for _ in cursor:
count += 1
self.assertEqual(2, count)
cursor.rewind()
count = 0
for _ in cursor:
break
cursor.rewind()
for _ in cursor:
count += 1
self.assertEqual(2, count)
self.assertEqual(cursor, cursor.rewind())
def test_clone(self):
self.db.test.save({"x": 1})
self.db.test.save({"x": 2})
self.db.test.save({"x": 3})
cursor = self.db.test.find().limit(2)
count = 0
for _ in cursor:
count += 1
self.assertEqual(2, count)
count = 0
for _ in cursor:
count += 1
self.assertEqual(0, count)
cursor = cursor.clone()
cursor2 = cursor.clone()
count = 0
for _ in cursor:
count += 1
self.assertEqual(2, count)
for _ in cursor2:
count += 1
self.assertEqual(4, count)
cursor.rewind()
count = 0
for _ in cursor:
break
cursor = cursor.clone()
for _ in cursor:
count += 1
self.assertEqual(2, count)
self.assertNotEqual(cursor, cursor.clone())
class MyClass(dict):
pass
cursor = self.db.test.find(as_class=MyClass)
for e in cursor:
self.assertEqual(type(MyClass()), type(e))
cursor = self.db.test.find(as_class=MyClass)
self.assertEqual(type(MyClass()), type(cursor[0]))
# Just test attributes
cursor = self.db.test.find({"x": re.compile("^hello.*")},
skip=1,
timeout=False,
snapshot=True,
tailable=True,
as_class=MyClass,
slave_okay=True,
await_data=True,
partial=True,
manipulate=False,
compile_re=False,
fields={'_id': False}).limit(2)
cursor.min([('a', 1)]).max([('b', 3)])
cursor.add_option(128)
cursor.comment('hi!')
cursor2 = cursor.clone()
self.assertEqual(cursor._Cursor__skip, cursor2._Cursor__skip)
self.assertEqual(cursor._Cursor__limit, cursor2._Cursor__limit)
self.assertEqual(cursor._Cursor__snapshot, cursor2._Cursor__snapshot)
self.assertEqual(type(cursor._Cursor__as_class),
type(cursor2._Cursor__as_class))
self.assertEqual(cursor._Cursor__slave_okay,
cursor2._Cursor__slave_okay)
self.assertEqual(cursor._Cursor__manipulate,
cursor2._Cursor__manipulate)
self.assertEqual(cursor._Cursor__compile_re,
cursor2._Cursor__compile_re)
self.assertEqual(cursor._Cursor__query_flags,
cursor2._Cursor__query_flags)
self.assertEqual(cursor._Cursor__comment,
cursor2._Cursor__comment)
self.assertEqual(cursor._Cursor__min,
cursor2._Cursor__min)
self.assertEqual(cursor._Cursor__max,
cursor2._Cursor__max)
# Shallow copies can so can mutate
cursor2 = copy.copy(cursor)
cursor2._Cursor__fields['cursor2'] = False
self.assertTrue('cursor2' in cursor._Cursor__fields)
# Deepcopies and shouldn't mutate
cursor3 = copy.deepcopy(cursor)
cursor3._Cursor__fields['cursor3'] = False
self.assertFalse('cursor3' in cursor._Cursor__fields)
cursor4 = cursor.clone()
cursor4._Cursor__fields['cursor4'] = False
self.assertFalse('cursor4' in cursor._Cursor__fields)
# Test memo when deepcopying queries
query = {"hello": "world"}
query["reflexive"] = query
cursor = self.db.test.find(query)
cursor2 = copy.deepcopy(cursor)
self.assertNotEqual(id(cursor._Cursor__spec),
id(cursor2._Cursor__spec))
self.assertEqual(id(cursor2._Cursor__spec['reflexive']),
id(cursor2._Cursor__spec))
self.assertEqual(len(cursor2._Cursor__spec), 2)
# Ensure hints are cloned as the correct type
cursor = self.db.test.find().hint([('z', 1), ("a", 1)])
cursor2 = copy.deepcopy(cursor)
self.assertTrue(isinstance(cursor2._Cursor__hint, SON))
self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint)
def test_deepcopy_cursor_littered_with_regexes(self):
cursor = self.db.test.find({"x": re.compile("^hmmm.*"),
"y": [re.compile("^hmm.*")],
"z": {"a": [re.compile("^hm.*")]},
re.compile("^key.*"): {"a": [re.compile("^hm.*")]}})
cursor2 = copy.deepcopy(cursor)
self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec)
def test_add_remove_option(self):
cursor = self.db.test.find()
self.assertEqual(0, cursor._Cursor__query_options())
cursor.add_option(2)
cursor2 = self.db.test.find(tailable=True)
self.assertEqual(2, cursor2._Cursor__query_options())
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
cursor.add_option(32)
cursor2 = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor2._Cursor__query_options())
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
cursor.add_option(128)
cursor2 = self.db.test.find(tailable=True,
await_data=True).add_option(128)
self.assertEqual(162, cursor2._Cursor__query_options())
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
self.assertEqual(162, cursor._Cursor__query_options())
cursor.add_option(128)
self.assertEqual(162, cursor._Cursor__query_options())
cursor.remove_option(128)
cursor2 = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor2._Cursor__query_options())
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
cursor.remove_option(32)
cursor2 = self.db.test.find(tailable=True)
self.assertEqual(2, cursor2._Cursor__query_options())
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
self.assertEqual(2, cursor._Cursor__query_options())
cursor.remove_option(32)
self.assertEqual(2, cursor._Cursor__query_options())
# Slave OK
cursor = self.db.test.find(slave_okay=True)
self.assertEqual(4, cursor._Cursor__query_options())
cursor2 = self.db.test.find().add_option(4)
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
self.assertTrue(cursor._Cursor__slave_okay)
cursor.remove_option(4)
self.assertEqual(0, cursor._Cursor__query_options())
self.assertFalse(cursor._Cursor__slave_okay)
# Timeout
cursor = self.db.test.find(timeout=False)
self.assertEqual(16, cursor._Cursor__query_options())
cursor2 = self.db.test.find().add_option(16)
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
cursor.remove_option(16)
self.assertEqual(0, cursor._Cursor__query_options())
# Tailable / Await data
cursor = self.db.test.find(tailable=True, await_data=True)
self.assertEqual(34, cursor._Cursor__query_options())
cursor2 = self.db.test.find().add_option(34)
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
cursor.remove_option(32)
self.assertEqual(2, cursor._Cursor__query_options())
# Exhaust - which mongos doesn't support
if not is_mongos(self.db.connection):
cursor = self.db.test.find(exhaust=True)
self.assertEqual(64, cursor._Cursor__query_options())
cursor2 = self.db.test.find().add_option(64)
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
self.assertTrue(cursor._Cursor__exhaust)
cursor.remove_option(64)
self.assertEqual(0, cursor._Cursor__query_options())
self.assertFalse(cursor._Cursor__exhaust)
# Partial
cursor = self.db.test.find(partial=True)
self.assertEqual(128, cursor._Cursor__query_options())
cursor2 = self.db.test.find().add_option(128)
self.assertEqual(cursor._Cursor__query_options(),
cursor2._Cursor__query_options())
cursor.remove_option(128)
self.assertEqual(0, cursor._Cursor__query_options())
def test_count_with_fields(self):
self.db.test.drop()
self.db.test.save({"x": 1})
if not version.at_least(self.db.connection, (1, 1, 3, -1)):
for _ in self.db.test.find({}, ["a"]):
self.fail()
self.assertEqual(0, self.db.test.find({}, ["a"]).count())
else:
self.assertEqual(1, self.db.test.find({}, ["a"]).count())
def test_bad_getitem(self):
self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello")
self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5)
self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None)
def test_getitem_slice_index(self):
self.db.drop_collection("test")
for i in range(100):
self.db.test.save({"i": i})
count = itertools.count
self.assertRaises(IndexError, lambda: self.db.test.find()[-1:])
self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2])
for a, b in zip(count(0), self.db.test.find()):
self.assertEqual(a, b['i'])
self.assertEqual(100, len(list(self.db.test.find()[0:])))
for a, b in zip(count(0), self.db.test.find()[0:]):
self.assertEqual(a, b['i'])
self.assertEqual(80, len(list(self.db.test.find()[20:])))
for a, b in zip(count(20), self.db.test.find()[20:]):
self.assertEqual(a, b['i'])
for a, b in zip(count(99), self.db.test.find()[99:]):
self.assertEqual(a, b['i'])
for i in self.db.test.find()[1000:]:
self.fail()
self.assertEqual(5, len(list(self.db.test.find()[20:25])))
self.assertEqual(5, len(list(self.db.test.find()[20L:25L])))
for a, b in zip(count(20), self.db.test.find()[20:25]):
self.assertEqual(a, b['i'])
self.assertEqual(80, len(list(self.db.test.find()[40:45][20:])))
for a, b in zip(count(20), self.db.test.find()[40:45][20:]):
self.assertEqual(a, b['i'])
self.assertEqual(80,
len(list(self.db.test.find()[40:45].limit(0).skip(20))
)
)
for a, b in zip(count(20),
self.db.test.find()[40:45].limit(0).skip(20)):
self.assertEqual(a, b['i'])
self.assertEqual(80,
len(list(self.db.test.find().limit(10).skip(40)[20:]))
)
for a, b in zip(count(20),
self.db.test.find().limit(10).skip(40)[20:]):
self.assertEqual(a, b['i'])
self.assertEqual(1, len(list(self.db.test.find()[:1])))
self.assertEqual(5, len(list(self.db.test.find()[:5])))
self.assertEqual(1, len(list(self.db.test.find()[99:100])))
self.assertEqual(1, len(list(self.db.test.find()[99:1000])))
self.assertEqual(0, len(list(self.db.test.find()[10:10])))
self.assertEqual(0, len(list(self.db.test.find()[:0])))
self.assertEqual(80,
len(list(self.db.test.find()[10:10].limit(0).skip(20))
)
)
self.assertRaises(IndexError, lambda: self.db.test.find()[10:8])
def test_getitem_numeric_index(self):
self.db.drop_collection("test")
for i in range(100):
self.db.test.save({"i": i})
self.assertEqual(0, self.db.test.find()[0]['i'])
self.assertEqual(50, self.db.test.find()[50]['i'])
self.assertEqual(50, self.db.test.find().skip(50)[0]['i'])
self.assertEqual(50, self.db.test.find().skip(49)[1]['i'])
self.assertEqual(50, self.db.test.find()[50L]['i'])
self.assertEqual(99, self.db.test.find()[99]['i'])
self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1)
self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100)
self.assertRaises(IndexError,
lambda x: self.db.test.find().skip(50)[x], 50)
def test_count_with_limit_and_skip(self):
if not version.at_least(self.db.connection, (1, 1, 4, -1)):
raise SkipTest("count with limit / skip requires MongoDB >= 1.1.4")
self.assertRaises(TypeError, self.db.test.find().count, "foo")
def check_len(cursor, length):
self.assertEqual(len(list(cursor)), cursor.count(True))
self.assertEqual(length, cursor.count(True))
self.db.drop_collection("test")
for i in range(100):
self.db.test.save({"i": i})
check_len(self.db.test.find(), 100)
check_len(self.db.test.find().limit(10), 10)
check_len(self.db.test.find().limit(110), 100)
check_len(self.db.test.find().skip(10), 90)
check_len(self.db.test.find().skip(110), 0)
check_len(self.db.test.find().limit(10).skip(10), 10)
check_len(self.db.test.find()[10:20], 10)
check_len(self.db.test.find().limit(10).skip(95), 5)
check_len(self.db.test.find()[95:105], 5)
def test_len(self):
self.assertRaises(TypeError, len, self.db.test.find())
def test_properties(self):
self.assertEqual(self.db.test, self.db.test.find().collection)
def set_coll():
self.db.test.find().collection = "hello"
self.assertRaises(AttributeError, set_coll)
def test_get_more(self):
db = self.db
db.drop_collection("test")
db.test.insert([{'i': i} for i in range(10)])
self.assertEqual(10, len(list(db.test.find().batch_size(5))))
def test_tailable(self):
db = self.db
db.drop_collection("test")
db.create_collection("test", capped=True, size=1000, max=3)
try:
cursor = db.test.find(tailable=True)
db.test.insert({"x": 1})
count = 0
for doc in cursor:
count += 1
self.assertEqual(1, doc["x"])
self.assertEqual(1, count)
db.test.insert({"x": 2})
count = 0
for doc in cursor:
count += 1
self.assertEqual(2, doc["x"])
self.assertEqual(1, count)
db.test.insert({"x": 3})
count = 0
for doc in cursor:
count += 1
self.assertEqual(3, doc["x"])
self.assertEqual(1, count)
# Capped rollover - the collection can never
# have more than 3 documents. Just make sure
# this doesn't raise...
db.test.insert(({"x": i} for i in xrange(4, 7)))
self.assertEqual(0, len(list(cursor)))
# and that the cursor doesn't think it's still alive.
self.assertFalse(cursor.alive)
self.assertEqual(3, db.test.count())
finally:
db.drop_collection("test")
def test_distinct(self):
if not version.at_least(self.db.connection, (1, 1, 3, 1)):
raise SkipTest("distinct with query requires MongoDB >= 1.1.3")
self.db.drop_collection("test")
self.db.test.save({"a": 1})
self.db.test.save({"a": 2})
self.db.test.save({"a": 2})
self.db.test.save({"a": 2})
self.db.test.save({"a": 3})
distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a")
distinct.sort()
self.assertEqual([1, 2], distinct)
self.db.drop_collection("test")
self.db.test.save({"a": {"b": "a"}, "c": 12})
self.db.test.save({"a": {"b": "b"}, "c": 8})
self.db.test.save({"a": {"b": "c"}, "c": 12})
self.db.test.save({"a": {"b": "c"}, "c": 8})
distinct = self.db.test.find({"c": 8}).distinct("a.b")
distinct.sort()
self.assertEqual(["b", "c"], distinct)
def test_max_scan(self):
if not version.at_least(self.db.connection, (1, 5, 1)):
raise SkipTest("maxScan requires MongoDB >= 1.5.1")
self.db.drop_collection("test")
for _ in range(100):
self.db.test.insert({})
self.assertEqual(100, len(list(self.db.test.find())))
self.assertEqual(50, len(list(self.db.test.find(max_scan=50))))
self.assertEqual(50, len(list(self.db.test.find()
.max_scan(90).max_scan(50))))
def test_with_statement(self):
if sys.version_info < (2, 6):
raise SkipTest("With statement requires Python >= 2.6")
self.db.drop_collection("test")
for _ in range(100):
self.db.test.insert({})
c1 = self.db.test.find()
exec """
with self.db.test.find() as c2:
self.assertTrue(c2.alive)
self.assertFalse(c2.alive)
with self.db.test.find() as c2:
self.assertEqual(100, len(list(c2)))
self.assertFalse(c2.alive)
"""
self.assertTrue(c1.alive)
def test_comment(self):
if is_mongos(self.client):
raise SkipTest("profile is not supported by mongos")
if not version.at_least(self.db.connection, (2, 0)):
raise SkipTest("Requires server >= 2.0")
if server_started_with_auth(self.db.connection):
raise SkipTest("SERVER-4754 - This test uses profiling.")
def run_with_profiling(func):
self.db.set_profiling_level(OFF)
self.db.system.profile.drop()
self.db.set_profiling_level(ALL)
func()
self.db.set_profiling_level(OFF)
def find():
list(self.db.test.find().comment('foo'))
op = self.db.system.profile.find({'ns': 'pymongo_test.test',
'op': 'query',
'query.$comment': 'foo'})
self.assertEqual(op.count(), 1)
run_with_profiling(find)
def count():
self.db.test.find().comment('foo').count()
op = self.db.system.profile.find({'ns': 'pymongo_test.$cmd',
'op': 'command',
'command.count': 'test',
'command.$comment': 'foo'})
self.assertEqual(op.count(), 1)
run_with_profiling(count)
def distinct():
self.db.test.find().comment('foo').distinct('type')
op = self.db.system.profile.find({'ns': 'pymongo_test.$cmd',
'op': 'command',
'command.distinct': 'test',
'command.$comment': 'foo'})
self.assertEqual(op.count(), 1)
run_with_profiling(distinct)
self.db.test.insert([{}, {}])
cursor = self.db.test.find()
cursor.next()
self.assertRaises(InvalidOperation, cursor.comment, 'hello')
self.db.system.profile.drop()
def test_cursor_transfer(self):
# This is just a test, don't try this at home...
self.db.test.remove({})
self.db.test.insert({'_id': i} for i in xrange(200))
class CManager(CursorManager):
def __init__(self, connection):
super(CManager, self).__init__(connection)
def close(self, dummy):
# Do absolutely nothing...
pass
client = self.db.connection
ctx = catch_warnings()
try:
warnings.simplefilter("ignore", DeprecationWarning)
client.set_cursor_manager(CManager)
docs = []
cursor = self.db.test.find().batch_size(10)
docs.append(cursor.next())
cursor.close()
docs.extend(cursor)
self.assertEqual(len(docs), 10)
cmd_cursor = {'id': cursor.cursor_id, 'firstBatch': []}
ccursor = CommandCursor(cursor.collection, cmd_cursor,
cursor.conn_id, retrieved=cursor.retrieved)
docs.extend(ccursor)
self.assertEqual(len(docs), 200)
finally:
client.set_cursor_manager(CursorManager)
ctx.exit()
if __name__ == "__main__":
unittest.main()