Compare commits

...

36 Commits
master ... v2.7

Author SHA1 Message Date
Bernie Hackett
0d4a2ef28a BUMP 2.7.2 2014-07-29 14:36:57 -07:00
Bernie Hackett
9ca8ad7fc9 Changelog for version 2.7.2. 2014-07-25 10:56:39 -07:00
Bernie Hackett
d32016274b PYTHON-736 - Don't close sockets on OperationFailure
This also speeds up returning exhaust sockets to the pool
when the server returns an error and fixes the tests to
run against all MongoDB versions we test against.
2014-07-24 13:22:23 -07:00
A. Jesse Jiryu Davis
9ad421a58a PYTHON-736 Fix exhaust cursor error-handling.
Connection-pool semaphore leak on server error when
creating or iterating an exhaust cursor.
2014-07-23 16:35:01 -07:00
Bernie Hackett
46a7df09bd PYTHON-738 - Clarify versionchanged line for bulk insert. 2014-07-22 16:17:14 -07:00
A. Jesse Jiryu Davis
b293b7735b import style 2014-07-15 11:08:43 -04:00
A. Jesse Jiryu Davis
e1a7bc5058 PYTHON-732 Handle network errors when adding existing credentials to sockets. 2014-07-15 09:34:23 -04:00
Luke Lovett
a1f7a5487f PYTHON-714 Work around localhost exception issues in add_user when connected to MongoDB >= 2.7.1. 2014-07-02 17:11:20 +00:00
Bernie Hackett
952953d3a1 Fix tests under pypy3. 2014-06-21 22:44:33 -07:00
A. Jesse Jiryu Davis
e904f014d9 PYTHON-709 insert _id in document after applying non-copying SONManipulators. 2014-06-19 14:50:41 -04:00
A. Jesse Jiryu Davis
79df8d799a Revert "PYTHON-710, simplify SON's equality operator."
This reverts commit 551e1e3edf.
The change did not work as expected in Jython.
2014-06-19 14:10:24 -04:00
A. Jesse Jiryu Davis
6c68762960 Use modern 'distinct' syntax in tests.
The old syntax is now an error:
https://jira.mongodb.org/browse/SERVER-12642
2014-06-18 20:34:45 -04:00
A. Jesse Jiryu Davis
69f52d0cdf PYTHON-710, simplify SON's equality operator. 2014-06-18 17:50:55 -04:00
A. Jesse Jiryu Davis
686c8fae49 PYTHON-710, SON.to_dict shouldn't change original data. 2014-06-18 17:49:57 -04:00
A. Jesse Jiryu Davis
91a56702cf PYTHON-710 test that SON.to_dict doesn't change data. 2014-06-18 17:44:12 -04:00
A. Jesse Jiryu Davis
fb207af4cf PYTHON-712 ObjectId.is_valid(None) should be False. 2014-06-18 16:19:58 -04:00
behackett
2dc840955a PYTHON-705 - Fix Bulk API legacy upsert _id compatibility
Versions of MongoDB previous to 2.6 only return the upserted
field for an upsert operation if the _id value is an ObjectId.
This patch works around that issue to ensure nUpserted counts
are correct regardless of server version.
2014-06-06 15:37:31 -07:00
behackett
6fbb4c5307 Version -> + 2014-06-06 14:47:41 -07:00
Bernie Hackett
e959aad948 BUMP 2.7.1 2014-05-23 14:45:12 -07:00
Bernie Hackett
6baba92fcf Changelog for 2.7.1 2014-05-23 13:39:00 -07:00
Bernie Hackett
47825d9d39 PYTHON-697 - Fix upsert _id backward compatibility 2014-05-20 11:28:06 -07:00
Bernie Hackett
f739025e0c PYTHON-698 - Try encoding types with broken __getattr__ methods 2014-05-16 14:02:12 -07:00
Bernie Hackett
6991b73734 Fix a few tests for MongoDB 2.7.0 2014-05-12 14:07:25 -07:00
Bernie Hackett
04ff22e3c9 Various fixes for auth tests with old mongos versions. 2014-05-02 15:15:08 -07:00
Bernie Hackett
13cd9bee6f Fix a few tests with really old mongos versions. 2014-05-01 19:20:33 -07:00
Bernie Hackett
2cc560c671 PYTHON-696 - Fix remove_user for old mongos versions. 2014-05-01 15:41:16 -07:00
Bernie Hackett
b97b85f89a PYTHON-696 - Fix user and index creation with old mongos versions. 2014-05-01 14:27:29 -07:00
Jaroslav Semančík
15511b70d8 Added Jaroslav Semančík (girogiro) to contributors 2014-05-01 11:28:49 -07:00
Jaroslav Semančík
e299c044aa Fixed wrong Python object name for UTC 2014-05-01 11:28:48 -07:00
Bernie Hackett
32279986bd PYTHON-667 - Clarify drop_index behavior when an index does not exist. 2014-05-01 10:49:16 -07:00
Bernie Hackett
348bd628aa PYTHON-690 - Various fixes to indexing docstrings. 2014-05-01 09:59:09 -07:00
Bernie Hackett
9d47f1cd3d PYTHON-691 - Fix UserWarning command issues.
Don't raise UserWarning for helpers and internal calls to
commands that do not obey read preference.
2014-05-01 09:21:52 -07:00
Bernie Hackett
d703ebb832 PYTHON-684 - Use unordered bulk for unordered test. 2014-04-29 13:44:51 -07:00
Bernie Hackett
baed02fb11 PYTHON-685 - Fix rare resource leak in _cmessage 2014-04-29 13:14:55 -07:00
Bernie Hackett
f61b0e4f59 PYTHON-684 - Ignore wnote/jnote from legacy servers.
Stop unnecessarily raising OperationFailure in the Bulk API
when a pre-2.6 server returns a result with a wnote or jnote
field.
2014-04-29 11:41:35 -07:00
Bernie Hackett
7d55d77072 Version -> + 2014-04-29 11:40:26 -07:00
31 changed files with 1431 additions and 755 deletions

View File

@ -554,6 +554,7 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
unsigned char check_keys, unsigned char check_keys,
unsigned char uuid_subtype) { unsigned char uuid_subtype) {
struct module_state *state = GETSTATE(self); struct module_state *state = GETSTATE(self);
PyObject* type_marker = NULL;
/* /*
* Don't use PyObject_IsInstance for our custom types. It causes * Don't use PyObject_IsInstance for our custom types. It causes
@ -561,26 +562,32 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
* have a _type_marker attribute, which we can switch on instead. * have a _type_marker attribute, which we can switch on instead.
*/ */
if (PyObject_HasAttrString(value, "_type_marker")) { if (PyObject_HasAttrString(value, "_type_marker")) {
long type; type_marker = PyObject_GetAttrString(value, "_type_marker");
PyObject* type_marker = PyObject_GetAttrString(value, "_type_marker"); if (type_marker == NULL) {
if (type_marker == NULL)
return 0; return 0;
}
}
/*
* Python objects with broken __getattr__ implementations could return
* arbitrary types for a call to PyObject_GetAttrString. For example
* pymongo.database.Database returns a new Collection instance for
* __getattr__ calls with names that don't match an existing attribute
* or method. In some cases "value" could be a subtype of something
* we know how to serialize. Make a best effort to encode these types.
*/
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
type = PyLong_AsLong(type_marker); if (type_marker && PyLong_CheckExact(type_marker)) {
long type = PyLong_AsLong(type_marker);
#else #else
type = PyInt_AsLong(type_marker); if (type_marker && PyInt_CheckExact(type_marker)) {
long type = PyInt_AsLong(type_marker);
#endif #endif
Py_DECREF(type_marker); Py_DECREF(type_marker);
/* /*
* Py(Long|Int)_AsLong returns -1 for error but -1 is a valid value * Py(Long|Int)_AsLong returns -1 for error but -1 is a valid value
* so we call PyErr_Occurred to differentiate. * so we call PyErr_Occurred to differentiate.
*
* One potential reason for an error is the user passing an invalid
* type that overrides __getattr__ (e.g. pymongo.collection.Collection)
*/ */
if (type == -1 && PyErr_Occurred()) { if (type == -1 && PyErr_Occurred()) {
PyErr_Clear();
_set_cannot_encode(value);
return 0; return 0;
} }
switch (type) { switch (type) {
@ -792,6 +799,8 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
return 1; return 1;
} }
} }
} else {
Py_XDECREF(type_marker);
} }
/* No _type_marker attibute or not one of our types. */ /* No _type_marker attibute or not one of our types. */
@ -1775,7 +1784,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
Py_DECREF(args); Py_DECREF(args);
goto invalid; goto invalid;
} }
utc_type = _get_object(state->UTC, "bson.tz_util", "UTC"); utc_type = _get_object(state->UTC, "bson.tz_util", "utc");
if (!utc_type || PyDict_SetItemString(kwargs, "tzinfo", utc_type) == -1) { if (!utc_type || PyDict_SetItemString(kwargs, "tzinfo", utc_type) == -1) {
Py_DECREF(replace); Py_DECREF(replace);
Py_DECREF(args); Py_DECREF(args);

View File

@ -140,6 +140,9 @@ class ObjectId(object):
.. versionadded:: 2.3 .. versionadded:: 2.3
""" """
if not oid:
return False
try: try:
ObjectId(oid) ObjectId(oid)
return True return True

View File

@ -226,12 +226,12 @@ class SON(dict):
def transform_value(value): def transform_value(value):
if isinstance(value, list): if isinstance(value, list):
return [transform_value(v) for v in value] return [transform_value(v) for v in value]
if isinstance(value, SON): elif isinstance(value, dict):
value = dict(value) return dict([
if isinstance(value, dict): (k, transform_value(v))
for k, v in value.iteritems(): for k, v in value.iteritems()])
value[k] = transform_value(v) else:
return value return value
return transform_value(dict(self)) return transform_value(dict(self))

View File

@ -10,6 +10,7 @@
.. autodata:: pymongo.GEOHAYSTACK .. autodata:: pymongo.GEOHAYSTACK
.. autodata:: pymongo.GEOSPHERE .. autodata:: pymongo.GEOSPHERE
.. autodata:: pymongo.HASHED .. autodata:: pymongo.HASHED
.. autodata:: pymongo.TEXT
.. autoclass:: pymongo.collection.Collection(database, name[, create=False[, **kwargs]]]) .. autoclass:: pymongo.collection.Collection(database, name[, create=False[, **kwargs]]])

View File

@ -1,6 +1,38 @@
Changelog Changelog
========= =========
Changes in Version 2.7.2
------------------------
Version 2.7.2 includes fixes for upsert reporting in the bulk API for MongoDB
versions previous to 2.6, a regression in how son manipulators are applied in
:meth:`~pymongo.collection.Collection.insert`, a few obscure connection pool
semaphore leaks, and a few other minor issues. See the list of issues resolved
for full details.
Issues Resolved
...............
See the `PyMongo 2.7.2 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 2.7.2 release notes in JIRA: https://jira.mongodb.org/browse/PYTHON/fixforversion/14005
Changes in Version 2.7.1
------------------------
Version 2.7.1 fixes a number of issues reported since the release of 2.7,
most importantly a fix for creating indexes and manipulating users through
mongos versions older than 2.4.0.
Issues Resolved
...............
See the `PyMongo 2.7.1 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 2.7.1 release notes in JIRA: https://jira.mongodb.org/browse/PYTHON/fixforversion/13823
Changes in Version 2.7 Changes in Version 2.7
---------------------- ----------------------

View File

@ -69,3 +69,4 @@ The following is a list of people who have contributed to
- Yuchen Ying (yegle) - Yuchen Ying (yegle)
- Kyle Erf (3rf) - Kyle Erf (3rf)
- Luke Lovett (lovett89) - Luke Lovett (lovett89)
- Jaroslav Semančík (girogiro)

View File

@ -30,6 +30,7 @@ from pymongo import ASCENDING
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.cursor import Cursor from pymongo.cursor import Cursor
from pymongo.errors import DuplicateKeyError from pymongo.errors import DuplicateKeyError
from pymongo.read_preferences import ReadPreference
try: try:
_SEEK_SET = os.SEEK_SET _SEEK_SET = os.SEEK_SET
@ -258,7 +259,8 @@ class GridIn(object):
db.error() db.error()
md5 = db.command( md5 = db.command(
"filemd5", self._id, root=self._coll.name)["md5"] "filemd5", self._id, root=self._coll.name,
read_preference=ReadPreference.PRIMARY)["md5"]
self._file["md5"] = md5 self._file["md5"] = md5
self._file["length"] = self._position self._file["length"] = self._position

View File

@ -27,7 +27,7 @@ GEO2D = "2d"
.. note:: Geo-spatial indexing requires server version **>= 1.3.3**. .. note:: Geo-spatial indexing requires server version **>= 1.3.3**.
.. _geospatial index: http://docs.mongodb.org/manual/core/geospatial-indexes/ .. _geospatial index: http://docs.mongodb.org/manual/core/2d/
""" """
GEOHAYSTACK = "geoHaystack" GEOHAYSTACK = "geoHaystack"
@ -37,7 +37,7 @@ GEOHAYSTACK = "geoHaystack"
.. note:: Geo-spatial indexing requires server version **>= 1.5.6**. .. note:: Geo-spatial indexing requires server version **>= 1.5.6**.
.. _haystack index: http://docs.mongodb.org/manual/core/geospatial-indexes/#haystack-indexes .. _haystack index: http://docs.mongodb.org/manual/core/geohaystack/
""" """
GEOSPHERE = "2dsphere" GEOSPHERE = "2dsphere"
@ -47,7 +47,7 @@ GEOSPHERE = "2dsphere"
.. note:: 2dsphere indexing requires server version **>= 2.4.0**. .. note:: 2dsphere indexing requires server version **>= 2.4.0**.
.. _spherical geospatial index: http://docs.mongodb.org/manual/release-notes/2.4/#new-geospatial-indexes-with-geojson-and-improved-spherical-geometry .. _spherical geospatial index: http://docs.mongodb.org/manual/core/2dsphere/
""" """
HASHED = "hashed" HASHED = "hashed"
@ -57,7 +57,17 @@ HASHED = "hashed"
.. note:: hashed indexing requires server version **>= 2.4.0**. .. note:: hashed indexing requires server version **>= 2.4.0**.
.. _hashed index: http://docs.mongodb.org/manual/release-notes/2.4/#new-hashed-index-and-sharding-with-a-hashed-shard-key .. _hashed index: http://docs.mongodb.org/manual/core/index-hashed/
"""
TEXT = "text"
"""Index specifier for a `text index`_.
.. versionadded:: 2.7.1
.. note:: text search requires server version **>= 2.4.0**.
.. _text index: http://docs.mongodb.org/manual/core/index-text/
""" """
OFF = 0 OFF = 0
@ -67,7 +77,7 @@ SLOW_ONLY = 1
ALL = 2 ALL = 2
"""Profile all operations.""" """Profile all operations."""
version_tuple = (2, 7) version_tuple = (2, 7, 2)
def get_version_string(): def get_version_string():
if isinstance(version_tuple[-1], basestring): if isinstance(version_tuple[-1], basestring):

View File

@ -1112,8 +1112,10 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
*/ */
buffer_update_position(buffer, sub_doc_begin); buffer_update_position(buffer, sub_doc_begin);
if (!buffer_write_bytes(buffer, "\x00\x00", 2)) if (!buffer_write_bytes(buffer, "\x00\x00", 2)) {
buffer_free(new_buffer);
goto cmditerfail; goto cmditerfail;
}
result = _send_write_command(client, buffer, result = _send_write_command(client, buffer,
lst_len_loc, cmd_len_loc, &errors); lst_len_loc, cmd_len_loc, &errors);

View File

@ -82,16 +82,6 @@ def _make_error(index, code, errmsg, operation):
def _merge_legacy(run, full_result, result, index): def _merge_legacy(run, full_result, result, index):
"""Merge a result from a legacy opcode into the full results. """Merge a result from a legacy opcode into the full results.
""" """
# MongoDB 2.6 returns {'ok': 0, 'code': 2, ...} if the j write
# concern option is used with --nojournal or w > 1 is used with
# a standalone mongod instance. Raise immediately here for
# consistency when talking to older servers. Since these are
# configuration errors related to write concern the entire batch
# will fail.
note = result.get("jnote", result.get("wnote"))
if note:
raise OperationFailure(note, _BAD_VALUE, result)
affected = result.get('n', 0) affected = result.get('n', 0)
errmsg = result.get("errmsg", result.get("err", "")) errmsg = result.get("errmsg", result.get("err", ""))
@ -111,13 +101,25 @@ def _merge_legacy(run, full_result, result, index):
if run.op_type == _INSERT: if run.op_type == _INSERT:
full_result['nInserted'] += 1 full_result['nInserted'] += 1
elif run.op_type == _UPDATE: elif run.op_type == _UPDATE:
if "upserted" in result: if "upserted" in result:
doc = {u"index": run.index(index), u"_id": result["upserted"]} doc = {u"index": run.index(index), u"_id": result["upserted"]}
full_result["upserted"].append(doc) full_result["upserted"].append(doc)
full_result['nUpserted'] += affected full_result['nUpserted'] += affected
# Versions of MongoDB before 2.6 don't return the _id for an
# upsert if _id is not an ObjectId.
elif result.get("updatedExisting") == False and affected == 1:
op = run.ops[index]
# If _id is in both the update document *and* the query spec
# the update document _id takes precedence.
_id = op['u'].get('_id', op['q'].get('_id'))
doc = {u"index": run.index(index), u"_id": _id}
full_result["upserted"].append(doc)
full_result['nUpserted'] += affected
else: else:
full_result['nMatched'] += affected full_result['nMatched'] += affected
elif run.op_type == _DELETE: elif run.op_type == _DELETE:
full_result['nRemoved'] += affected full_result['nRemoved'] += affected

View File

@ -28,6 +28,7 @@ from pymongo.cursor import Cursor
from pymongo.errors import InvalidName, OperationFailure from pymongo.errors import InvalidName, OperationFailure
from pymongo.helpers import _check_write_command_response from pymongo.helpers import _check_write_command_response
from pymongo.message import _INSERT, _UPDATE, _DELETE from pymongo.message import _INSERT, _UPDATE, _DELETE
from pymongo.read_preferences import ReadPreference
try: try:
@ -125,9 +126,12 @@ class Collection(common.BaseObject):
if options: if options:
if "size" in options: if "size" in options:
options["size"] = float(options["size"]) options["size"] = float(options["size"])
self.__database.command("create", self.__name, **options) self.__database.command("create", self.__name,
read_preference=ReadPreference.PRIMARY,
**options)
else: else:
self.__database.command("create", self.__name) self.__database.command("create", self.__name,
read_preference=ReadPreference.PRIMARY)
def __getattr__(self, name): def __getattr__(self, name):
"""Get a sub-collection of this collection by name. """Get a sub-collection of this collection by name.
@ -352,7 +356,7 @@ class Collection(common.BaseObject):
Support for passing `getLastError` options as keyword Support for passing `getLastError` options as keyword
arguments. arguments.
.. versionchanged:: 1.1 .. versionchanged:: 1.1
Bulk insert works with any iterable Bulk insert works with an iterable sequence of documents.
.. mongodoc:: insert .. mongodoc:: insert
""" """
@ -374,11 +378,14 @@ class Collection(common.BaseObject):
def gen(): def gen():
db = self.__database db = self.__database
for doc in docs: for doc in docs:
# Apply user-configured SON manipulators. This order of
# operations is required for backwards compatibility,
# see PYTHON-709.
doc = db._apply_incoming_manipulators(doc, self)
if '_id' not in doc: if '_id' not in doc:
doc['_id'] = ObjectId() doc['_id'] = ObjectId()
# Apply user-configured SON manipulators. doc = db._apply_incoming_copying_manipulators(doc, self)
doc = db._fix_incoming(doc, self)
ids.append(doc['_id']) ids.append(doc['_id'])
yield doc yield doc
else: else:
@ -550,6 +557,10 @@ class Collection(common.BaseObject):
result['updatedExisting'] = True result['updatedExisting'] = True
else: else:
result['updatedExisting'] = False result['updatedExisting'] = False
# MongoDB >= 2.6.0 returns the upsert _id in an array
# element. Break it out for backward compatibility.
if isinstance(result.get('upserted'), list):
result['upserted'] = result['upserted'][0]['_id']
return result return result
@ -942,32 +953,39 @@ class Collection(common.BaseObject):
Takes either a single key or a list of (key, direction) pairs. Takes either a single key or a list of (key, direction) pairs.
The key(s) must be an instance of :class:`basestring` The key(s) must be an instance of :class:`basestring`
(:class:`str` in python 3), and the direction(s) must be one of (:class:`str` in python 3), and the direction(s) should be one of
(:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`, :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`,
:data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`). :data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`,
:data:`~pymongo.TEXT`).
To create a single key index on the key ``'mike'`` we just use To create a simple ascending index on the key ``'mike'`` we just
a string argument: use a string argument::
>>> my_collection.create_index("mike") >>> my_collection.create_index("mike")
For a compound index on ``'mike'`` descending and ``'eliot'`` For a compound index on ``'mike'`` descending and ``'eliot'``
ascending we need to use a list of tuples: ascending we need to use a list of tuples::
>>> my_collection.create_index([("mike", pymongo.DESCENDING), >>> my_collection.create_index([("mike", pymongo.DESCENDING),
... ("eliot", pymongo.ASCENDING)]) ... ("eliot", pymongo.ASCENDING)])
All optional index creation parameters should be passed as All optional index creation parameters should be passed as
keyword arguments to this method. Valid options include: keyword arguments to this method. For example::
>>> my_collection.create_index([("mike", pymongo.DESCENDING)],
... background=True)
Valid options include:
- `name`: custom name to use for this index - if none is - `name`: custom name to use for this index - if none is
given, a name will be generated given, a name will be generated
- `unique`: should this index guarantee uniqueness? - `unique`: if ``True`` creates a unique constraint on the index
- `dropDups` or `drop_dups`: should we drop duplicates - `dropDups` or `drop_dups`: if ``True`` duplicate values are dropped
- `background`: if this index should be created in the during index creation when creating a unique index
- `background`: if ``True`` this index should be created in the
background background
- `sparse`: if True, omit from the index any documents that lack - `sparse`: if ``True``, omit from the index any documents that lack
the indexed field the indexed field
- `bucketSize` or `bucket_size`: for use with geoHaystack indexes. - `bucketSize` or `bucket_size`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity Number of documents to group together within a certain proximity
@ -1037,9 +1055,11 @@ class Collection(common.BaseObject):
index.update(kwargs) index.update(kwargs)
try: try:
self.__database.command('createIndexes', self.name, indexes=[index]) self.__database.command('createIndexes', self.name,
read_preference=ReadPreference.PRIMARY,
indexes=[index])
except OperationFailure, exc: except OperationFailure, exc:
if exc.code in (59, None): if exc.code in common.COMMAND_NOT_FOUND_CODES:
index["ns"] = self.__full_name index["ns"] = self.__full_name
self.__database.system.indexes.insert(index, manipulate=False, self.__database.system.indexes.insert(index, manipulate=False,
check_keys=False, check_keys=False,
@ -1057,11 +1077,13 @@ class Collection(common.BaseObject):
Takes either a single key or a list of (key, direction) pairs. Takes either a single key or a list of (key, direction) pairs.
The key(s) must be an instance of :class:`basestring` The key(s) must be an instance of :class:`basestring`
(:class:`str` in python 3), and the direction(s) must be one of (:class:`str` in python 3), and the direction(s) should be one of
(:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`, :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`,
:data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`). :data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`,
See :meth:`create_index` for a detailed example. :data:`pymongo.TEXT`).
See :meth:`create_index` for detailed examples.
Unlike :meth:`create_index`, which attempts to create an index Unlike :meth:`create_index`, which attempts to create an index
unconditionally, :meth:`ensure_index` takes advantage of some unconditionally, :meth:`ensure_index` takes advantage of some
@ -1087,12 +1109,12 @@ class Collection(common.BaseObject):
- `name`: custom name to use for this index - if none is - `name`: custom name to use for this index - if none is
given, a name will be generated given, a name will be generated
- `unique`: should this index guarantee uniqueness? - `unique`: if ``True`` creates a unique constraint on the index
- `dropDups` or `drop_dups`: should we drop duplicates - `dropDups` or `drop_dups`: if ``True`` duplicate values are dropped
during index creation when creating a unique index? during index creation when creating a unique index
- `background`: if this index should be created in the - `background`: if ``True`` this index should be created in the
background background
- `sparse`: if True, omit from the index any documents that lack - `sparse`: if ``True``, omit from the index any documents that lack
the indexed field the indexed field
- `bucketSize` or `bucket_size`: for use with geoHaystack indexes. - `bucketSize` or `bucket_size`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity Number of documents to group together within a certain proximity
@ -1159,7 +1181,8 @@ class Collection(common.BaseObject):
"""Drops the specified index on this collection. """Drops the specified index on this collection.
Can be used on non-existant collections or collections with no Can be used on non-existant collections or collections with no
indexes. Raises OperationFailure on an error. `index_or_name` indexes. Raises OperationFailure on an error (e.g. trying to
drop an index that does not exist). `index_or_name`
can be either an index name (as returned by `create_index`), can be either an index name (as returned by `create_index`),
or an index specifier (as passed to `create_index`). An index or an index specifier (as passed to `create_index`). An index
specifier should be a list of (key, direction) pairs. Raises specifier should be a list of (key, direction) pairs. Raises
@ -1183,7 +1206,9 @@ class Collection(common.BaseObject):
self.__database.connection._purge_index(self.__database.name, self.__database.connection._purge_index(self.__database.name,
self.__name, name) self.__name, name)
self.__database.command("dropIndexes", self.__name, index=name, self.__database.command("dropIndexes", self.__name,
read_preference=ReadPreference.PRIMARY,
index=name,
allowable_errors=["ns not found"]) allowable_errors=["ns not found"])
def reindex(self): def reindex(self):
@ -1195,7 +1220,8 @@ class Collection(common.BaseObject):
.. versionadded:: 1.11+ .. versionadded:: 1.11+
""" """
return self.__database.command("reIndex", self.__name) return self.__database.command("reIndex", self.__name,
read_preference=ReadPreference.PRIMARY)
def index_information(self): def index_information(self):
"""Get information on this collection's indexes. """Get information on this collection's indexes.
@ -1416,9 +1442,10 @@ class Collection(common.BaseObject):
raise InvalidName("collection names must not contain '$'") raise InvalidName("collection names must not contain '$'")
new_name = "%s.%s" % (self.__database.name, new_name) new_name = "%s.%s" % (self.__database.name, new_name)
self.__database.connection.admin.command("renameCollection", client = self.__database.connection
self.__full_name, client.admin.command("renameCollection", self.__full_name,
to=new_name, **kwargs) read_preference=ReadPreference.PRIMARY,
to=new_name, **kwargs)
def distinct(self, key): def distinct(self, key):
"""Get a list of distinct values for `key` among all documents """Get a list of distinct values for `key` among all documents
@ -1645,6 +1672,7 @@ class Collection(common.BaseObject):
out = self.__database.command("findAndModify", self.__name, out = self.__database.command("findAndModify", self.__name,
allowable_errors=[no_obj_error], allowable_errors=[no_obj_error],
read_preference=ReadPreference.PRIMARY,
uuid_subtype=self.uuid_subtype, uuid_subtype=self.uuid_subtype,
**kwargs) **kwargs)

View File

@ -47,6 +47,14 @@ MAX_WRITE_BATCH_SIZE = 1000
MIN_SUPPORTED_WIRE_VERSION = 0 MIN_SUPPORTED_WIRE_VERSION = 0
MAX_SUPPORTED_WIRE_VERSION = 2 MAX_SUPPORTED_WIRE_VERSION = 2
# mongod/s 2.6 and above return code 59 when a
# command doesn't exist. mongod versions previous
# to 2.6 and mongos 2.4.x return no error code
# when a command does exist. mongos versions previous
# to 2.4.0 return code 13390 when a command does not
# exist.
COMMAND_NOT_FOUND_CODES = (59, 13390, None)
def raise_config_error(key, dummy): def raise_config_error(key, dummy):
"""Raise ConfigurationError with the given key name.""" """Raise ConfigurationError with the given key name."""

View File

@ -22,8 +22,8 @@ from bson.son import SON
from pymongo import helpers, message, read_preferences from pymongo import helpers, message, read_preferences
from pymongo.read_preferences import ReadPreference, secondary_ok_commands from pymongo.read_preferences import ReadPreference, secondary_ok_commands
from pymongo.errors import (AutoReconnect, from pymongo.errors import (AutoReconnect,
CursorNotFound, InvalidOperation,
InvalidOperation) OperationFailure)
_QUERY_OPTIONS = { _QUERY_OPTIONS = {
"tailable_cursor": 2, "tailable_cursor": 2,
@ -56,6 +56,15 @@ class _SocketManager:
self.pool.maybe_return_socket(self.sock) self.pool.maybe_return_socket(self.sock)
self.sock, self.pool = None, None self.sock, self.pool = None, None
def error(self):
"""Clean up after an error on the managed socket.
"""
if self.sock:
self.sock.close()
# Return the closed socket to avoid a semaphore leak in the pool.
self.close()
# TODO might be cool to be able to do find().include("foo") or # TODO might be cool to be able to do find().include("foo") or
# find().exclude(["bar", "baz"]) or find().slice("a", 1, 2) as an # find().exclude(["bar", "baz"]) or find().slice("a", 1, 2) as an
@ -914,8 +923,14 @@ class Cursor(object):
# due to a socket timeout. # due to a socket timeout.
self.__killed = True self.__killed = True
raise raise
else: # exhaust cursor - no getMore message else:
response = client._exhaust_next(self.__exhaust_mgr.sock) # Exhaust cursor - no getMore message.
try:
response = client._exhaust_next(self.__exhaust_mgr.sock)
except AutoReconnect:
self.__killed = True
self.__exhaust_mgr.error()
raise
try: try:
response = helpers._unpack_response(response, self.__id, response = helpers._unpack_response(response, self.__id,
@ -923,8 +938,10 @@ class Cursor(object):
self.__tz_aware, self.__tz_aware,
self.__uuid_subtype, self.__uuid_subtype,
self.__compile_re) self.__compile_re)
except CursorNotFound: except OperationFailure:
self.__killed = True self.__killed = True
# Make sure exhaust socket is returned immediately, if necessary.
self.__die()
# If this is a tailable cursor the error is likely # If this is a tailable cursor the error is likely
# due to capped collection roll over. Setting # due to capped collection roll over. Setting
# self.__killed to True ensures Cursor.alive will be # self.__killed to True ensures Cursor.alive will be
@ -936,8 +953,11 @@ class Cursor(object):
# Don't send kill cursors to another server after a "not master" # Don't send kill cursors to another server after a "not master"
# error. It's completely pointless. # error. It's completely pointless.
self.__killed = True self.__killed = True
# Make sure exhaust socket is returned immediately, if necessary.
self.__die()
client.disconnect() client.disconnect()
raise raise
self.__id = response["cursor_id"] self.__id = response["cursor_id"]
# starting from doesn't get set on getmore's for tailable cursors # starting from doesn't get set on getmore's for tailable cursors

View File

@ -26,7 +26,9 @@ from pymongo.errors import (CollectionInvalid,
ConfigurationError, ConfigurationError,
InvalidName, InvalidName,
OperationFailure) OperationFailure)
from pymongo import read_preferences as rp from pymongo.read_preferences import (modes,
secondary_ok_commands,
ReadPreference)
def _check_name(name): def _check_name(name):
@ -245,6 +247,16 @@ class Database(common.BaseObject):
return Collection(self, name, **opts) return Collection(self, name, **opts)
def _apply_incoming_manipulators(self, son, collection):
for manipulator in self.__incoming_manipulators:
son = manipulator.transform_incoming(son, collection)
return son
def _apply_incoming_copying_manipulators(self, son, collection):
for manipulator in self.__incoming_copying_manipulators:
son = manipulator.transform_incoming(son, collection)
return son
def _fix_incoming(self, son, collection): def _fix_incoming(self, son, collection):
"""Apply manipulators to an incoming SON object before it gets stored. """Apply manipulators to an incoming SON object before it gets stored.
@ -252,10 +264,8 @@ class Database(common.BaseObject):
- `son`: the son object going into the database - `son`: the son object going into the database
- `collection`: the collection the son object is being saved in - `collection`: the collection the son object is being saved in
""" """
for manipulator in self.__incoming_manipulators: son = self._apply_incoming_manipulators(son, collection)
son = manipulator.transform_incoming(son, collection) son = self._apply_incoming_copying_manipulators(son, collection)
for manipulator in self.__incoming_copying_manipulators:
son = manipulator.transform_incoming(son, collection)
return son return son
def _fix_outgoing(self, son, collection): def _fix_outgoing(self, son, collection):
@ -282,7 +292,7 @@ class Database(common.BaseObject):
command_name = command.keys()[0].lower() command_name = command.keys()[0].lower()
must_use_master = kwargs.pop('_use_master', False) must_use_master = kwargs.pop('_use_master', False)
if command_name not in rp.secondary_ok_commands: if command_name not in secondary_ok_commands:
must_use_master = True must_use_master = True
# Special-case: mapreduce can go to secondaries only if inline # Special-case: mapreduce can go to secondaries only if inline
@ -323,13 +333,13 @@ class Database(common.BaseObject):
command.update(kwargs) command.update(kwargs)
# Warn if must_use_master will override read_preference. # Warn if must_use_master will override read_preference.
if (extra_opts['read_preference'] != rp.ReadPreference.PRIMARY and if (extra_opts['read_preference'] != ReadPreference.PRIMARY and
extra_opts['_must_use_master']): extra_opts['_must_use_master']):
warnings.warn("%s does not support %s read preference " warnings.warn("%s does not support %s read preference "
"and will be routed to the primary instead." % "and will be routed to the primary instead." %
(command_name, (command_name,
rp.modes[extra_opts['read_preference']]), modes[extra_opts['read_preference']]),
UserWarning) UserWarning, stacklevel=3)
cursor = self["$cmd"].find(command, **extra_opts).limit(-1) cursor = self["$cmd"].find(command, **extra_opts).limit(-1)
for doc in cursor: for doc in cursor:
@ -466,7 +476,8 @@ class Database(common.BaseObject):
self.__connection._purge_index(self.__name, name) self.__connection._purge_index(self.__name, name)
self.command("drop", unicode(name), allowable_errors=["ns not found"]) self.command("drop", unicode(name), allowable_errors=["ns not found"],
read_preference=ReadPreference.PRIMARY)
def validate_collection(self, name_or_collection, def validate_collection(self, name_or_collection,
scandata=False, full=False): scandata=False, full=False):
@ -504,7 +515,8 @@ class Database(common.BaseObject):
"%s or Collection" % (basestring.__name__,)) "%s or Collection" % (basestring.__name__,))
result = self.command("validate", unicode(name), result = self.command("validate", unicode(name),
scandata=scandata, full=full) scandata=scandata, full=full,
read_preference=ReadPreference.PRIMARY)
valid = True valid = True
# Pre 1.9 results # Pre 1.9 results
@ -553,7 +565,8 @@ class Database(common.BaseObject):
.. mongodoc:: profiling .. mongodoc:: profiling
""" """
result = self.command("profile", -1) result = self.command("profile", -1,
read_preference=ReadPreference.PRIMARY)
assert result["was"] >= 0 and result["was"] <= 2 assert result["was"] >= 0 and result["was"] <= 2
return result["was"] return result["was"]
@ -593,9 +606,11 @@ class Database(common.BaseObject):
raise TypeError("slow_ms must be an integer") raise TypeError("slow_ms must be an integer")
if slow_ms is not None: if slow_ms is not None:
self.command("profile", level, slowms=slow_ms) self.command("profile", level, slowms=slow_ms,
read_preference=ReadPreference.PRIMARY)
else: else:
self.command("profile", level) self.command("profile", level,
read_preference=ReadPreference.PRIMARY)
def profiling_info(self): def profiling_info(self):
"""Returns a list containing current profiling information. """Returns a list containing current profiling information.
@ -610,7 +625,8 @@ class Database(common.BaseObject):
Return None if the last operation was error-free. Otherwise return the Return None if the last operation was error-free. Otherwise return the
error that occurred. error that occurred.
""" """
error = self.command("getlasterror") error = self.command("getlasterror",
read_preference=ReadPreference.PRIMARY)
error_msg = error.get("err", "") error_msg = error.get("err", "")
if error_msg is None: if error_msg is None:
return None return None
@ -623,7 +639,8 @@ class Database(common.BaseObject):
Returns a SON object with status information. Returns a SON object with status information.
""" """
return self.command("getlasterror") return self.command("getlasterror",
read_preference=ReadPreference.PRIMARY)
def previous_error(self): def previous_error(self):
"""Get the most recent error to have occurred on this database. """Get the most recent error to have occurred on this database.
@ -632,7 +649,8 @@ class Database(common.BaseObject):
`Database.reset_error_history`. Returns None if no such errors have `Database.reset_error_history`. Returns None if no such errors have
occurred. occurred.
""" """
error = self.command("getpreverror") error = self.command("getpreverror",
read_preference=ReadPreference.PRIMARY)
if error.get("err", 0) is None: if error.get("err", 0) is None:
return None return None
return error return error
@ -643,7 +661,8 @@ class Database(common.BaseObject):
Calls to `Database.previous_error` will only return errors that have Calls to `Database.previous_error` will only return errors that have
occurred since the most recent call to this method. occurred since the most recent call to this method.
""" """
self.command("reseterror") self.command("reseterror",
read_preference=ReadPreference.PRIMARY)
def __iter__(self): def __iter__(self):
return self return self
@ -697,7 +716,8 @@ class Database(common.BaseObject):
else: else:
command_name = "updateUser" command_name = "updateUser"
self.command(command_name, name, **opts) self.command(command_name, name,
read_preference=ReadPreference.PRIMARY, **opts)
def _legacy_add_user(self, name, password, read_only, **kwargs): def _legacy_add_user(self, name, password, read_only, **kwargs):
"""Uses v1 system to add users, i.e. saving to system.users. """Uses v1 system to add users, i.e. saving to system.users.
@ -716,6 +736,11 @@ class Database(common.BaseObject):
# See SERVER-4225 for more information. # See SERVER-4225 for more information.
if 'login' in str(exc): if 'login' in str(exc):
pass pass
# First admin user add fails gle from mongos 2.0.x
# and 2.2.x.
elif (exc.details and
'getlasterror' in exc.details.get('note', '')):
pass
else: else:
raise raise
@ -763,21 +788,22 @@ class Database(common.BaseObject):
"read_only and roles together") "read_only and roles together")
try: try:
uinfo = self.command("usersInfo", name) uinfo = self.command("usersInfo", name,
read_preference=ReadPreference.PRIMARY)
self._create_or_update_user(
(not uinfo["users"]), name, password, read_only, **kwargs)
except OperationFailure, exc: except OperationFailure, exc:
# MongoDB >= 2.5.3 requires the use of commands to manage # MongoDB >= 2.5.3 requires the use of commands to manage
# users. "No such command" error didn't return an error # users.
# code (59) before MongoDB 2.4.7 so we assume that an error if exc.code in common.COMMAND_NOT_FOUND_CODES:
# code of None means the userInfo command doesn't exist and
# we should fall back to the legacy add user code.
if exc.code in (59, None):
self._legacy_add_user(name, password, read_only, **kwargs) self._legacy_add_user(name, password, read_only, **kwargs)
return # Unauthorized. MongoDB >= 2.7.1 has a narrow localhost exception,
raise # and we must add a user before sending commands.
elif exc.code == 13:
# Create the user if not found in uinfo, otherwise update one. self._create_or_update_user(
self._create_or_update_user( True, name, password, read_only, **kwargs)
(not uinfo["users"]), name, password, read_only, **kwargs) else:
raise
def remove_user(self, name): def remove_user(self, name):
"""Remove user `name` from this :class:`Database`. """Remove user `name` from this :class:`Database`.
@ -793,10 +819,11 @@ class Database(common.BaseObject):
try: try:
self.command("dropUser", name, self.command("dropUser", name,
read_preference=ReadPreference.PRIMARY,
writeConcern=self._get_wc_override()) writeConcern=self._get_wc_override())
except OperationFailure, exc: except OperationFailure, exc:
# See comment in add_user try / except above. # See comment in add_user try / except above.
if exc.code in (59, None): if exc.code in common.COMMAND_NOT_FOUND_CODES:
self.system.users.remove({"user": name}, self.system.users.remove({"user": name},
**self._get_wc_override()) **self._get_wc_override())
return return
@ -930,7 +957,9 @@ class Database(common.BaseObject):
if not isinstance(code, Code): if not isinstance(code, Code):
code = Code(code) code = Code(code)
result = self.command("$eval", code, args=args) result = self.command("$eval", code,
read_preference=ReadPreference.PRIMARY,
args=args)
return result.get("retval", None) return result.get("retval", None)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):

View File

@ -61,6 +61,9 @@ from pymongo.errors import (AutoReconnect,
InvalidURI, InvalidURI,
OperationFailure) OperationFailure)
from pymongo.member import Member from pymongo.member import Member
from pymongo.read_preferences import ReadPreference
EMPTY = b("") EMPTY = b("")
@ -670,13 +673,16 @@ class MongoClient(common.BaseObject):
'max_write_batch_size', common.MAX_WRITE_BATCH_SIZE) 'max_write_batch_size', common.MAX_WRITE_BATCH_SIZE)
def __simple_command(self, sock_info, dbname, spec): def __simple_command(self, sock_info, dbname, spec):
"""Send a command to the server. """Send a command to the server. May raise AutoReconnect.
""" """
rqst_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec) rqst_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec)
start = time.time() start = time.time()
try: try:
sock_info.sock.sendall(msg) sock_info.sock.sendall(msg)
response = self.__receive_message_on_socket(1, rqst_id, sock_info) response = self.__receive_message_on_socket(1, rqst_id, sock_info)
except socket.error, e:
sock_info.close()
raise AutoReconnect(e)
except: except:
sock_info.close() sock_info.close()
raise raise
@ -913,7 +919,7 @@ class MongoClient(common.BaseObject):
"%s %s" % (host_details, str(why))) "%s %s" % (host_details, str(why)))
try: try:
self.__check_auth(sock_info) self.__check_auth(sock_info)
except OperationFailure: except:
connection_pool.maybe_return_socket(sock_info) connection_pool.maybe_return_socket(sock_info)
raise raise
return sock_info return sock_info
@ -1050,7 +1056,7 @@ class MongoClient(common.BaseObject):
# for some errors. # for some errors.
if "errObjects" in result: if "errObjects" in result:
for errobj in result["errObjects"]: for errobj in result["errObjects"]:
if errobj["err"] == error_msg: if errobj.get("err") == error_msg:
details = errobj details = errobj
break break
@ -1186,27 +1192,35 @@ class MongoClient(common.BaseObject):
sock_info = self.__socket(member) sock_info = self.__socket(member)
exhaust = kwargs.get('exhaust') exhaust = kwargs.get('exhaust')
try: try:
try: if not exhaust and "network_timeout" in kwargs:
if not exhaust and "network_timeout" in kwargs: sock_info.sock.settimeout(kwargs["network_timeout"])
sock_info.sock.settimeout(kwargs["network_timeout"])
response = self.__send_and_receive(message, sock_info)
if not exhaust: response = self.__send_and_receive(message, sock_info)
if "network_timeout" in kwargs:
sock_info.sock.settimeout(self.__net_timeout)
return (None, (response, sock_info, member.pool))
except (ConnectionFailure, socket.error), e:
self.disconnect()
raise AutoReconnect(str(e))
finally:
if not exhaust: if not exhaust:
if "network_timeout" in kwargs:
sock_info.sock.settimeout(self.__net_timeout)
member.pool.maybe_return_socket(sock_info) member.pool.maybe_return_socket(sock_info)
return (None, (response, sock_info, member.pool))
except (ConnectionFailure, socket.error), e:
self.disconnect()
member.pool.maybe_return_socket(sock_info)
raise AutoReconnect(str(e))
except:
member.pool.maybe_return_socket(sock_info)
raise
def _exhaust_next(self, sock_info): def _exhaust_next(self, sock_info):
"""Used with exhaust cursors to get the next batch off the socket. """Used with exhaust cursors to get the next batch off the socket.
Can raise AutoReconnect.
""" """
return self.__receive_message_on_socket(1, None, sock_info) try:
return self.__receive_message_on_socket(1, None, sock_info)
except socket.error, e:
raise AutoReconnect(str(e))
def start_request(self): def start_request(self):
"""Ensure the current thread or greenlet always uses the same socket """Ensure the current thread or greenlet always uses the same socket
@ -1337,13 +1351,15 @@ class MongoClient(common.BaseObject):
def server_info(self): def server_info(self):
"""Get information about the MongoDB server we're connected to. """Get information about the MongoDB server we're connected to.
""" """
return self.admin.command("buildinfo") return self.admin.command("buildinfo",
read_preference=ReadPreference.PRIMARY)
def database_names(self): def database_names(self):
"""Get a list of the names of all databases on the connected server. """Get a list of the names of all databases on the connected server.
""" """
return [db["name"] for db in return [db["name"] for db in
self.admin.command("listDatabases")["databases"]] self.admin.command("listDatabases",
read_preference=ReadPreference.PRIMARY)["databases"]]
def drop_database(self, name_or_database): def drop_database(self, name_or_database):
"""Drop a database. """Drop a database.
@ -1365,7 +1381,8 @@ class MongoClient(common.BaseObject):
"%s or Database" % (basestring.__name__,)) "%s or Database" % (basestring.__name__,))
self._purge_index(name) self._purge_index(name)
self[name].command("dropDatabase") self[name].command("dropDatabase",
read_preference=ReadPreference.PRIMARY)
def copy_database(self, from_name, to_name, def copy_database(self, from_name, to_name,
from_host=None, username=None, password=None): from_host=None, username=None, password=None):
@ -1413,12 +1430,15 @@ class MongoClient(common.BaseObject):
if username is not None: if username is not None:
nonce = self.admin.command("copydbgetnonce", nonce = self.admin.command("copydbgetnonce",
fromhost=from_host)["nonce"] read_preference=ReadPreference.PRIMARY,
fromhost=from_host)["nonce"]
command["username"] = username command["username"] = username
command["nonce"] = nonce command["nonce"] = nonce
command["key"] = auth._auth_key(nonce, username, password) command["key"] = auth._auth_key(nonce, username, password)
return self.admin.command("copydb", **command) return self.admin.command("copydb",
read_preference=ReadPreference.PRIMARY,
**command)
finally: finally:
self.end_request() self.end_request()
@ -1467,7 +1487,8 @@ class MongoClient(common.BaseObject):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
self.admin.command("fsync", **kwargs) self.admin.command("fsync",
read_preference=ReadPreference.PRIMARY, **kwargs)
def unlock(self): def unlock(self):
"""Unlock a previously locked server. """Unlock a previously locked server.

View File

@ -59,6 +59,7 @@ from pymongo.errors import (AutoReconnect,
DuplicateKeyError, DuplicateKeyError,
OperationFailure, OperationFailure,
InvalidOperation) InvalidOperation)
from pymongo.read_preferences import ReadPreference
from pymongo.thread_util import DummyLock from pymongo.thread_util import DummyLock
EMPTY = b("") EMPTY = b("")
@ -1710,8 +1711,13 @@ class MongoReplicaSetClient(common.BaseObject):
def _exhaust_next(self, sock_info): def _exhaust_next(self, sock_info):
"""Used with exhaust cursors to get the next batch off the socket. """Used with exhaust cursors to get the next batch off the socket.
Can raise AutoReconnect.
""" """
return self.__recv_msg(1, None, sock_info) try:
return self.__recv_msg(1, None, sock_info)
except socket.error, e:
raise AutoReconnect(str(e))
def start_request(self): def start_request(self):
"""Ensure the current thread or greenlet always uses the same socket """Ensure the current thread or greenlet always uses the same socket
@ -1832,13 +1838,15 @@ class MongoReplicaSetClient(common.BaseObject):
def server_info(self): def server_info(self):
"""Get information about the MongoDB primary we're connected to. """Get information about the MongoDB primary we're connected to.
""" """
return self.admin.command("buildinfo") return self.admin.command("buildinfo",
read_preference=ReadPreference.PRIMARY)
def database_names(self): def database_names(self):
"""Get a list of the names of all databases on the connected server. """Get a list of the names of all databases on the connected server.
""" """
return [db["name"] for db in return [db["name"] for db in
self.admin.command("listDatabases")["databases"]] self.admin.command("listDatabases",
read_preference=ReadPreference.PRIMARY)["databases"]]
def drop_database(self, name_or_database): def drop_database(self, name_or_database):
"""Drop a database. """Drop a database.
@ -1860,7 +1868,8 @@ class MongoReplicaSetClient(common.BaseObject):
"%s or Database" % (basestring.__name__,)) "%s or Database" % (basestring.__name__,))
self._purge_index(name) self._purge_index(name)
self[name].command("dropDatabase") self[name].command("dropDatabase",
read_preference=ReadPreference.PRIMARY)
def copy_database(self, from_name, to_name, def copy_database(self, from_name, to_name,
from_host=None, username=None, password=None): from_host=None, username=None, password=None):
@ -1906,12 +1915,15 @@ class MongoReplicaSetClient(common.BaseObject):
if username is not None: if username is not None:
nonce = self.admin.command("copydbgetnonce", nonce = self.admin.command("copydbgetnonce",
fromhost=from_host)["nonce"] read_preference=ReadPreference.PRIMARY,
fromhost=from_host)["nonce"]
command["username"] = username command["username"] = username
command["nonce"] = nonce command["nonce"] = nonce
command["key"] = auth._auth_key(nonce, username, password) command["key"] = auth._auth_key(nonce, username, password)
return self.admin.command("copydb", **command) return self.admin.command("copydb",
read_preference=ReadPreference.PRIMARY,
**command)
finally: finally:
self.end_request() self.end_request()

View File

@ -33,7 +33,7 @@ from distutils.errors import CCompilerError
from distutils.errors import DistutilsPlatformError, DistutilsExecError from distutils.errors import DistutilsPlatformError, DistutilsExecError
from distutils.core import Extension from distutils.core import Extension
version = "2.7" version = "2.7.2"
f = open("README.rst") f = open("README.rst")
try: try:

View File

@ -600,6 +600,28 @@ class TestBulk(BulkTestBase):
self.assertEqual(1, self.coll.find({'x': 1}).count()) self.assertEqual(1, self.coll.find({'x': 1}).count())
def test_client_generated_upsert_id(self):
batch = self.coll.initialize_ordered_bulk_op()
batch.find({'_id': 0}).upsert().update_one({'$set': {'a': 0}})
batch.find({'a': 1}).upsert().replace_one({'_id': 1})
if not version.at_least(self.coll.database.connection, (2, 6, 0)):
# This case is only possible in MongoDB versions before 2.6.
batch.find({'_id': 3}).upsert().replace_one({'_id': 2})
else:
# This is just here to make the counts right in all cases.
batch.find({'_id': 2}).upsert().replace_one({'_id': 2})
result = batch.execute()
self.assertEqualResponse(
{'nMatched': 0,
'nModified': 0,
'nUpserted': 3,
'nInserted': 0,
'nRemoved': 0,
'upserted': [{'index': 0, '_id': 0},
{'index': 1, '_id': 1},
{'index': 2, '_id': 2}]},
result)
def test_single_ordered_batch(self): def test_single_ordered_batch(self):
batch = self.coll.initialize_ordered_bulk_op() batch = self.coll.initialize_ordered_bulk_op()
batch.insert({'a': 1}) batch.insert({'a': 1})
@ -893,40 +915,71 @@ class TestBulkWriteConcern(BulkTestBase):
ismaster = client.test.command('ismaster') ismaster = client.test.command('ismaster')
self.is_repl = bool(ismaster.get('setName')) self.is_repl = bool(ismaster.get('setName'))
self.w = len(ismaster.get("hosts", [])) self.w = len(ismaster.get("hosts", []))
self.client = client
self.coll = client.pymongo_test.test self.coll = client.pymongo_test.test
self.coll.remove() self.coll.remove()
def test_fsync_and_j(self): def test_fsync_and_j(self):
if not version.at_least(self.client, (1, 8, 2)):
raise SkipTest("Need at least MongoDB 1.8.2")
batch = self.coll.initialize_ordered_bulk_op() batch = self.coll.initialize_ordered_bulk_op()
batch.insert({'a': 1}) batch.insert({'a': 1})
self.assertRaises( self.assertRaises(
OperationFailure, OperationFailure,
batch.execute, {'fsync': True, 'j': True}) batch.execute, {'fsync': True, 'j': True})
def test_j_without_journal(self):
client = self.coll.database.connection
if not server_started_with_nojournal(client):
raise SkipTest("Need mongod started with --nojournal")
# Using j=True without journaling is a hard failure.
batch = self.coll.initialize_ordered_bulk_op()
batch.insert({})
self.assertRaises(OperationFailure, batch.execute, {'j': True})
def test_write_concern_failure_ordered(self): def test_write_concern_failure_ordered(self):
if not self.is_repl:
raise SkipTest("Need a replica set to test.")
# Ensure we don't raise on wnote.
batch = self.coll.initialize_ordered_bulk_op()
batch.find({"something": "that does not exist"}).remove()
self.assertTrue(batch.execute({"w": self.w}))
batch = self.coll.initialize_ordered_bulk_op() batch = self.coll.initialize_ordered_bulk_op()
batch.insert({'a': 1}) batch.insert({'a': 1})
batch.insert({'a': 2}) batch.insert({'a': 2})
# Using w > 1 with no replication is a hard failure.
if not self.is_repl:
self.assertRaises(OperationFailure,
batch.execute, {'w': 5, 'wtimeout': 1})
# Replication wtimeout is a 'soft' error. # Replication wtimeout is a 'soft' error.
# It shouldn't stop batch processing. # It shouldn't stop batch processing.
try:
batch.execute({'w': self.w + 1, 'wtimeout': 1})
except BulkWriteError, exc:
result = exc.details
self.assertEqual(exc.code, 65)
else: else:
self.fail("Error not raised")
self.assertEqualResponse(
{'nMatched': 0,
'nModified': 0,
'nUpserted': 0,
'nInserted': 2,
'nRemoved': 0,
'upserted': [],
'writeErrors': []},
result)
# When talking to legacy servers there will be a
# write concern error for each operation.
self.assertTrue(len(result['writeConcernErrors']) > 0)
failed = result['writeConcernErrors'][0]
self.assertEqual(64, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], basestring))
self.coll.remove()
self.coll.ensure_index('a', unique=True)
# Fail due to write concern support as well
# as duplicate key error on ordered batch.
try:
batch = self.coll.initialize_ordered_bulk_op()
batch.insert({'a': 1})
batch.find({'a': 3}).upsert().replace_one({'b': 1})
batch.insert({'a': 1})
batch.insert({'a': 2})
try: try:
batch.execute({'w': self.w + 1, 'wtimeout': 1}) batch.execute({'w': self.w + 1, 'wtimeout': 1})
except BulkWriteError, exc: except BulkWriteError, exc:
@ -938,74 +991,66 @@ class TestBulkWriteConcern(BulkTestBase):
self.assertEqualResponse( self.assertEqualResponse(
{'nMatched': 0, {'nMatched': 0,
'nModified': 0, 'nModified': 0,
'nUpserted': 0, 'nUpserted': 1,
'nInserted': 2, 'nInserted': 1,
'nRemoved': 0, 'nRemoved': 0,
'upserted': [], 'upserted': [{'index': 1, '_id': '...'}],
'writeErrors': []}, 'writeErrors': [
{'index': 2,
'code': 11000,
'errmsg': '...',
'op': {'_id': '...', 'a': 1}}]},
result) result)
# When talking to legacy servers there will be a self.assertEqual(2, len(result['writeConcernErrors']))
# write concern error for each operation. failed = result['writeErrors'][0]
self.assertTrue(len(result['writeConcernErrors']) > 0) self.assertTrue("duplicate" in failed['errmsg'])
finally:
failed = result['writeConcernErrors'][0] self.coll.drop_index([('a', 1)])
self.assertEqual(64, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], basestring))
self.coll.remove()
self.coll.ensure_index('a', unique=True)
# Fail due to write concern support as well
# as duplicate key error on ordered batch.
try:
batch = self.coll.initialize_ordered_bulk_op()
batch.insert({'a': 1})
batch.find({'a': 3}).upsert().replace_one({'b': 1})
batch.insert({'a': 1})
batch.insert({'a': 2})
try:
batch.execute({'w': self.w + 1, 'wtimeout': 1})
except BulkWriteError, exc:
result = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
self.assertEqualResponse(
{'nMatched': 0,
'nModified': 0,
'nUpserted': 1,
'nInserted': 1,
'nRemoved': 0,
'upserted': [{'index': 1, '_id': '...'}],
'writeErrors': [
{'index': 2,
'code': 11000,
'errmsg': '...',
'op': {'_id': '...', 'a': 1}}]},
result)
self.assertEqual(2, len(result['writeConcernErrors']))
failed = result['writeErrors'][0]
self.assertTrue("duplicate" in failed['errmsg'])
finally:
self.coll.drop_index([('a', 1)])
def test_write_concern_failure_unordered(self): def test_write_concern_failure_unordered(self):
if not self.is_repl:
raise SkipTest("Need a replica set to test.")
# Ensure we don't raise on wnote.
batch = self.coll.initialize_unordered_bulk_op()
batch.find({"something": "that does not exist"}).remove()
self.assertTrue(batch.execute({"w": self.w}))
batch = self.coll.initialize_unordered_bulk_op() batch = self.coll.initialize_unordered_bulk_op()
batch.insert({'a': 1}) batch.insert({'a': 1})
batch.find({'a': 3}).upsert().update_one({'$set': {'a': 3, 'b': 1}}) batch.find({'a': 3}).upsert().update_one({'$set': {'a': 3, 'b': 1}})
batch.insert({'a': 2}) batch.insert({'a': 2})
# Using w > 1 with no replication is a hard failure.
if not self.is_repl:
self.assertRaises(OperationFailure,
batch.execute, {'w': 5, 'wtimeout': 1})
# Replication wtimeout is a 'soft' error. # Replication wtimeout is a 'soft' error.
# It shouldn't stop batch processing. # It shouldn't stop batch processing.
try:
batch.execute({'w': self.w + 1, 'wtimeout': 1})
except BulkWriteError, exc:
result = exc.details
self.assertEqual(exc.code, 65)
else: else:
self.fail("Error not raised")
self.assertEqual(2, result['nInserted'])
self.assertEqual(1, result['nUpserted'])
self.assertEqual(0, len(result['writeErrors']))
# When talking to legacy servers there will be a
# write concern error for each operation.
self.assertTrue(len(result['writeConcernErrors']) > 1)
self.coll.remove()
self.coll.ensure_index('a', unique=True)
# Fail due to write concern support as well
# as duplicate key error on unordered batch.
try:
batch = self.coll.initialize_unordered_bulk_op()
batch.insert({'a': 1})
batch.find({'a': 3}).upsert().update_one({'$set': {'a': 3,
'b': 1}})
batch.insert({'a': 1})
batch.insert({'a': 2})
try: try:
batch.execute({'w': self.w + 1, 'wtimeout': 1}) batch.execute({'w': self.w + 1, 'wtimeout': 1})
except BulkWriteError, exc: except BulkWriteError, exc:
@ -1016,54 +1061,27 @@ class TestBulkWriteConcern(BulkTestBase):
self.assertEqual(2, result['nInserted']) self.assertEqual(2, result['nInserted'])
self.assertEqual(1, result['nUpserted']) self.assertEqual(1, result['nUpserted'])
self.assertEqual(0, len(result['writeErrors'])) self.assertEqual(1, len(result['writeErrors']))
# When talking to legacy servers there will be a # When talking to legacy servers there will be a
# write concern error for each operation. # write concern error for each operation.
self.assertTrue(len(result['writeConcernErrors']) > 1) self.assertTrue(len(result['writeConcernErrors']) > 1)
self.coll.remove() failed = result['writeErrors'][0]
self.coll.ensure_index('a', unique=True) self.assertEqual(2, failed['index'])
self.assertEqual(11000, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], basestring))
self.assertEqual(1, failed['op']['a'])
# Fail due to write concern support as well failed = result['writeConcernErrors'][0]
# as duplicate key error on unordered batch. self.assertEqual(64, failed['code'])
try: self.assertTrue(isinstance(failed['errmsg'], basestring))
batch = self.coll.initialize_unordered_bulk_op()
batch.insert({'a': 1})
batch.find({'a': 3}).upsert().update_one({'$set': {'a': 3,
'b': 1}})
batch.insert({'a': 1})
batch.insert({'a': 2})
try:
batch.execute({'w': self.w + 1, 'wtimeout': 1})
except BulkWriteError, exc:
result = exc.details
self.assertEqual(exc.code, 65)
else:
self.fail("Error not raised")
self.assertEqual(2, result['nInserted']) upserts = result['upserted']
self.assertEqual(1, result['nUpserted']) self.assertEqual(1, len(upserts))
self.assertEqual(1, len(result['writeErrors'])) self.assertEqual(1, upserts[0]['index'])
# When talking to legacy servers there will be a self.assertTrue(upserts[0].get('_id'))
# write concern error for each operation. finally:
self.assertTrue(len(result['writeConcernErrors']) > 1) self.coll.drop_index([('a', 1)])
failed = result['writeErrors'][0]
self.assertEqual(2, failed['index'])
self.assertEqual(11000, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], basestring))
self.assertEqual(1, failed['op']['a'])
failed = result['writeConcernErrors'][0]
self.assertEqual(64, failed['code'])
self.assertTrue(isinstance(failed['errmsg'], basestring))
upserts = result['upserted']
self.assertEqual(1, len(upserts))
self.assertEqual(1, upserts[0]['index'])
self.assertTrue(upserts[0].get('_id'))
finally:
self.coll.drop_index([('a', 1)])
class TestBulkNoResults(BulkTestBase): class TestBulkNoResults(BulkTestBase):

View File

@ -22,6 +22,7 @@ import sys
import time import time
import thread import thread
import unittest import unittest
import warnings
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -33,7 +34,7 @@ from bson.tz_util import utc
from pymongo.mongo_client import MongoClient from pymongo.mongo_client import MongoClient
from pymongo.database import Database from pymongo.database import Database
from pymongo.pool import SocketInfo from pymongo.pool import SocketInfo
from pymongo import thread_util, common from pymongo import auth, thread_util, common
from pymongo.errors import (AutoReconnect, from pymongo.errors import (AutoReconnect,
ConfigurationError, ConfigurationError,
ConnectionFailure, ConnectionFailure,
@ -43,6 +44,7 @@ from pymongo.errors import (AutoReconnect,
from test import version, host, port, pair from test import version, host, port, pair
from test.pymongo_mocks import MockClient from test.pymongo_mocks import MockClient
from test.utils import (assertRaisesExactly, from test.utils import (assertRaisesExactly,
catch_warnings,
delay, delay,
is_mongos, is_mongos,
remove_all_users, remove_all_users,
@ -50,9 +52,11 @@ from test.utils import (assertRaisesExactly,
server_started_with_auth, server_started_with_auth,
TestRequestMixin, TestRequestMixin,
_TestLazyConnectMixin, _TestLazyConnectMixin,
_TestExhaustCursorMixin,
lazy_client_trial, lazy_client_trial,
NTHREADS, NTHREADS,
get_pool) get_pool,
one)
def get_client(*args, **kwargs): def get_client(*args, **kwargs):
@ -224,6 +228,9 @@ class TestClient(unittest.TestCase, TestRequestMixin):
# from a master in a master-slave pair. # from a master in a master-slave pair.
if server_is_master_with_slave(c): if server_is_master_with_slave(c):
raise SkipTest("SERVER-2329") raise SkipTest("SERVER-2329")
if (not version.at_least(c, (2, 6, 0)) and
is_mongos(c) and server_started_with_auth(c)):
raise SkipTest("Need mongos >= 2.6.0 to test with authentication")
# We test copy twice; once starting in a request and once not. In # We test copy twice; once starting in a request and once not. In
# either case the copy should succeed (because it starts a request # either case the copy should succeed (because it starts a request
# internally) and should leave us in the same state as before the copy. # internally) and should leave us in the same state as before the copy.
@ -260,8 +267,7 @@ class TestClient(unittest.TestCase, TestRequestMixin):
self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"]) self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"])
# See SERVER-6427 for mongos # See SERVER-6427 for mongos
if (version.at_least(c, (1, 3, 3, 1)) and if not is_mongos(c) and server_started_with_auth(c):
not is_mongos(c) and server_started_with_auth(c)):
c.drop_database("pymongo_test1") c.drop_database("pymongo_test1")
@ -315,11 +321,16 @@ class TestClient(unittest.TestCase, TestRequestMixin):
def test_from_uri(self): def test_from_uri(self):
c = MongoClient(host, port) c = MongoClient(host, port)
self.assertEqual(c, MongoClient("mongodb://%s:%d" % (host, port))) ctx = catch_warnings()
self.assertTrue(MongoClient( try:
"mongodb://%s:%d" % (host, port), slave_okay=True).slave_okay) warnings.simplefilter("ignore", DeprecationWarning)
self.assertTrue(MongoClient( self.assertEqual(c, MongoClient("mongodb://%s:%d" % (host, port)))
"mongodb://%s:%d/?slaveok=true;w=2" % (host, port)).slave_okay) self.assertTrue(MongoClient(
"mongodb://%s:%d" % (host, port), slave_okay=True).slave_okay)
self.assertTrue(MongoClient(
"mongodb://%s:%d/?slaveok=true;w=2" % (host, port)).slave_okay)
finally:
ctx.exit()
def test_get_default_database(self): def test_get_default_database(self):
c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False) c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False)
@ -990,6 +1001,42 @@ with client.start_request() as request:
client = get_client(_connect=False) client = get_client(_connect=False)
client.pymongo_test.test.remove(w=0) client.pymongo_test.test.remove(w=0)
def test_auth_network_error(self):
# Make sure there's no semaphore leak if we get a network error
# when authenticating a new socket with cached credentials.
auth_client = get_client()
if not server_started_with_auth(auth_client):
raise SkipTest('Authentication is not enabled on server')
auth_client.admin.add_user('admin', 'password')
auth_client.admin.authenticate('admin', 'password')
try:
# Get a client with one socket so we detect if it's leaked.
c = get_client(max_pool_size=1, waitQueueTimeoutMS=1)
# Simulate an authenticate() call on a different socket.
credentials = auth._build_credentials_tuple(
'MONGODB-CR', 'admin',
unicode('admin'), unicode('password'),
{})
c._cache_credentials('test', credentials, connect=False)
# Cause a network error on the actual socket.
pool = get_pool(c)
socket_info = one(pool.sockets)
socket_info.sock.close()
# In __check_auth, the client authenticates its socket with the
# new credential, but gets a socket.error. Should be reraised as
# AutoReconnect.
self.assertRaises(AutoReconnect, c.test.collection.find_one)
# No semaphore leak, the pool is allowed to make a new socket.
c.test.collection.find_one()
finally:
remove_all_users(auth_client.admin)
class TestClientLazyConnect(unittest.TestCase, _TestLazyConnectMixin): class TestClientLazyConnect(unittest.TestCase, _TestLazyConnectMixin):
def _get_client(self, **kwargs): def _get_client(self, **kwargs):
@ -1101,5 +1148,10 @@ class TestMongoClientFailover(unittest.TestCase):
c.db.collection.find_one() c.db.collection.find_one()
class TestExhaustCursor(_TestExhaustCursorMixin, unittest.TestCase):
def _get_client(self, **kwargs):
return get_client(**kwargs)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -36,7 +36,7 @@ from bson.objectid import ObjectId
from bson.py3compat import b from bson.py3compat import b
from bson.son import SON, RE_TYPE from bson.son import SON, RE_TYPE
from pymongo import (ASCENDING, DESCENDING, GEO2D, from pymongo import (ASCENDING, DESCENDING, GEO2D,
GEOHAYSTACK, GEOSPHERE, HASHED) GEOHAYSTACK, GEOSPHERE, HASHED, TEXT)
from pymongo import message as message_module from pymongo import message as message_module
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor from pymongo.command_cursor import CommandCursor
@ -51,8 +51,8 @@ from pymongo.errors import (DocumentTooLarge,
OperationFailure, OperationFailure,
WTimeoutError) WTimeoutError)
from test.test_client import get_client from test.test_client import get_client
from test.utils import (is_mongos, joinall, enable_text_search, get_pool, from test.utils import (catch_warnings, enable_text_search,
oid_generated_on_client) get_pool, is_mongos, joinall, oid_generated_on_client)
from test import (qcheck, from test import (qcheck,
version) version)
@ -226,19 +226,21 @@ class TestCollection(unittest.TestCase):
def test_deprecated_ttl_index_kwarg(self): def test_deprecated_ttl_index_kwarg(self):
db = self.db db = self.db
# In Python 2.6+ we could use the catch_warnings context ctx = catch_warnings()
# manager to test this warning nicely. As we can't do that
# we must test raising errors before the ignore filter is applied.
warnings.simplefilter("error", DeprecationWarning)
try: try:
warnings.simplefilter("error", DeprecationWarning)
self.assertRaises(DeprecationWarning, lambda: self.assertRaises(DeprecationWarning, lambda:
db.test.ensure_index("goodbye", ttl=10)) db.test.ensure_index("goodbye", ttl=10))
finally: finally:
warnings.resetwarnings() ctx.exit()
warnings.simplefilter("ignore")
self.assertEqual("goodbye_1", ctx = catch_warnings()
db.test.ensure_index("goodbye", ttl=10)) try:
warnings.simplefilter("ignore", DeprecationWarning)
self.assertEqual("goodbye_1",
db.test.ensure_index("goodbye", ttl=10))
finally:
ctx.exit()
self.assertEqual(None, db.test.ensure_index("goodbye")) self.assertEqual(None, db.test.ensure_index("goodbye"))
def test_ensure_unique_index_threaded(self): def test_ensure_unique_index_threaded(self):
@ -417,7 +419,7 @@ class TestCollection(unittest.TestCase):
db = self.db db = self.db
db.test.drop_indexes() db.test.drop_indexes()
self.assertEqual("t_text", db.test.create_index([("t", "text")])) self.assertEqual("t_text", db.test.create_index([("t", TEXT)]))
index_info = db.test.index_information()["t_text"] index_info = db.test.index_information()["t_text"]
self.assertTrue("weights" in index_info) self.assertTrue("weights" in index_info)
@ -601,13 +603,20 @@ class TestCollection(unittest.TestCase):
db = self.db db = self.db
db.drop_collection("test") db.drop_collection("test")
db.test.save({}) db.test.save({})
self.assertEqual(db.test.options(), {}) expected = {}
if version.at_least(db.connection, (2, 7, 0)):
# usePowerOf2Sizes server default
expected["flags"] = 1
self.assertEqual(db.test.options(), expected)
self.assertEqual(db.test.doesnotexist.options(), {}) self.assertEqual(db.test.doesnotexist.options(), {})
db.drop_collection("test") db.drop_collection("test")
if version.at_least(db.connection, (1, 9)): if version.at_least(db.connection, (1, 9)):
db.create_collection("test", capped=True, size=4096) db.create_collection("test", capped=True, size=4096)
self.assertEqual(db.test.options(), {"capped": True, 'size': 4096}) result = db.test.options()
# mongos 2.2.x adds an $auth field when auth is enabled.
result.pop('$auth', None)
self.assertEqual(result, {"capped": True, 'size': 4096})
else: else:
db.create_collection("test", capped=True) db.create_collection("test", capped=True)
self.assertEqual(db.test.options(), {"capped": True}) self.assertEqual(db.test.options(), {"capped": True})
@ -853,10 +862,15 @@ class TestCollection(unittest.TestCase):
) )
# Misconfigured value for safe # Misconfigured value for safe
self.assertRaises( ctx = catch_warnings()
TypeError, try:
lambda: db.test.insert([{'i': 2}] * 2, safe=1), warnings.simplefilter("ignore", DeprecationWarning)
) self.assertRaises(
TypeError,
lambda: db.test.insert([{'i': 2}] * 2, safe=1),
)
finally:
ctx.exit()
def test_insert_iterables(self): def test_insert_iterables(self):
db = self.db db = self.db
@ -979,10 +993,16 @@ class TestCollection(unittest.TestCase):
db.test.insert({"_id": 2, "x": 2}) db.test.insert({"_id": 2, "x": 2})
# No error # No error
db.test.insert({"_id": 1, "x": 1}, safe=False) ctx = catch_warnings()
db.test.save({"_id": 1, "x": 1}, safe=False) try:
db.test.insert({"_id": 2, "x": 2}, safe=False) warnings.simplefilter("ignore", DeprecationWarning)
db.test.save({"_id": 2, "x": 2}, safe=False) db.test.insert({"_id": 1, "x": 1}, safe=False)
db.test.save({"_id": 1, "x": 1}, safe=False)
db.test.insert({"_id": 2, "x": 2}, safe=False)
db.test.save({"_id": 2, "x": 2}, safe=False)
finally:
ctx.exit()
db.test.insert({"_id": 1, "x": 1}, w=0) db.test.insert({"_id": 1, "x": 1}, w=0)
db.test.save({"_id": 1, "x": 1}, w=0) db.test.save({"_id": 1, "x": 1}, w=0)
db.test.insert({"_id": 2, "x": 2}, w=0) db.test.insert({"_id": 2, "x": 2}, w=0)
@ -1330,8 +1350,9 @@ class TestCollection(unittest.TestCase):
self.fail("WTimeoutError was not raised") self.fail("WTimeoutError was not raised")
# can't use fsync and j options together # can't use fsync and j options together
self.assertRaises(OperationFailure, self.db.test.insert, if version.at_least(self.client, (1, 8, 2)):
{"_id": 1}, j=True, fsync=True) self.assertRaises(OperationFailure, self.db.test.insert,
{"_id": 1}, j=True, fsync=True)
def test_manual_last_error(self): def test_manual_last_error(self):
self.db.test.save({"x": 1}, w=0) self.db.test.save({"x": 1}, w=0)
@ -2064,8 +2085,17 @@ class TestCollection(unittest.TestCase):
# (Shame on me) # (Shame on me)
def test_bad_encode(self): def test_bad_encode(self):
c = self.db.test c = self.db.test
c.drop()
self.assertRaises(InvalidDocument, c.save, {"x": c}) self.assertRaises(InvalidDocument, c.save, {"x": c})
class BadGetAttr(dict):
def __getattr__(self, name):
pass
bad = BadGetAttr([('foo', 'bar')])
c.insert({'bad': bad})
self.assertEqual('bar', c.find_one()['bad']['foo'])
def test_bad_dbref(self): def test_bad_dbref(self):
c = self.db.test c = self.db.test
c.drop() c.drop()
@ -2194,53 +2224,71 @@ class TestCollection(unittest.TestCase):
as_class=ExtendedDict) as_class=ExtendedDict)
self.assertTrue(isinstance(result, ExtendedDict)) self.assertTrue(isinstance(result, ExtendedDict))
def test_update_backward_compat(self):
# MongoDB versions >= 2.6.0 don't return the updatedExisting field
# and return upsert _id in an array subdocument. This test should
# pass regardless of server version or type (mongod/s).
c = self.db.test
c.drop()
oid = ObjectId()
res = c.update({'_id': oid}, {'$set': {'a': 'a'}}, upsert=True)
self.assertFalse(res.get('updatedExisting'))
self.assertEqual(oid, res.get('upserted'))
res = c.update({'_id': oid}, {'$set': {'b': 'b'}})
self.assertTrue(res.get('updatedExisting'))
def test_find_and_modify_with_sort(self): def test_find_and_modify_with_sort(self):
c = self.db.test c = self.db.test
c.drop() c.drop()
for j in xrange(5): for j in xrange(5):
c.insert({'j': j, 'i': 0}) c.insert({'j': j, 'i': 0})
sort={'j': DESCENDING} ctx = catch_warnings()
self.assertEqual(4, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort={'j': ASCENDING}
self.assertEqual(0, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort=[('j', DESCENDING)]
self.assertEqual(4, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort=[('j', ASCENDING)]
self.assertEqual(0, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort=SON([('j', DESCENDING)])
self.assertEqual(4, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort=SON([('j', ASCENDING)])
self.assertEqual(0, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
try: try:
from collections import OrderedDict warnings.simplefilter("ignore", DeprecationWarning)
sort=OrderedDict([('j', DESCENDING)]) sort={'j': DESCENDING}
self.assertEqual(4, c.find_and_modify({}, self.assertEqual(4, c.find_and_modify({},
{'$inc': {'i': 1}}, {'$inc': {'i': 1}},
sort=sort)['j']) sort=sort)['j'])
sort=OrderedDict([('j', ASCENDING)]) sort={'j': ASCENDING}
self.assertEqual(0, c.find_and_modify({}, self.assertEqual(0, c.find_and_modify({},
{'$inc': {'i': 1}}, {'$inc': {'i': 1}},
sort=sort)['j']) sort=sort)['j'])
except ImportError: sort=[('j', DESCENDING)]
pass self.assertEqual(4, c.find_and_modify({},
# Test that a standard dict with two keys is rejected. {'$inc': {'i': 1}},
sort={'j': DESCENDING, 'foo': DESCENDING} sort=sort)['j'])
self.assertRaises(TypeError, c.find_and_modify, {}, sort=[('j', ASCENDING)]
{'$inc': {'i': 1}}, self.assertEqual(0, c.find_and_modify({},
sort=sort) {'$inc': {'i': 1}},
sort=sort)['j'])
sort=SON([('j', DESCENDING)])
self.assertEqual(4, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort=SON([('j', ASCENDING)])
self.assertEqual(0, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
try:
from collections import OrderedDict
sort=OrderedDict([('j', DESCENDING)])
self.assertEqual(4, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
sort=OrderedDict([('j', ASCENDING)])
self.assertEqual(0, c.find_and_modify({},
{'$inc': {'i': 1}},
sort=sort)['j'])
except ImportError:
pass
# Test that a standard dict with two keys is rejected.
sort={'j': DESCENDING, 'foo': DESCENDING}
self.assertRaises(TypeError, c.find_and_modify,
{}, {'$inc': {'i': 1}}, sort=sort)
finally:
ctx.exit()
def test_find_with_nested(self): def test_find_with_nested(self):
if not version.at_least(self.db.connection, (2, 0, 0)): if not version.at_least(self.db.connection, (2, 0, 0)):

View File

@ -31,7 +31,7 @@ from pymongo.mongo_client import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.errors import ConfigurationError, OperationFailure from pymongo.errors import ConfigurationError, OperationFailure
from test import host, port, pair, version from test import host, port, pair, version
from test.utils import drop_collections from test.utils import catch_warnings, drop_collections
have_uuid = True have_uuid = True
try: try:
@ -44,11 +44,9 @@ class TestCommon(unittest.TestCase):
def test_baseobject(self): def test_baseobject(self):
# In Python 2.6+ we could use the catch_warnings context ctx = catch_warnings()
# manager to test this warning nicely. As we can't do that
# we must test raising errors before the ignore filter is applied.
warnings.simplefilter("error", UserWarning)
try: try:
warnings.simplefilter("error", UserWarning)
self.assertRaises(UserWarning, lambda: self.assertRaises(UserWarning, lambda:
MongoClient(host, port, wtimeout=1000, w=0)) MongoClient(host, port, wtimeout=1000, w=0))
try: try:
@ -61,191 +59,195 @@ class TestCommon(unittest.TestCase):
except UserWarning: except UserWarning:
self.fail() self.fail()
finally: finally:
warnings.resetwarnings() ctx.exit()
warnings.simplefilter("ignore")
# Connection tests # Connection tests
c = Connection(pair) ctx = catch_warnings()
self.assertFalse(c.slave_okay) try:
self.assertFalse(c.safe) warnings.simplefilter("ignore", DeprecationWarning)
self.assertEqual({}, c.get_lasterror_options()) c = Connection(pair)
db = c.pymongo_test self.assertFalse(c.slave_okay)
db.drop_collection("test") self.assertFalse(c.safe)
self.assertFalse(db.slave_okay) self.assertEqual({}, c.get_lasterror_options())
self.assertFalse(db.safe) db = c.pymongo_test
self.assertEqual({}, db.get_lasterror_options()) db.drop_collection("test")
coll = db.test self.assertFalse(db.slave_okay)
self.assertFalse(coll.slave_okay) self.assertFalse(db.safe)
self.assertFalse(coll.safe) self.assertEqual({}, db.get_lasterror_options())
self.assertEqual({}, coll.get_lasterror_options()) coll = db.test
self.assertFalse(coll.slave_okay)
self.assertFalse(coll.safe)
self.assertEqual({}, coll.get_lasterror_options())
self.assertEqual((False, {}), coll._get_write_mode()) self.assertEqual((False, {}), coll._get_write_mode())
coll.safe = False coll.safe = False
coll.write_concern.update(w=1) coll.write_concern.update(w=1)
self.assertEqual((True, {}), coll._get_write_mode()) self.assertEqual((True, {}), coll._get_write_mode())
coll.write_concern.update(w=3) coll.write_concern.update(w=3)
self.assertEqual((True, {'w': 3}), coll._get_write_mode()) self.assertEqual((True, {'w': 3}), coll._get_write_mode())
coll.safe = True coll.safe = True
coll.write_concern.update(w=0) coll.write_concern.update(w=0)
self.assertEqual((False, {}), coll._get_write_mode()) self.assertEqual((False, {}), coll._get_write_mode())
coll = db.test coll = db.test
cursor = coll.find() cursor = coll.find()
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
cursor = coll.find(slave_okay=True) cursor = coll.find(slave_okay=True)
self.assertTrue(cursor._Cursor__slave_okay) self.assertTrue(cursor._Cursor__slave_okay)
# MongoClient test # MongoClient test
c = MongoClient(pair) c = MongoClient(pair)
self.assertFalse(c.slave_okay) self.assertFalse(c.slave_okay)
self.assertTrue(c.safe) self.assertTrue(c.safe)
self.assertEqual({}, c.get_lasterror_options()) self.assertEqual({}, c.get_lasterror_options())
db = c.pymongo_test db = c.pymongo_test
db.drop_collection("test") db.drop_collection("test")
self.assertFalse(db.slave_okay) self.assertFalse(db.slave_okay)
self.assertTrue(db.safe) self.assertTrue(db.safe)
self.assertEqual({}, db.get_lasterror_options()) self.assertEqual({}, db.get_lasterror_options())
coll = db.test coll = db.test
self.assertFalse(coll.slave_okay) self.assertFalse(coll.slave_okay)
self.assertTrue(coll.safe) self.assertTrue(coll.safe)
self.assertEqual({}, coll.get_lasterror_options()) self.assertEqual({}, coll.get_lasterror_options())
self.assertEqual((True, {}), coll._get_write_mode()) self.assertEqual((True, {}), coll._get_write_mode())
coll.safe = False coll.safe = False
coll.write_concern.update(w=1) coll.write_concern.update(w=1)
self.assertEqual((True, {}), coll._get_write_mode()) self.assertEqual((True, {}), coll._get_write_mode())
coll.write_concern.update(w=3) coll.write_concern.update(w=3)
self.assertEqual((True, {'w': 3}), coll._get_write_mode()) self.assertEqual((True, {'w': 3}), coll._get_write_mode())
coll.safe = True coll.safe = True
coll.write_concern.update(w=0) coll.write_concern.update(w=0)
self.assertEqual((False, {}), coll._get_write_mode()) self.assertEqual((False, {}), coll._get_write_mode())
coll = db.test coll = db.test
cursor = coll.find() cursor = coll.find()
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
cursor = coll.find(slave_okay=True) cursor = coll.find(slave_okay=True)
self.assertTrue(cursor._Cursor__slave_okay) self.assertTrue(cursor._Cursor__slave_okay)
# Setting any safe operations overrides explicit safe # Setting any safe operations overrides explicit safe
self.assertTrue(MongoClient(host, port, wtimeout=1000, safe=False).safe) self.assertTrue(MongoClient(host, port, wtimeout=1000, safe=False).safe)
c = MongoClient(pair, slaveok=True, w='majority', c = MongoClient(pair, slaveok=True, w='majority',
wtimeout=300, fsync=True, j=True) wtimeout=300, fsync=True, j=True)
self.assertTrue(c.slave_okay) self.assertTrue(c.slave_okay)
self.assertTrue(c.safe) self.assertTrue(c.safe)
d = {'w': 'majority', 'wtimeout': 300, 'fsync': True, 'j': True} d = {'w': 'majority', 'wtimeout': 300, 'fsync': True, 'j': True}
self.assertEqual(d, c.get_lasterror_options()) self.assertEqual(d, c.get_lasterror_options())
db = c.pymongo_test db = c.pymongo_test
self.assertTrue(db.slave_okay) self.assertTrue(db.slave_okay)
self.assertTrue(db.safe) self.assertTrue(db.safe)
self.assertEqual(d, db.get_lasterror_options()) self.assertEqual(d, db.get_lasterror_options())
coll = db.test coll = db.test
self.assertTrue(coll.slave_okay) self.assertTrue(coll.slave_okay)
self.assertTrue(coll.safe) self.assertTrue(coll.safe)
self.assertEqual(d, coll.get_lasterror_options()) self.assertEqual(d, coll.get_lasterror_options())
cursor = coll.find() cursor = coll.find()
self.assertTrue(cursor._Cursor__slave_okay) self.assertTrue(cursor._Cursor__slave_okay)
cursor = coll.find(slave_okay=False) cursor = coll.find(slave_okay=False)
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
c = MongoClient('mongodb://%s/?' c = MongoClient('mongodb://%s/?'
'w=2;wtimeoutMS=300;fsync=true;' 'w=2;wtimeoutMS=300;fsync=true;'
'journal=true' % (pair,)) 'journal=true' % (pair,))
self.assertTrue(c.safe) self.assertTrue(c.safe)
d = {'w': 2, 'wtimeout': 300, 'fsync': True, 'j': True} d = {'w': 2, 'wtimeout': 300, 'fsync': True, 'j': True}
self.assertEqual(d, c.get_lasterror_options()) self.assertEqual(d, c.get_lasterror_options())
c = MongoClient('mongodb://%s/?' c = MongoClient('mongodb://%s/?'
'slaveok=true;w=1;wtimeout=300;' 'slaveok=true;w=1;wtimeout=300;'
'fsync=true;j=true' % (pair,)) 'fsync=true;j=true' % (pair,))
self.assertTrue(c.slave_okay) self.assertTrue(c.slave_okay)
self.assertTrue(c.safe) self.assertTrue(c.safe)
d = {'w': 1, 'wtimeout': 300, 'fsync': True, 'j': True} d = {'w': 1, 'wtimeout': 300, 'fsync': True, 'j': True}
self.assertEqual(d, c.get_lasterror_options()) self.assertEqual(d, c.get_lasterror_options())
self.assertEqual(d, c.write_concern) self.assertEqual(d, c.write_concern)
db = c.pymongo_test db = c.pymongo_test
self.assertTrue(db.slave_okay) self.assertTrue(db.slave_okay)
self.assertTrue(db.safe) self.assertTrue(db.safe)
self.assertEqual(d, db.get_lasterror_options()) self.assertEqual(d, db.get_lasterror_options())
self.assertEqual(d, db.write_concern) self.assertEqual(d, db.write_concern)
coll = db.test coll = db.test
self.assertTrue(coll.slave_okay) self.assertTrue(coll.slave_okay)
self.assertTrue(coll.safe) self.assertTrue(coll.safe)
self.assertEqual(d, coll.get_lasterror_options()) self.assertEqual(d, coll.get_lasterror_options())
self.assertEqual(d, coll.write_concern) self.assertEqual(d, coll.write_concern)
cursor = coll.find() cursor = coll.find()
self.assertTrue(cursor._Cursor__slave_okay) self.assertTrue(cursor._Cursor__slave_okay)
cursor = coll.find(slave_okay=False) cursor = coll.find(slave_okay=False)
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
c.unset_lasterror_options() c.unset_lasterror_options()
self.assertTrue(c.slave_okay) self.assertTrue(c.slave_okay)
self.assertTrue(c.safe) self.assertTrue(c.safe)
c.safe = False c.safe = False
self.assertFalse(c.safe) self.assertFalse(c.safe)
c.slave_okay = False c.slave_okay = False
self.assertFalse(c.slave_okay) self.assertFalse(c.slave_okay)
self.assertEqual({}, c.get_lasterror_options()) self.assertEqual({}, c.get_lasterror_options())
self.assertEqual({}, c.write_concern) self.assertEqual({}, c.write_concern)
db = c.pymongo_test db = c.pymongo_test
self.assertFalse(db.slave_okay) self.assertFalse(db.slave_okay)
self.assertFalse(db.safe) self.assertFalse(db.safe)
self.assertEqual({}, db.get_lasterror_options()) self.assertEqual({}, db.get_lasterror_options())
self.assertEqual({}, db.write_concern) self.assertEqual({}, db.write_concern)
coll = db.test coll = db.test
self.assertFalse(coll.slave_okay) self.assertFalse(coll.slave_okay)
self.assertFalse(coll.safe) self.assertFalse(coll.safe)
self.assertEqual({}, coll.get_lasterror_options()) self.assertEqual({}, coll.get_lasterror_options())
self.assertEqual({}, coll.write_concern) self.assertEqual({}, coll.write_concern)
cursor = coll.find() cursor = coll.find()
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
cursor = coll.find(slave_okay=True) cursor = coll.find(slave_okay=True)
self.assertTrue(cursor._Cursor__slave_okay) self.assertTrue(cursor._Cursor__slave_okay)
coll.set_lasterror_options(fsync=True) coll.set_lasterror_options(fsync=True)
self.assertEqual({'fsync': True}, coll.get_lasterror_options()) self.assertEqual({'fsync': True}, coll.get_lasterror_options())
self.assertEqual({'fsync': True}, coll.write_concern) self.assertEqual({'fsync': True}, coll.write_concern)
self.assertEqual({}, db.get_lasterror_options()) self.assertEqual({}, db.get_lasterror_options())
self.assertEqual({}, db.write_concern) self.assertEqual({}, db.write_concern)
self.assertFalse(db.safe) self.assertFalse(db.safe)
self.assertEqual({}, c.get_lasterror_options()) self.assertEqual({}, c.get_lasterror_options())
self.assertEqual({}, c.write_concern) self.assertEqual({}, c.write_concern)
self.assertFalse(c.safe) self.assertFalse(c.safe)
db.set_lasterror_options(w='majority') db.set_lasterror_options(w='majority')
self.assertEqual({'fsync': True}, coll.get_lasterror_options()) self.assertEqual({'fsync': True}, coll.get_lasterror_options())
self.assertEqual({'fsync': True}, coll.write_concern) self.assertEqual({'fsync': True}, coll.write_concern)
self.assertEqual({'w': 'majority'}, db.get_lasterror_options()) self.assertEqual({'w': 'majority'}, db.get_lasterror_options())
self.assertEqual({'w': 'majority'}, db.write_concern) self.assertEqual({'w': 'majority'}, db.write_concern)
self.assertEqual({}, c.get_lasterror_options()) self.assertEqual({}, c.get_lasterror_options())
self.assertEqual({}, c.write_concern) self.assertEqual({}, c.write_concern)
self.assertFalse(c.safe) self.assertFalse(c.safe)
db.slave_okay = True db.slave_okay = True
self.assertTrue(db.slave_okay) self.assertTrue(db.slave_okay)
self.assertFalse(c.slave_okay) self.assertFalse(c.slave_okay)
self.assertFalse(coll.slave_okay) self.assertFalse(coll.slave_okay)
cursor = coll.find() cursor = coll.find()
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
cursor = db.coll2.find() cursor = db.coll2.find()
self.assertTrue(cursor._Cursor__slave_okay) self.assertTrue(cursor._Cursor__slave_okay)
cursor = db.coll2.find(slave_okay=False) cursor = db.coll2.find(slave_okay=False)
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
self.assertRaises(ConfigurationError, coll.set_lasterror_options, foo=20) self.assertRaises(ConfigurationError, coll.set_lasterror_options, foo=20)
self.assertRaises(TypeError, coll._BaseObject__set_slave_okay, 20) self.assertRaises(TypeError, coll._BaseObject__set_slave_okay, 20)
self.assertRaises(TypeError, coll._BaseObject__set_safe, 20) self.assertRaises(TypeError, coll._BaseObject__set_safe, 20)
coll.remove() coll.remove()
self.assertEqual(None, coll.find_one(slave_okay=True)) self.assertEqual(None, coll.find_one(slave_okay=True))
coll.unset_lasterror_options() coll.unset_lasterror_options()
coll.set_lasterror_options(w=4, wtimeout=10) coll.set_lasterror_options(w=4, wtimeout=10)
# Fails if we don't have 4 active nodes or we don't have replication... # Fails if we don't have 4 active nodes or we don't have replication...
self.assertRaises(OperationFailure, coll.insert, {'foo': 'bar'}) self.assertRaises(OperationFailure, coll.insert, {'foo': 'bar'})
# Succeeds since we override the lasterror settings per query. # Succeeds since we override the lasterror settings per query.
self.assertTrue(coll.insert({'foo': 'bar'}, fsync=True)) self.assertTrue(coll.insert({'foo': 'bar'}, fsync=True))
drop_collections(db) drop_collections(db)
finally:
ctx.exit()
def test_uuid_subtype(self): def test_uuid_subtype(self):
if not have_uuid: if not have_uuid:
@ -439,30 +441,36 @@ class TestCommon(unittest.TestCase):
m = MongoClient(pair, w=0) m = MongoClient(pair, w=0)
coll = m.pymongo_test.write_concern_test coll = m.pymongo_test.write_concern_test
coll.drop() coll.drop()
doc = {"_id": ObjectId()}
coll.insert(doc)
self.assertTrue(coll.insert(doc, safe=False))
self.assertTrue(coll.insert(doc, w=0))
self.assertTrue(coll.insert(doc))
self.assertRaises(OperationFailure, coll.insert, doc, safe=True)
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = MongoClient(pair) ctx = catch_warnings()
coll = m.pymongo_test.write_concern_test try:
self.assertTrue(coll.insert(doc, safe=False)) warnings.simplefilter("ignore", DeprecationWarning)
self.assertTrue(coll.insert(doc, w=0)) doc = {"_id": ObjectId()}
self.assertRaises(OperationFailure, coll.insert, doc) coll.insert(doc)
self.assertRaises(OperationFailure, coll.insert, doc, safe=True) self.assertTrue(coll.insert(doc, safe=False))
self.assertRaises(OperationFailure, coll.insert, doc, w=1) self.assertTrue(coll.insert(doc, w=0))
self.assertTrue(coll.insert(doc))
self.assertRaises(OperationFailure, coll.insert, doc, safe=True)
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = MongoClient("mongodb://%s/" % (pair,)) m = MongoClient(pair)
self.assertTrue(m.safe) coll = m.pymongo_test.write_concern_test
coll = m.pymongo_test.write_concern_test self.assertTrue(coll.insert(doc, safe=False))
self.assertRaises(OperationFailure, coll.insert, doc) self.assertTrue(coll.insert(doc, w=0))
m = MongoClient("mongodb://%s/?w=0" % (pair,)) self.assertRaises(OperationFailure, coll.insert, doc)
self.assertFalse(m.safe) self.assertRaises(OperationFailure, coll.insert, doc, safe=True)
coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert, doc, w=1)
self.assertTrue(coll.insert(doc))
m = MongoClient("mongodb://%s/" % (pair,))
self.assertTrue(m.safe)
coll = m.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert, doc)
m = MongoClient("mongodb://%s/?w=0" % (pair,))
self.assertFalse(m.safe)
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc))
finally:
ctx.exit()
# Equality tests # Equality tests
self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,))) self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,)))
@ -478,30 +486,36 @@ class TestCommon(unittest.TestCase):
m = MongoReplicaSetClient(pair, replicaSet=setname, w=0) m = MongoReplicaSetClient(pair, replicaSet=setname, w=0)
coll = m.pymongo_test.write_concern_test coll = m.pymongo_test.write_concern_test
coll.drop() coll.drop()
doc = {"_id": ObjectId()}
coll.insert(doc)
self.assertTrue(coll.insert(doc, safe=False))
self.assertTrue(coll.insert(doc, w=0))
self.assertTrue(coll.insert(doc))
self.assertRaises(OperationFailure, coll.insert, doc, safe=True)
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = MongoReplicaSetClient(pair, replicaSet=setname) ctx = catch_warnings()
coll = m.pymongo_test.write_concern_test try:
self.assertTrue(coll.insert(doc, safe=False)) warnings.simplefilter("ignore", DeprecationWarning)
self.assertTrue(coll.insert(doc, w=0)) doc = {"_id": ObjectId()}
self.assertRaises(OperationFailure, coll.insert, doc) coll.insert(doc)
self.assertRaises(OperationFailure, coll.insert, doc, safe=True) self.assertTrue(coll.insert(doc, safe=False))
self.assertRaises(OperationFailure, coll.insert, doc, w=1) self.assertTrue(coll.insert(doc, w=0))
self.assertTrue(coll.insert(doc))
self.assertRaises(OperationFailure, coll.insert, doc, safe=True)
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s" % (pair, setname)) m = MongoReplicaSetClient(pair, replicaSet=setname)
self.assertTrue(m.safe) coll = m.pymongo_test.write_concern_test
coll = m.pymongo_test.write_concern_test self.assertTrue(coll.insert(doc, safe=False))
self.assertRaises(OperationFailure, coll.insert, doc) self.assertTrue(coll.insert(doc, w=0))
m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname)) self.assertRaises(OperationFailure, coll.insert, doc)
self.assertFalse(m.safe) self.assertRaises(OperationFailure, coll.insert, doc, safe=True)
coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert, doc, w=1)
self.assertTrue(coll.insert(doc))
m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s" % (pair, setname))
self.assertTrue(m.safe)
coll = m.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert, doc)
m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname))
self.assertFalse(m.safe)
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc))
finally:
ctx.exit()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -19,6 +19,8 @@ import random
import re import re
import sys import sys
import unittest import unittest
import warnings
sys.path[0:0] = [""] sys.path[0:0] = [""]
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
@ -37,7 +39,8 @@ from pymongo.errors import (InvalidOperation,
ExecutionTimeout) ExecutionTimeout)
from test import version from test import version
from test.test_client import get_client from test.test_client import get_client
from test.utils import is_mongos, get_command_line, server_started_with_auth from test.utils import (catch_warnings, is_mongos,
get_command_line, server_started_with_auth)
class TestCursor(unittest.TestCase): class TestCursor(unittest.TestCase):
@ -1100,8 +1103,11 @@ self.assertFalse(c2.alive)
pass pass
client = self.db.connection client = self.db.connection
ctx = catch_warnings()
try: try:
warnings.simplefilter("ignore", DeprecationWarning)
client.set_cursor_manager(CManager) client.set_cursor_manager(CManager)
docs = [] docs = []
cursor = self.db.test.find().batch_size(10) cursor = self.db.test.find().batch_size(10)
docs.append(cursor.next()) docs.append(cursor.next())
@ -1115,6 +1121,7 @@ self.assertFalse(c2.alive)
self.assertEqual(len(docs), 200) self.assertEqual(len(docs), 200)
finally: finally:
client.set_cursor_manager(CursorManager) client.set_cursor_manager(CursorManager)
ctx.exit()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -45,10 +45,11 @@ from pymongo.errors import (CollectionInvalid,
OperationFailure) OperationFailure)
from pymongo.son_manipulator import (AutoReference, from pymongo.son_manipulator import (AutoReference,
NamespaceInjector, NamespaceInjector,
SONManipulator,
ObjectIdShuffler) ObjectIdShuffler)
from test import version from test import version
from test.utils import (get_command_line, is_mongos, from test.utils import (catch_warnings, get_command_line,
remove_all_users, server_started_with_auth) is_mongos, remove_all_users, server_started_with_auth)
from test.test_client import get_client from test.test_client import get_client
@ -110,7 +111,14 @@ class TestDatabase(unittest.TestCase):
db.drop_collection("test.foo") db.drop_collection("test.foo")
db.create_collection("test.foo") db.create_collection("test.foo")
self.assertTrue(u"test.foo" in db.collection_names()) self.assertTrue(u"test.foo" in db.collection_names())
self.assertEqual(db.test.foo.options(), {}) expected = {}
if version.at_least(self.client, (2, 7, 0)):
# usePowerOf2Sizes server default
expected["flags"] = 1
result = db.test.foo.options()
# mongos 2.2.x adds an $auth field when auth is enabled.
result.pop('$auth', None)
self.assertEqual(result, expected)
self.assertRaises(CollectionInvalid, db.create_collection, "test.foo") self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
def test_collection_names(self): def test_collection_names(self):
@ -368,15 +376,15 @@ class TestDatabase(unittest.TestCase):
"user", 'password', True, roles=['read']) "user", 'password', True, roles=['read'])
if version.at_least(self.client, (2, 5, 3, -1)): if version.at_least(self.client, (2, 5, 3, -1)):
warnings.simplefilter("error", DeprecationWarning) ctx = catch_warnings()
try: try:
warnings.simplefilter("error", DeprecationWarning)
self.assertRaises(DeprecationWarning, db.add_user, self.assertRaises(DeprecationWarning, db.add_user,
"user", "password") "user", "password")
self.assertRaises(DeprecationWarning, db.add_user, self.assertRaises(DeprecationWarning, db.add_user,
"user", "password", True) "user", "password", True)
finally: finally:
warnings.resetwarnings() ctx.exit()
warnings.simplefilter("ignore")
self.assertRaises(ConfigurationError, db.add_user, self.assertRaises(ConfigurationError, db.add_user,
"user", "password", digestPassword=True) "user", "password", digestPassword=True)
@ -601,8 +609,8 @@ class TestDatabase(unittest.TestCase):
def test_authenticate_multiple(self): def test_authenticate_multiple(self):
client = get_client() client = get_client()
if (is_mongos(client) and not if (is_mongos(client) and not
version.at_least(self.client, (2, 0, 0))): version.at_least(self.client, (2, 2, 0))):
raise SkipTest("Auth with sharding requires MongoDB >= 2.0.0") raise SkipTest("Need mongos >= 2.2.0")
if not server_started_with_auth(client): if not server_started_with_auth(client):
raise SkipTest("Authentication is not enabled on server") raise SkipTest("Authentication is not enabled on server")
@ -947,18 +955,18 @@ class TestDatabase(unittest.TestCase):
self.fail("_check_command_response didn't raise OperationFailure") self.fail("_check_command_response didn't raise OperationFailure")
def test_command_read_pref_warning(self): def test_command_read_pref_warning(self):
warnings.simplefilter("error", UserWarning) ctx = catch_warnings()
try: try:
warnings.simplefilter("error", UserWarning)
self.assertRaises(UserWarning, self.client.pymongo_test.command, self.assertRaises(UserWarning, self.client.pymongo_test.command,
'ping', read_preference=ReadPreference.SECONDARY) 'ping', read_preference=ReadPreference.SECONDARY)
try: try:
self.client.pymongo_test.command( self.client.pymongo_test.command('dbStats',
'dbStats', read_preference=ReadPreference.SECONDARY) read_preference=ReadPreference.SECONDARY_PREFERRED)
except UserWarning: except UserWarning:
self.fail("Shouldn't have raised UserWarning.") self.fail("Shouldn't have raised UserWarning.")
finally: finally:
warnings.resetwarnings() ctx.exit()
warnings.simplefilter("ignore")
def test_command_max_time_ms(self): def test_command_max_time_ms(self):
if not version.at_least(self.client, (2, 5, 3, -1)): if not version.at_least(self.client, (2, 5, 3, -1)):
@ -989,6 +997,29 @@ class TestDatabase(unittest.TestCase):
"maxTimeAlwaysTimeOut", "maxTimeAlwaysTimeOut",
mode="off") mode="off")
def test_object_to_dict_transformer(self):
# PYTHON-709: Some users rely on their custom SONManipulators to run
# before any other checks, so they can insert non-dict objects and
# have them dictified before the _id is inserted or any other
# processing.
class Thing(object):
def __init__(self, value):
self.value = value
class ThingTransformer(SONManipulator):
def transform_incoming(self, thing, collection):
return {'value': thing.value}
db = self.client.foo
db.add_son_manipulator(ThingTransformer())
t = Thing('value')
db.test.remove()
db.test.insert([t])
out = db.test.find_one()
self.assertEqual('value', out.get('value'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -16,25 +16,25 @@
"""Tests for the gridfs package. """Tests for the gridfs package.
""" """
import sys
sys.path[0:0] = [""]
from pymongo.mongo_client import MongoClient
from pymongo.errors import ConnectionFailure
from pymongo.read_preferences import ReadPreference
from test.test_replica_set_client import TestReplicaSetClientBase
import datetime import datetime
import unittest import sys
import threading import threading
import time import time
import unittest
import warnings
sys.path[0:0] = [""]
import gridfs import gridfs
from bson.py3compat import b, StringIO from bson.py3compat import b, StringIO
from gridfs.errors import (FileExists, from gridfs.errors import (FileExists, NoFile)
NoFile) from pymongo.errors import ConnectionFailure
from pymongo.mongo_client import MongoClient
from pymongo.read_preferences import ReadPreference
from test.test_client import get_client from test.test_client import get_client
from test.utils import joinall from test.test_replica_set_client import TestReplicaSetClientBase
from test.utils import catch_warnings, joinall
class JustWrite(threading.Thread): class JustWrite(threading.Thread):
@ -416,20 +416,25 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase):
primary_connection = MongoClient(primary_host, primary_port) primary_connection = MongoClient(primary_host, primary_port)
secondary_host, secondary_port = self.secondaries[0] secondary_host, secondary_port = self.secondaries[0]
for secondary_connection in [ ctx = catch_warnings()
MongoClient(secondary_host, secondary_port, slave_okay=True), try:
MongoClient(secondary_host, secondary_port, warnings.simplefilter("ignore", DeprecationWarning)
read_preference=ReadPreference.SECONDARY), for secondary_connection in [
]: MongoClient(secondary_host, secondary_port, slave_okay=True),
primary_connection.pymongo_test.drop_collection("fs.files") MongoClient(secondary_host, secondary_port,
primary_connection.pymongo_test.drop_collection("fs.chunks") read_preference=ReadPreference.SECONDARY),
]:
primary_connection.pymongo_test.drop_collection("fs.files")
primary_connection.pymongo_test.drop_collection("fs.chunks")
# Should detect it's connected to secondary and not attempt to # Should detect it's connected to secondary and not attempt to
# create index # create index
fs = gridfs.GridFS(secondary_connection.pymongo_test) fs = gridfs.GridFS(secondary_connection.pymongo_test)
# This won't detect secondary, raises error # This won't detect secondary, raises error
self.assertRaises(ConnectionFailure, fs.put, b('foo')) self.assertRaises(ConnectionFailure, fs.put, b('foo'))
finally:
ctx.exit()
def test_gridfs_secondary_lazy(self): def test_gridfs_secondary_lazy(self):
# Should detect it's connected to secondary and not attempt to # Should detect it's connected to secondary and not attempt to

View File

@ -17,6 +17,7 @@
import sys import sys
import unittest import unittest
import warnings
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -28,27 +29,33 @@ from pymongo.replica_set_connection import ReplicaSetConnection
from pymongo.errors import ConfigurationError from pymongo.errors import ConfigurationError
from test import host, port, pair from test import host, port, pair
from test.test_replica_set_client import TestReplicaSetClientBase from test.test_replica_set_client import TestReplicaSetClientBase
from test.utils import get_pool from test.utils import catch_warnings, get_pool
class TestConnection(unittest.TestCase): class TestConnection(unittest.TestCase):
def test_connection(self): def test_connection(self):
c = Connection(host, port) c = Connection(host, port)
self.assertTrue(c.auto_start_request)
self.assertEqual(None, c.max_pool_size)
self.assertFalse(c.slave_okay)
self.assertFalse(c.safe)
self.assertEqual({}, c.get_lasterror_options())
# Connection's writes are unacknowledged by default ctx = catch_warnings()
doc = {"_id": ObjectId()} try:
coll = c.pymongo_test.write_concern_test warnings.simplefilter("ignore", DeprecationWarning)
coll.drop() self.assertTrue(c.auto_start_request)
coll.insert(doc) self.assertEqual(None, c.max_pool_size)
coll.insert(doc) self.assertFalse(c.slave_okay)
self.assertFalse(c.safe)
self.assertEqual({}, c.get_lasterror_options())
c = Connection("mongodb://%s:%s/?safe=true" % (host, port)) # Connection's writes are unacknowledged by default
self.assertTrue(c.safe) doc = {"_id": ObjectId()}
coll = c.pymongo_test.write_concern_test
coll.drop()
coll.insert(doc)
coll.insert(doc)
c = Connection("mongodb://%s:%s/?safe=true" % (host, port))
self.assertTrue(c.safe)
finally:
ctx.exit()
# To preserve legacy Connection's behavior, max_size should be None. # To preserve legacy Connection's behavior, max_size should be None.
# Pool should handle this without error. # Pool should handle this without error.
@ -73,23 +80,29 @@ class TestConnection(unittest.TestCase):
class TestReplicaSetConnection(TestReplicaSetClientBase): class TestReplicaSetConnection(TestReplicaSetClientBase):
def test_replica_set_connection(self): def test_replica_set_connection(self):
c = ReplicaSetConnection(pair, replicaSet=self.name) c = ReplicaSetConnection(pair, replicaSet=self.name)
self.assertTrue(c.auto_start_request)
self.assertEqual(None, c.max_pool_size)
self.assertFalse(c.slave_okay)
self.assertFalse(c.safe)
self.assertEqual({}, c.get_lasterror_options())
# ReplicaSetConnection's writes are unacknowledged by default ctx = catch_warnings()
doc = {"_id": ObjectId()} try:
coll = c.pymongo_test.write_concern_test warnings.simplefilter("ignore", DeprecationWarning)
coll.drop() self.assertTrue(c.auto_start_request)
coll.insert(doc) self.assertEqual(None, c.max_pool_size)
coll.insert(doc) self.assertFalse(c.slave_okay)
self.assertFalse(c.safe)
self.assertEqual({}, c.get_lasterror_options())
c = ReplicaSetConnection("mongodb://%s:%s/?replicaSet=%s&safe=true" % ( # ReplicaSetConnection's writes are unacknowledged by default
host, port, self.name)) doc = {"_id": ObjectId()}
coll = c.pymongo_test.write_concern_test
coll.drop()
coll.insert(doc)
coll.insert(doc)
self.assertTrue(c.safe) c = ReplicaSetConnection("mongodb://%s:%s/?replicaSet=%s&safe=true" % (
host, port, self.name))
self.assertTrue(c.safe)
finally:
ctx.exit()
# To preserve legacy ReplicaSetConnection's behavior, max_size should # To preserve legacy ReplicaSetConnection's behavior, max_size should
# be None. Pool should handle this without error. # be None. Pool should handle this without error.

View File

@ -20,6 +20,8 @@ import sys
import threading import threading
import time import time
import unittest import unittest
import warnings
sys.path[0:0] = [""] sys.path[0:0] = [""]
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
@ -35,7 +37,7 @@ from pymongo.mongo_client import MongoClient
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.master_slave_connection import MasterSlaveConnection from pymongo.master_slave_connection import MasterSlaveConnection
from test import host, port, host2, port2, host3, port3 from test import host, port, host2, port2, host3, port3
from test.utils import TestRequestMixin, get_pool from test.utils import TestRequestMixin, catch_warnings, get_pool
class TestMasterSlaveConnection(unittest.TestCase, TestRequestMixin): class TestMasterSlaveConnection(unittest.TestCase, TestRequestMixin):
@ -363,7 +365,7 @@ class TestMasterSlaveConnection(unittest.TestCase, TestRequestMixin):
def test_kill_cursor_explicit(self): def test_kill_cursor_explicit(self):
c = self.client c = self.client
c.slave_okay = True c.read_preference = ReadPreference.SECONDARY_PREFERRED
db = c.pymongo_test db = c.pymongo_test
test = db.master_slave_test_kill_cursor_explicit test = db.master_slave_test_kill_cursor_explicit
@ -412,60 +414,65 @@ class TestMasterSlaveConnection(unittest.TestCase, TestRequestMixin):
self.assertRaises(OperationFailure, lambda: list(cursor2)) self.assertRaises(OperationFailure, lambda: list(cursor2))
def test_base_object(self): def test_base_object(self):
c = self.client ctx = catch_warnings()
self.assertFalse(c.slave_okay) try:
self.assertTrue(bool(c.read_preference)) warnings.simplefilter("ignore", DeprecationWarning)
self.assertTrue(c.safe) c = self.client
self.assertEqual({}, c.get_lasterror_options()) self.assertFalse(c.slave_okay)
db = c.pymongo_test self.assertTrue(bool(c.read_preference))
self.assertFalse(db.slave_okay) self.assertTrue(c.safe)
self.assertTrue(bool(c.read_preference)) self.assertEqual({}, c.get_lasterror_options())
self.assertTrue(db.safe) db = c.pymongo_test
self.assertEqual({}, db.get_lasterror_options()) self.assertFalse(db.slave_okay)
coll = db.test self.assertTrue(bool(c.read_preference))
coll.drop() self.assertTrue(db.safe)
self.assertFalse(coll.slave_okay) self.assertEqual({}, db.get_lasterror_options())
self.assertTrue(bool(c.read_preference)) coll = db.test
self.assertTrue(coll.safe) coll.drop()
self.assertEqual({}, coll.get_lasterror_options()) self.assertFalse(coll.slave_okay)
cursor = coll.find() self.assertTrue(bool(c.read_preference))
self.assertFalse(cursor._Cursor__slave_okay) self.assertTrue(coll.safe)
self.assertTrue(bool(cursor._Cursor__read_preference)) self.assertEqual({}, coll.get_lasterror_options())
cursor = coll.find()
self.assertFalse(cursor._Cursor__slave_okay)
self.assertTrue(bool(cursor._Cursor__read_preference))
w = 1 + len(self.slaves) w = 1 + len(self.slaves)
wtimeout=10000 # Wait 10 seconds for replication to complete wtimeout=10000 # Wait 10 seconds for replication to complete
c.set_lasterror_options(w=w, wtimeout=wtimeout) c.set_lasterror_options(w=w, wtimeout=wtimeout)
self.assertFalse(c.slave_okay) self.assertFalse(c.slave_okay)
self.assertTrue(bool(c.read_preference)) self.assertTrue(bool(c.read_preference))
self.assertTrue(c.safe) self.assertTrue(c.safe)
self.assertEqual({'w': w, 'wtimeout': wtimeout}, c.get_lasterror_options()) self.assertEqual({'w': w, 'wtimeout': wtimeout}, c.get_lasterror_options())
db = c.pymongo_test db = c.pymongo_test
self.assertFalse(db.slave_okay) self.assertFalse(db.slave_okay)
self.assertTrue(bool(c.read_preference)) self.assertTrue(bool(c.read_preference))
self.assertTrue(db.safe) self.assertTrue(db.safe)
self.assertEqual({'w': w, 'wtimeout': wtimeout}, db.get_lasterror_options()) self.assertEqual({'w': w, 'wtimeout': wtimeout}, db.get_lasterror_options())
coll = db.test coll = db.test
self.assertFalse(coll.slave_okay) self.assertFalse(coll.slave_okay)
self.assertTrue(bool(c.read_preference)) self.assertTrue(bool(c.read_preference))
self.assertTrue(coll.safe) self.assertTrue(coll.safe)
self.assertEqual({'w': w, 'wtimeout': wtimeout}, self.assertEqual({'w': w, 'wtimeout': wtimeout},
coll.get_lasterror_options()) coll.get_lasterror_options())
cursor = coll.find() cursor = coll.find()
self.assertFalse(cursor._Cursor__slave_okay) self.assertFalse(cursor._Cursor__slave_okay)
self.assertTrue(bool(cursor._Cursor__read_preference)) self.assertTrue(bool(cursor._Cursor__read_preference))
coll.insert({'foo': 'bar'}) coll.insert({'foo': 'bar'})
self.assertEqual(1, coll.find({'foo': 'bar'}).count()) self.assertEqual(1, coll.find({'foo': 'bar'}).count())
self.assertTrue(coll.find({'foo': 'bar'})) self.assertTrue(coll.find({'foo': 'bar'}))
coll.remove({'foo': 'bar'}) coll.remove({'foo': 'bar'})
self.assertEqual(0, coll.find({'foo': 'bar'}).count()) self.assertEqual(0, coll.find({'foo': 'bar'}).count())
c.safe = False c.safe = False
c.unset_lasterror_options() c.unset_lasterror_options()
self.assertFalse(self.client.slave_okay) self.assertFalse(self.client.slave_okay)
self.assertTrue(bool(self.client.read_preference)) self.assertTrue(bool(self.client.read_preference))
self.assertFalse(self.client.safe) self.assertFalse(self.client.safe)
self.assertEqual({}, self.client.get_lasterror_options()) self.assertEqual({}, self.client.get_lasterror_options())
finally:
ctx.exit()
def test_document_class(self): def test_document_class(self):
c = MasterSlaveConnection(self.master, self.slaves) c = MasterSlaveConnection(self.master, self.slaves)

View File

@ -181,6 +181,7 @@ class TestObjectId(unittest.TestCase):
self.assertEqual(oid_1_9, oid_1_10) self.assertEqual(oid_1_9, oid_1_10)
def test_is_valid(self): def test_is_valid(self):
self.assertFalse(ObjectId.is_valid(None))
self.assertFalse(ObjectId.is_valid(4)) self.assertFalse(ObjectId.is_valid(4))
self.assertFalse(ObjectId.is_valid(175.0)) self.assertFalse(ObjectId.is_valid(175.0))
self.assertFalse(ObjectId.is_valid({"test": 4})) self.assertFalse(ObjectId.is_valid({"test": 4}))

View File

@ -14,9 +14,9 @@
"""Test the replica_set_connection module.""" """Test the replica_set_connection module."""
import random import random
import sys import sys
import unittest import unittest
import warnings
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
@ -32,6 +32,7 @@ from pymongo.errors import ConfigurationError
from test.test_replica_set_client import TestReplicaSetClientBase from test.test_replica_set_client import TestReplicaSetClientBase
from test.test_client import get_client from test.test_client import get_client
from test import version, utils, host, port from test import version, utils, host, port
from test.utils import catch_warnings
class TestReadPreferencesBase(TestReplicaSetClientBase): class TestReadPreferencesBase(TestReplicaSetClientBase):
@ -278,8 +279,13 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# Test generic 'command' method. Some commands obey read preference, # Test generic 'command' method. Some commands obey read preference,
# most don't. # most don't.
# Disobedient commands, always go to primary # Disobedient commands, always go to primary
self._test_fn(False, lambda: self.c.pymongo_test.command('ping')) ctx = catch_warnings()
self._test_fn(False, lambda: self.c.admin.command('buildinfo')) try:
warnings.simplefilter("ignore", UserWarning)
self._test_fn(False, lambda: self.c.pymongo_test.command('ping'))
self._test_fn(False, lambda: self.c.admin.command('buildinfo'))
finally:
ctx.exit()
# Obedient commands. # Obedient commands.
self._test_fn(True, lambda: self.c.pymongo_test.command('group', { self._test_fn(True, lambda: self.c.pymongo_test.command('group', {
@ -303,11 +309,11 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# Distinct # Distinct
self._test_fn(True, lambda: self.c.pymongo_test.command( self._test_fn(True, lambda: self.c.pymongo_test.command(
'distinct', 'test', key={'a': 1})) 'distinct', 'test', key='a'))
self._test_fn(True, lambda: self.c.pymongo_test.command( self._test_fn(True, lambda: self.c.pymongo_test.command(
'distinct', 'test', key={'a': 1}, query={'a': 1})) 'distinct', 'test', key='a', query={'a': 1}))
self._test_fn(True, lambda: self.c.pymongo_test.command(SON([ self._test_fn(True, lambda: self.c.pymongo_test.command(SON([
('distinct', 'test'), ('key', {'a': 1}), ('query', {'a': 1})]))) ('distinct', 'test'), ('key', 'a'), ('query', {'a': 1})])))
# Geo stuff. Make sure a 2d index is created and replicated # Geo stuff. Make sure a 2d index is created and replicated
self.c.pymongo_test.system.indexes.insert({ self.c.pymongo_test.system.indexes.insert({
@ -342,7 +348,12 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# Text search. # Text search.
if version.at_least(self.c, (2, 3, 2)): if version.at_least(self.c, (2, 3, 2)):
utils.enable_text_search(self.c) ctx = catch_warnings()
try:
warnings.simplefilter("ignore", UserWarning)
utils.enable_text_search(self.c)
finally:
ctx.exit()
db = self.c.pymongo_test db = self.c.pymongo_test
# Only way to create an index and wait for all members to build it. # Only way to create an index and wait for all members to build it.
@ -366,20 +377,25 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# Non-inline mapreduce always goes to primary, doesn't obey read prefs. # Non-inline mapreduce always goes to primary, doesn't obey read prefs.
# Test with command in a SON and with kwargs # Test with command in a SON and with kwargs
self._test_fn(False, lambda: self.c.pymongo_test.command(SON([ ctx = catch_warnings()
('mapreduce', 'test'), try:
('map', 'function() { }'), warnings.simplefilter("ignore", UserWarning)
('reduce', 'function() { }'), self._test_fn(False, lambda: self.c.pymongo_test.command(SON([
('out', 'mr_out') ('mapreduce', 'test'),
]))) ('map', 'function() { }'),
('reduce', 'function() { }'),
('out', 'mr_out')
])))
self._test_fn(False, lambda: self.c.pymongo_test.command( self._test_fn(False, lambda: self.c.pymongo_test.command(
'mapreduce', 'test', map='function() { }', 'mapreduce', 'test', map='function() { }',
reduce='function() { }', out='mr_out')) reduce='function() { }', out='mr_out'))
self._test_fn(False, lambda: self.c.pymongo_test.command( self._test_fn(False, lambda: self.c.pymongo_test.command(
'mapreduce', 'test', map='function() { }', 'mapreduce', 'test', map='function() { }',
reduce='function() { }', out={'replace': 'some_collection'})) reduce='function() { }', out={'replace': 'some_collection'}))
finally:
ctx.exit()
# Inline mapreduce obeys read prefs # Inline mapreduce obeys read prefs
self._test_fn(True, lambda: self.c.pymongo_test.command( self._test_fn(True, lambda: self.c.pymongo_test.command(
@ -405,32 +421,47 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# Aggregate with $out always goes to primary, doesn't obey read prefs. # Aggregate with $out always goes to primary, doesn't obey read prefs.
# Test aggregate command sent directly to db.command. # Test aggregate command sent directly to db.command.
self._test_fn(False, lambda: self.c.pymongo_test.command( ctx = catch_warnings()
"aggregate", "test", try:
pipeline=[{"$match": {"x": 1}}, {"$out": "agg_out"}] warnings.simplefilter("ignore", UserWarning)
)) self._test_fn(False, lambda: self.c.pymongo_test.command(
"aggregate", "test",
pipeline=[{"$match": {"x": 1}}, {"$out": "agg_out"}]
))
# Test aggregate when sent through the collection aggregate function. # Test aggregate when sent through the collection aggregate function.
self._test_fn(False, lambda: self.c.pymongo_test.test.aggregate( self._test_fn(False, lambda: self.c.pymongo_test.test.aggregate(
[{"$match": {"x": 2}}, {"$out": "agg_out"}] [{"$match": {"x": 2}}, {"$out": "agg_out"}]
)) ))
finally:
ctx.exit()
self.c.pymongo_test.drop_collection("test") self.c.pymongo_test.drop_collection("test")
self.c.pymongo_test.drop_collection("agg_out") self.c.pymongo_test.drop_collection("agg_out")
def test_create_collection(self): def test_create_collection(self):
# Collections should be created on primary, obviously # Collections should be created on primary, obviously
self._test_fn(False, lambda: self.c.pymongo_test.command( ctx = catch_warnings()
'create', 'some_collection%s' % random.randint(0, sys.maxint))) try:
warnings.simplefilter("ignore", UserWarning)
self._test_fn(False, lambda: self.c.pymongo_test.command(
'create', 'some_collection%s' % random.randint(0, sys.maxint)))
self._test_fn(False, lambda: self.c.pymongo_test.create_collection( self._test_fn(False, lambda: self.c.pymongo_test.create_collection(
'some_collection%s' % random.randint(0, sys.maxint))) 'some_collection%s' % random.randint(0, sys.maxint)))
finally:
ctx.exit()
def test_drop_collection(self): def test_drop_collection(self):
self._test_fn(False, lambda: self.c.pymongo_test.drop_collection( ctx = catch_warnings()
'some_collection')) try:
warnings.simplefilter("ignore", UserWarning)
self._test_fn(False, lambda: self.c.pymongo_test.drop_collection(
'some_collection'))
self._test_fn(False, lambda: self.c.pymongo_test.some_collection.drop()) self._test_fn(False, lambda: self.c.pymongo_test.some_collection.drop())
finally:
ctx.exit()
def test_group(self): def test_group(self):
self._test_fn(True, lambda: self.c.pymongo_test.test.group( self._test_fn(True, lambda: self.c.pymongo_test.test.group(
@ -440,8 +471,13 @@ class TestCommandAndReadPreference(TestReplicaSetClientBase):
# mapreduce fails if no collection # mapreduce fails if no collection
self.c.pymongo_test.test.insert({}, w=self.w) self.c.pymongo_test.test.insert({}, w=self.w)
self._test_fn(False, lambda: self.c.pymongo_test.test.map_reduce( ctx = catch_warnings()
'function() { }', 'function() { }', 'mr_out')) try:
warnings.simplefilter("ignore", UserWarning)
self._test_fn(False, lambda: self.c.pymongo_test.test.map_reduce(
'function() { }', 'function() { }', 'mr_out'))
finally:
ctx.exit()
self._test_fn(True, lambda: self.c.pymongo_test.test.map_reduce( self._test_fn(True, lambda: self.c.pymongo_test.test.map_reduce(
'function() { }', 'function() { }', {'inline': 1})) 'function() { }', 'function() { }', {'inline': 1}))
@ -517,83 +553,93 @@ class TestMongosConnection(unittest.TestCase):
NEAREST = ReadPreference.NEAREST NEAREST = ReadPreference.NEAREST
SLAVE_OKAY = _QUERY_OPTIONS['slave_okay'] SLAVE_OKAY = _QUERY_OPTIONS['slave_okay']
# Test non-PRIMARY modes which can be combined with tags ctx = catch_warnings()
for kwarg, value, mongos_mode in ( try:
('read_preference', PRIMARY_PREFERRED, 'primaryPreferred'), warnings.simplefilter("ignore", DeprecationWarning)
('read_preference', SECONDARY, 'secondary'), # Test non-PRIMARY modes which can be combined with tags
('read_preference', SECONDARY_PREFERRED, 'secondaryPreferred'), for kwarg, value, mongos_mode in (
('read_preference', NEAREST, 'nearest'), ('read_preference', PRIMARY_PREFERRED, 'primaryPreferred'),
('slave_okay', True, 'secondaryPreferred'), ('read_preference', SECONDARY, 'secondary'),
('slave_okay', False, 'primary') ('read_preference', SECONDARY_PREFERRED, 'secondaryPreferred'),
): ('read_preference', NEAREST, 'nearest'),
for tag_sets in ( ('slave_okay', True, 'secondaryPreferred'),
None, [{}] ('slave_okay', False, 'primary')
): ):
# Create a client e.g. with read_preference=NEAREST or for tag_sets in (
# slave_okay=True None, [{}]
c = get_client(tag_sets=tag_sets, **{kwarg: value}) ):
# Create a client e.g. with read_preference=NEAREST or
# slave_okay=True
c = get_client(tag_sets=tag_sets, **{kwarg: value})
self.assertEqual(is_mongos, c.is_mongos) self.assertEqual(is_mongos, c.is_mongos)
cursor = c.pymongo_test.test.find() cursor = c.pymongo_test.test.find()
if is_mongos: if is_mongos:
# We don't set $readPreference for SECONDARY_PREFERRED # We don't set $readPreference for SECONDARY_PREFERRED
# unless tags are in use. slaveOkay has the same effect. # unless tags are in use. slaveOkay has the same effect.
if mongos_mode == 'secondaryPreferred': if mongos_mode == 'secondaryPreferred':
self.assertEqual( self.assertEqual(
None, None,
cursor._Cursor__query_spec().get('$readPreference')) cursor._Cursor__query_spec().get('$readPreference'))
self.assertTrue( self.assertTrue(
cursor._Cursor__query_options() & SLAVE_OKAY) cursor._Cursor__query_options() & SLAVE_OKAY)
# Don't send $readPreference for PRIMARY either # Don't send $readPreference for PRIMARY either
elif mongos_mode == 'primary': elif mongos_mode == 'primary':
self.assertEqual( self.assertEqual(
None, None,
cursor._Cursor__query_spec().get('$readPreference')) cursor._Cursor__query_spec().get('$readPreference'))
self.assertFalse( self.assertFalse(
cursor._Cursor__query_options() & SLAVE_OKAY) cursor._Cursor__query_options() & SLAVE_OKAY)
else:
self.assertEqual(
{'mode': mongos_mode},
cursor._Cursor__query_spec().get('$readPreference'))
self.assertTrue(
cursor._Cursor__query_options() & SLAVE_OKAY)
else: else:
self.assertFalse(
'$readPreference' in cursor._Cursor__query_spec())
for tag_sets in (
[{'dc': 'la'}],
[{'dc': 'la'}, {'dc': 'sf'}],
[{'dc': 'la'}, {'dc': 'sf'}, {}],
):
if kwarg == 'slave_okay':
# Can't use tags with slave_okay True or False, need a
# real read preference
self.assertRaises(
ConfigurationError,
get_client, tag_sets=tag_sets, **{kwarg: value})
continue
c = get_client(tag_sets=tag_sets, **{kwarg: value})
self.assertEqual(is_mongos, c.is_mongos)
cursor = c.pymongo_test.test.find()
if is_mongos:
self.assertEqual( self.assertEqual(
{'mode': mongos_mode}, {'mode': mongos_mode, 'tags': tag_sets},
cursor._Cursor__query_spec().get('$readPreference')) cursor._Cursor__query_spec().get('$readPreference'))
else:
self.assertTrue( self.assertFalse(
cursor._Cursor__query_options() & SLAVE_OKAY) '$readPreference' in cursor._Cursor__query_spec())
else: finally:
self.assertFalse( ctx.exit()
'$readPreference' in cursor._Cursor__query_spec())
for tag_sets in (
[{'dc': 'la'}],
[{'dc': 'la'}, {'dc': 'sf'}],
[{'dc': 'la'}, {'dc': 'sf'}, {}],
):
if kwarg == 'slave_okay':
# Can't use tags with slave_okay True or False, need a
# real read preference
self.assertRaises(
ConfigurationError,
get_client, tag_sets=tag_sets, **{kwarg: value})
continue
c = get_client(tag_sets=tag_sets, **{kwarg: value})
self.assertEqual(is_mongos, c.is_mongos)
cursor = c.pymongo_test.test.find()
if is_mongos:
self.assertEqual(
{'mode': mongos_mode, 'tags': tag_sets},
cursor._Cursor__query_spec().get('$readPreference'))
else:
self.assertFalse(
'$readPreference' in cursor._Cursor__query_spec())
def test_only_secondary_ok_commands_have_read_prefs(self): def test_only_secondary_ok_commands_have_read_prefs(self):
c = get_client(read_preference=ReadPreference.SECONDARY) c = get_client(read_preference=ReadPreference.SECONDARY)
is_mongos = utils.is_mongos(c) ctx = catch_warnings()
try:
warnings.simplefilter("ignore", UserWarning)
is_mongos = utils.is_mongos(c)
finally:
ctx.exit()
if not is_mongos: if not is_mongos:
raise SkipTest("Only mongos have read_prefs added to the spec") raise SkipTest("Only mongos have read_prefs added to the spec")

View File

@ -45,13 +45,14 @@ from pymongo.errors import (AutoReconnect,
ConnectionFailure, ConnectionFailure,
InvalidName, InvalidName,
OperationFailure, InvalidOperation) OperationFailure, InvalidOperation)
from pymongo import auth
from test import version, port, pair from test import version, port, pair
from test.pymongo_mocks import MockReplicaSetClient from test.pymongo_mocks import MockReplicaSetClient
from test.utils import ( from test.utils import (
delay, assertReadFrom, assertReadFromAll, read_from_which_host, delay, assertReadFrom, assertReadFromAll, read_from_which_host,
remove_all_users, assertRaisesExactly, TestRequestMixin, one, remove_all_users, assertRaisesExactly, TestRequestMixin, one,
server_started_with_auth, pools_from_rs_client, get_pool, server_started_with_auth, pools_from_rs_client, get_pool,
_TestLazyConnectMixin) _TestLazyConnectMixin, _TestExhaustCursorMixin)
class TestReplicaSetClientAgainstStandalone(unittest.TestCase): class TestReplicaSetClientAgainstStandalone(unittest.TestCase):
@ -260,7 +261,6 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
read_preference=ReadPreference.SECONDARY, read_preference=ReadPreference.SECONDARY,
tag_sets=copy.deepcopy(tag_sets), tag_sets=copy.deepcopy(tag_sets),
secondary_acceptable_latency_ms=77) secondary_acceptable_latency_ms=77)
c.admin.command('ping')
self.assertEqual(c.primary, self.primary) self.assertEqual(c.primary, self.primary)
self.assertEqual(c.hosts, self.hosts) self.assertEqual(c.hosts, self.hosts)
self.assertEqual(c.arbiters, self.arbiters) self.assertEqual(c.arbiters, self.arbiters)
@ -1127,6 +1127,42 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertFalse(client.alive()) self.assertFalse(client.alive())
def test_auth_network_error(self):
# Make sure there's no semaphore leak if we get a network error
# when authenticating a new socket with cached credentials.
auth_client = self._get_client()
if not server_started_with_auth(auth_client):
raise SkipTest('Authentication is not enabled on server')
auth_client.admin.add_user('admin', 'password')
auth_client.admin.authenticate('admin', 'password')
try:
# Get a client with one socket so we detect if it's leaked.
c = self._get_client(max_pool_size=1, waitQueueTimeoutMS=1)
# Simulate an authenticate() call on a different socket.
credentials = auth._build_credentials_tuple(
'MONGODB-CR', 'admin',
unicode('admin'), unicode('password'),
{})
c._cache_credentials('test', credentials, connect=False)
# Cause a network error on the actual socket.
pool = get_pool(c)
socket_info = one(pool.sockets)
socket_info.sock.close()
# In __check_auth, the client authenticates its socket with the
# new credential, but gets a socket.error. Should be reraised as
# AutoReconnect.
self.assertRaises(AutoReconnect, c.test.collection.find_one)
# No semaphore leak, the pool is allowed to make a new socket.
c.test.collection.find_one()
finally:
remove_all_users(auth_client.admin)
class TestReplicaSetWireVersion(unittest.TestCase): class TestReplicaSetWireVersion(unittest.TestCase):
def test_wire_version(self): def test_wire_version(self):
@ -1238,5 +1274,12 @@ class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase):
self.assertEqual(c.max_write_batch_size, 2) self.assertEqual(c.max_write_batch_size, 2)
class TestReplicaSetClientExhaustCursor(
_TestExhaustCursorMixin,
TestReplicaSetClientBase):
# Base class implements _get_client already.
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -67,6 +67,15 @@ class TestSON(unittest.TestCase):
('mike', 'awesome'), ('mike', 'awesome'),
('hello', 'world')))) ('hello', 'world'))))
# Embedded SON.
d4 = SON([('blah', {'foo': SON()})])
self.assertEqual(d4, {'blah': {'foo': {}}})
self.assertEqual(d4, {'blah': {'foo': SON()}})
self.assertNotEqual(d4, {'blah': {'foo': []}})
# Original data unaffected.
self.assertEqual(SON, d4['blah']['foo'].__class__)
def test_to_dict(self): def test_to_dict(self):
a1 = SON() a1 = SON()
b2 = SON([("blah", SON())]) b2 = SON([("blah", SON())])
@ -81,6 +90,9 @@ class TestSON(unittest.TestCase):
self.assertEqual(dict, c3.to_dict()["blah"][0].__class__) self.assertEqual(dict, c3.to_dict()["blah"][0].__class__)
self.assertEqual(dict, d4.to_dict()["blah"]["foo"].__class__) self.assertEqual(dict, d4.to_dict()["blah"]["foo"].__class__)
# Original data unaffected.
self.assertEqual(SON, d4['blah']['foo'].__class__)
def test_pickle(self): def test_pickle(self):
simple_son = SON([]) simple_son = SON([])

View File

@ -15,14 +15,18 @@
"""Utilities for testing pymongo """Utilities for testing pymongo
""" """
import gc
import os import os
import struct import struct
import sys import sys
import threading import threading
import time
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from bson.son import SON
from pymongo import MongoClient, MongoReplicaSetClient from pymongo import MongoClient, MongoReplicaSetClient
from pymongo.errors import AutoReconnect from pymongo.errors import AutoReconnect, ConnectionFailure, OperationFailure
from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo
from test import host, port, version from test import host, port, version
@ -453,7 +457,7 @@ def lazy_client_trial(reset, target, test, get_client, use_greenlets):
# Make concurrency bugs more likely to manifest. # Make concurrency bugs more likely to manifest.
interval = None interval = None
if not sys.platform.startswith('java'): if not sys.platform.startswith('java'):
if sys.version_info >= (3, 2): if hasattr(sys, 'getswitchinterval'):
interval = sys.getswitchinterval() interval = sys.getswitchinterval()
sys.setswitchinterval(1e-6) sys.setswitchinterval(1e-6)
else: else:
@ -472,7 +476,7 @@ def lazy_client_trial(reset, target, test, get_client, use_greenlets):
finally: finally:
if not sys.platform.startswith('java'): if not sys.platform.startswith('java'):
if sys.version_info >= (3, 2): if hasattr(sys, 'setswitchinterval'):
sys.setswitchinterval(interval) sys.setswitchinterval(interval)
else: else:
sys.setcheckinterval(interval) sys.setcheckinterval(interval)
@ -584,3 +588,198 @@ class _TestLazyConnectMixin(object):
self.assertEqual( self.assertEqual(
ismaster['maxMessageSizeBytes'], ismaster['maxMessageSizeBytes'],
c.max_message_size) c.max_message_size)
class _TestExhaustCursorMixin(object):
"""Test that clients properly handle errors from exhaust cursors.
Inherit from this class and from unittest.TestCase, and override
_get_client(self, **kwargs).
"""
def test_exhaust_query_server_error(self):
# When doing an exhaust query, the socket stays checked out on success
# but must be checked in on error to avoid semaphore leaks.
client = self._get_client(max_pool_size=1)
if is_mongos(client):
raise SkipTest("Can't use exhaust cursors with mongos")
if not version.at_least(client, (2, 2, 0)):
raise SkipTest("mongod < 2.2.0 closes exhaust socket on error")
collection = client.pymongo_test.test
pool = get_pool(client)
sock_info = one(pool.sockets)
# This will cause OperationFailure in all mongo versions since
# the value for $orderby must be a document.
cursor = collection.find(
SON([('$query', {}), ('$orderby', True)]), exhaust=True)
self.assertRaises(OperationFailure, cursor.next)
self.assertFalse(sock_info.closed)
# The semaphore was decremented despite the error.
self.assertTrue(pool._socket_semaphore.acquire(blocking=False))
def test_exhaust_getmore_server_error(self):
# When doing a getmore on an exhaust cursor, the socket stays checked
# out on success but must be checked in on error to avoid semaphore
# leaks.
client = self._get_client(max_pool_size=1)
if is_mongos(client):
raise SkipTest("Can't use exhaust cursors with mongos")
# A separate client that doesn't affect the test client's pool.
client2 = self._get_client()
collection = client.pymongo_test.test
collection.remove()
# Enough data to ensure it streams down for a few milliseconds.
long_str = 'a' * (256 * 1024)
collection.insert([{'a': long_str} for _ in range(200)])
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
sock_info = one(pool.sockets)
cursor = collection.find(exhaust=True)
# Initial query succeeds.
cursor.next()
# Cause a server error on getmore.
client2.pymongo_test.test.drop()
self.assertRaises(OperationFailure, list, cursor)
# Make sure the socket is still valid
self.assertEqual(0, collection.count())
def test_exhaust_query_network_error(self):
# When doing an exhaust query, the socket stays checked out on success
# but must be checked in on error to avoid semaphore leaks.
client = self._get_client(max_pool_size=1)
if is_mongos(client):
raise SkipTest("Can't use exhaust cursors with mongos")
collection = client.pymongo_test.test
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
# Cause a network error.
sock_info = one(pool.sockets)
sock_info.sock.close()
cursor = collection.find(exhaust=True)
self.assertRaises(ConnectionFailure, cursor.next)
self.assertTrue(sock_info.closed)
# The semaphore was decremented despite the error.
self.assertTrue(pool._socket_semaphore.acquire(blocking=False))
def test_exhaust_getmore_network_error(self):
# When doing a getmore on an exhaust cursor, the socket stays checked
# out on success but must be checked in on error to avoid semaphore
# leaks.
client = self._get_client(max_pool_size=1)
if is_mongos(client):
raise SkipTest("Can't use exhaust cursors with mongos")
collection = client.pymongo_test.test
collection.remove()
collection.insert([{} for _ in range(200)]) # More than one batch.
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
cursor = collection.find(exhaust=True)
# Initial query succeeds.
cursor.next()
# Cause a network error.
sock_info = cursor._Cursor__exhaust_mgr.sock
sock_info.sock.close()
# A getmore fails.
self.assertRaises(ConnectionFailure, list, cursor)
self.assertTrue(sock_info.closed)
# The semaphore was decremented despite the error.
self.assertTrue(pool._socket_semaphore.acquire(blocking=False))
# Backport of WarningMessage from python 2.6, with fixed syntax for python 2.4.
class WarningMessage(object):
"""Holds the result of a single showwarning() call."""
_WARNING_DETAILS = ("message", "category", "filename", "lineno", "file",
"line")
def __init__(self, message, category,
filename, lineno, file=None, line=None):
local_values = locals()
for attr in self._WARNING_DETAILS:
setattr(self, attr, local_values[attr])
self._category_name = None
if category:
self._category_name = category.__name__
def __str__(self):
return ("{message : %r, category : %r, filename : %r, lineno : %s, "
"line : %r}" % (self.message, self._category_name,
self.filename, self.lineno, self.line))
# Rough backport of warnings.catch_warnings from python 2.6,
# with changes to support python 2.4.
class CatchWarnings(object):
"""A non-context manager version of warnings.catch_warnings.
The 'record' argument specifies whether warnings should be captured by a
custom implementation of warnings.showwarning() and be appended to a list
accessed through the `log` property. The objects appended to the list are
arguments whose attributes mirror the arguments to showwarning().
The 'module' argument is to specify an alternative module to the module
named 'warnings' and imported under that name. This argument is only useful
when testing the warnings module itself.
"""
def __init__(self, record=False, module=None):
self._record = record
if module is None:
self._module = sys.modules['warnings']
else:
self._module = module
# No __enter__ so do that work here
self._filters = self._module.filters
self._module.filters = self._filters[:]
self._showwarning = self._module.showwarning
self._log = []
if self._record:
def showwarning(*args, **kwargs):
self._log.append(WarningMessage(*args, **kwargs))
self._module.showwarning = showwarning
@property
def log(self):
"""A list of any warnings recorded when using record=True."""
return self._log
def __repr__(self):
args = []
if self._record:
args.append("record=True")
if self._module is not sys.modules['warnings']:
args.append("module=%r" % self._module)
name = type(self).__name__
return "%s(%s)" % (name, ", ".join(args))
def exit(self):
"""Revert changes to the warnings module."""
self._module.filters = self._filters
self._module.showwarning = self._showwarning
def catch_warnings(record=False, module=None):
"""Helper for use with CatchWarnings."""
return CatchWarnings(record, module)