Compare commits

...

24 Commits
master ... v2.6

Author SHA1 Message Date
Bernie Hackett
8d0167a5f2 BUMP 2.6.3 2013-10-11 10:20:38 -07:00
Bernie Hackett
74c94a4f14 Fix pooling tests to use imported host and port. 2013-10-10 16:09:55 -07:00
Bernie Hackett
1e5165afba Release notes for 2.6.3 2013-10-10 15:07:47 -07:00
Bernie Hackett
bad7e32674 Fix signed / unsigned comparison warnings 2013-10-10 12:58:27 -07:00
Bernie Hackett
17fd05f8fa Fix length check in python 3 PYTHON-571 2013-10-10 12:58:27 -07:00
behackett
cc942556b5 Better handling of corrupt/invalid BSON PYTHON-571 2013-10-10 12:58:27 -07:00
A. Jesse Jiryu Davis
db25c61156 Test refactoring: get_pool() and pools_from_rs_client().
Conflicts:
	test/test_replica_set_client.py
2013-10-10 15:32:02 -04:00
A. Jesse Jiryu Davis
bddee65ed8 Semaphore management during connection failure. PYTHON-580 2013-10-10 15:21:34 -04:00
A. Jesse Jiryu Davis
3db63c45a6 Pool must release its semaphore on connection failure. PYTHON-580 2013-10-10 15:21:26 -04:00
A. Jesse Jiryu Davis
a633a4111d Test that Pool releases its semaphore after connection error. PYTHON-580 2013-10-10 15:21:19 -04:00
A. Jesse Jiryu Davis
5f16f0b49d RS client now accepts waitQueueTimeoutMS and waitQueueMultiple. PYTHON-579
Conflicts:
	pymongo/mongo_replica_set_client.py
	pymongo/replica_set_connection.py
2013-10-10 15:20:30 -04:00
A. Jesse Jiryu Davis
ee67111dc1 Test that RS client accepts waitQueueTimeoutMS and waitQueueMultiple. PYTHON-579 2013-10-10 15:20:26 -04:00
A. Jesse Jiryu Davis
d38b0e1c47 Fix AttributeError in __del__ for Pool(use_greenlets=True) when Gevent is not installed. PYTHON-561 2013-10-10 15:14:24 -04:00
A. Jesse Jiryu Davis
7d0240a88a Test Pool(use_greenlets=True) when Gevent is not installed. PYTHON-561 2013-10-10 15:14:17 -04:00
Bernie Hackett
58b5ecec92 Version -> + 2013-10-08 08:42:34 -07:00
behackett
7fede98efd BUMP 2.6.2 2013-09-06 13:54:29 -07:00
behackett
b8a49d0f13 Release notes for 2.6.2 2013-09-06 13:46:01 -07:00
Bernie Hackett
acb30ec029 Fix tests that were failing with unstable MongoDB. 2013-09-06 12:16:17 -04:00
A. Jesse Jiryu Davis
0dc87faf2f Fix TypeError in Python 3 when max_pool_size is None. PYTHON-566 2013-09-05 21:32:53 -04:00
A. Jesse Jiryu Davis
4868ebd278 Test max_pool_size=None. PYTHON-566
Reproduces TypeError in Python 3 when max_pool_size is None.
2013-09-05 21:32:47 -04:00
behackett
42b73cdae5 Version -> + 2013-09-05 17:42:31 -07:00
behackett
43b30817f9 BUMP 2.6.1 2013-09-03 17:09:02 -07:00
behackett
cbfb243582 Release notes for 2.6.1 2013-09-03 16:23:52 -07:00
A. Jesse Jiryu Davis
d8faa7af00 Fix refleak in insert. PYTHON-564 2013-09-03 13:43:18 -04:00
17 changed files with 447 additions and 140 deletions

View File

@ -139,13 +139,19 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype):
def _get_string(data, position, as_class, tz_aware, uuid_subtype): def _get_string(data, position, as_class, tz_aware, uuid_subtype):
length = struct.unpack("<i", data[position:position + 4])[0] - 1 length = struct.unpack("<i", data[position:position + 4])[0]
if (len(data) - position - 4) < length:
raise InvalidBSON("invalid string length")
position += 4 position += 4
return _get_c_string(data, position, length) if data[position + length - 1:position + length] != ZERO:
raise InvalidBSON("invalid end of string")
return _get_c_string(data, position, length - 1)
def _get_object(data, position, as_class, tz_aware, uuid_subtype): def _get_object(data, position, as_class, tz_aware, uuid_subtype):
obj_size = struct.unpack("<i", data[position:position + 4])[0] obj_size = struct.unpack("<i", data[position:position + 4])[0]
if data[position + obj_size - 1:position + obj_size] != ZERO:
raise InvalidBSON("bad eoo")
encoded = data[position + 4:position + obj_size - 1] encoded = data[position + 4:position + obj_size - 1]
object = _elements_to_dict(encoded, as_class, tz_aware, uuid_subtype) object = _elements_to_dict(encoded, as_class, tz_aware, uuid_subtype)
position += obj_size position += obj_size

View File

@ -105,6 +105,8 @@ static struct module_state _state;
#define JAVA_LEGACY 5 #define JAVA_LEGACY 5
#define CSHARP_LEGACY 6 #define CSHARP_LEGACY 6
#define BSON_MAX_SIZE 2147483647 #define BSON_MAX_SIZE 2147483647
/* The smallest possible BSON document, i.e. "{}" */
#define BSON_MIN_SIZE 5
/* Get an error class from the bson.errors module. /* Get an error class from the bson.errors module.
* *
@ -136,8 +138,9 @@ _downcast_and_check(Py_ssize_t size, int extra) {
return (int)size + extra; return (int)size + extra;
} }
static PyObject* elements_to_dict(PyObject* self, const char* string, int max, static PyObject* elements_to_dict(PyObject* self, const char* string,
PyObject* as_class, unsigned char tz_aware, unsigned max, PyObject* as_class,
unsigned char tz_aware,
unsigned char uuid_subtype); unsigned char uuid_subtype);
static int _write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, static int _write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte,
@ -1331,8 +1334,8 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
return result; return result;
} }
static PyObject* get_value(PyObject* self, const char* buffer, int* position, static PyObject* get_value(PyObject* self, const char* buffer, unsigned* position,
int type, int max, PyObject* as_class, int type, unsigned max, PyObject* as_class,
unsigned char tz_aware, unsigned char uuid_subtype) { unsigned char tz_aware, unsigned char uuid_subtype) {
struct module_state *state = GETSTATE(self); struct module_state *state = GETSTATE(self);
@ -1356,24 +1359,40 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
case 2: case 2:
case 14: case 14:
{ {
int value_length = ((int*)(buffer + *position))[0] - 1; unsigned value_length;
if (max < value_length) { if (max < 4) {
goto invalid;
}
memcpy(&value_length, buffer + *position, 4);
/* Encoded string length + string */
if (max < 4 + value_length) {
goto invalid; goto invalid;
} }
*position += 4; *position += 4;
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict"); /* Strings must end in \0 */
if (buffer[*position + value_length - 1]) {
goto invalid;
}
value = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict");
if (!value) { if (!value) {
return NULL; return NULL;
} }
*position += value_length + 1; *position += value_length;
break; break;
} }
case 3: case 3:
{ {
PyObject* collection; PyObject* collection;
int size; unsigned size;
if (max < 4) {
goto invalid;
}
memcpy(&size, buffer + *position, 4); memcpy(&size, buffer + *position, 4);
if (size < 0 || max < size) { if (size < BSON_MIN_SIZE || max < size) {
goto invalid;
}
/* Check for bad eoo */
if (buffer[*position + size - 1]) {
goto invalid; goto invalid;
} }
value = elements_to_dict(self, buffer + *position + 4, value = elements_to_dict(self, buffer + *position + 4,
@ -1427,14 +1446,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
} }
case 4: case 4:
{ {
int size, unsigned size, end;
end;
if (max < 4) {
goto invalid;
}
memcpy(&size, buffer + *position, 4); memcpy(&size, buffer + *position, 4);
if (max < size) { if (max < size) {
goto invalid; goto invalid;
} }
end = *position + size - 1; end = *position + size - 1;
/* Check for bad eoo */
if (buffer[end]) {
goto invalid;
}
*position += 4; *position += 4;
value = PyList_New(0); value = PyList_New(0);
@ -1446,14 +1471,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
int bson_type = (int)buffer[(*position)++]; int bson_type = (int)buffer[(*position)++];
size_t key_size = strlen(buffer + *position); size_t key_size = strlen(buffer + *position);
if (key_size > BSON_MAX_SIZE) { if (max < key_size) {
Py_DECREF(value); Py_DECREF(value);
goto invalid; goto invalid;
} }
/* just skip the key, they're in order. */ /* just skip the key, they're in order. */
*position += (int)key_size + 1; *position += (unsigned)key_size + 1;
if (Py_EnterRecursiveCall(" while decoding a list value")) {
Py_DECREF(value);
return NULL;
}
to_append = get_value(self, buffer, position, bson_type, to_append = get_value(self, buffer, position, bson_type,
max - (int)key_size, as_class, tz_aware, uuid_subtype); max - (unsigned)key_size,
as_class, tz_aware, uuid_subtype);
Py_LeaveRecursiveCall();
if (!to_append) { if (!to_append) {
Py_DECREF(value); Py_DECREF(value);
return NULL; return NULL;
@ -1468,8 +1499,11 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
{ {
PyObject* data; PyObject* data;
PyObject* st; PyObject* st;
int length, subtype; unsigned length, subtype;
if (max < 4) {
goto invalid;
}
memcpy(&length, buffer + *position, 4); memcpy(&length, buffer + *position, 4);
if (max < length) { if (max < length) {
goto invalid; goto invalid;
@ -1656,19 +1690,21 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
int flags; int flags;
size_t flags_length, i; size_t flags_length, i;
size_t pattern_length = strlen(buffer + *position); size_t pattern_length = strlen(buffer + *position);
if (pattern_length > BSON_MAX_SIZE || max < (int)pattern_length) { if (pattern_length > BSON_MAX_SIZE || max < pattern_length) {
goto invalid; goto invalid;
} }
pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict"); pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict");
if (!pattern) { if (!pattern) {
return NULL; return NULL;
} }
*position += (int)pattern_length + 1; *position += (unsigned)pattern_length + 1;
if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) { flags_length = strlen(buffer + *position);
if (flags_length > BSON_MAX_SIZE ||
(BSON_MAX_SIZE - pattern_length) < flags_length) {
Py_DECREF(pattern); Py_DECREF(pattern);
goto invalid; goto invalid;
} }
if (max < (int)(pattern_length + flags_length)) { if (max < pattern_length + flags_length) {
Py_DECREF(pattern); Py_DECREF(pattern);
goto invalid; goto invalid;
} }
@ -1688,28 +1724,37 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
flags |= 64; flags |= 64;
} }
} }
*position += (int)flags_length + 1; *position += (unsigned)flags_length + 1;
value = PyObject_CallFunction(state->RECompile, "Oi", pattern, flags); value = PyObject_CallFunction(state->RECompile, "Oi", pattern, flags);
Py_DECREF(pattern); Py_DECREF(pattern);
break; break;
} }
case 12: case 12:
{ {
size_t coll_length; unsigned coll_length;
PyObject* collection; PyObject* collection;
PyObject* id; PyObject* id;
*position += 4; if (max < 4) {
coll_length = strlen(buffer + *position);
if (coll_length > BSON_MAX_SIZE || max < (int)coll_length + 12) {
goto invalid; goto invalid;
} }
memcpy(&coll_length, buffer + *position, 4);
/* Encoded string length + string + 12 byte ObjectId */
if (max < 4 + coll_length + 12) {
goto invalid;
}
*position += 4;
/* Strings must end in \0 */
if (buffer[*position + coll_length - 1]) {
goto invalid;
}
collection = PyUnicode_DecodeUTF8(buffer + *position, collection = PyUnicode_DecodeUTF8(buffer + *position,
coll_length, "strict"); coll_length - 1, "strict");
if (!collection) { if (!collection) {
return NULL; return NULL;
} }
*position += (int)coll_length + 1; *position += coll_length;
id = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12); id = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
if (!id) { if (!id) {
@ -1725,41 +1770,82 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
case 13: case 13:
{ {
PyObject* code; PyObject* code;
int value_length = ((int*)(buffer + *position))[0] - 1; unsigned value_length;
if (max < value_length) { if (max < 4) {
goto invalid;
}
memcpy(&value_length, buffer + *position, 4);
/* Encoded string length + string */
if (max < 4 + value_length) {
goto invalid; goto invalid;
} }
*position += 4; *position += 4;
code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict"); /* Strings must end in \0 */
if (buffer[*position + value_length - 1]) {
goto invalid;
}
code = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict");
if (!code) { if (!code) {
return NULL; return NULL;
} }
*position += value_length + 1; *position += value_length;
value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL); value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL);
Py_DECREF(code); Py_DECREF(code);
break; break;
} }
case 15: case 15:
{ {
size_t code_length; unsigned c_w_s_size;
int scope_size; unsigned code_size;
unsigned scope_size;
PyObject* code; PyObject* code;
PyObject* scope; PyObject* scope;
*position += 8; if (max < 8) {
code_length = strlen(buffer + *position);
if (code_length > BSON_MAX_SIZE || max < 8 + (int)code_length) {
goto invalid; goto invalid;
} }
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
memcpy(&c_w_s_size, buffer + *position, 4);
*position += 4;
if (max < c_w_s_size) {
goto invalid;
}
memcpy(&code_size, buffer + *position, 4);
/* code_w_scope length + code length + code + scope length */
if (max < 4 + 4 + code_size + 4) {
goto invalid;
}
*position += 4;
/* Strings must end in \0 */
if (buffer[*position + code_size - 1]) {
goto invalid;
}
code = PyUnicode_DecodeUTF8(buffer + *position, code_size - 1, "strict");
if (!code) { if (!code) {
return NULL; return NULL;
} }
*position += (int)code_length + 1; *position += code_size;
memcpy(&scope_size, buffer + *position, 4); memcpy(&scope_size, buffer + *position, 4);
scope = elements_to_dict(self, buffer + *position + 4, scope_size - 5, if (scope_size < BSON_MIN_SIZE) {
(PyObject*)&PyDict_Type, tz_aware, uuid_subtype); Py_DECREF(code);
goto invalid;
}
/* code length + code + scope length + scope */
if ((4 + code_size + 4 + scope_size) != c_w_s_size) {
Py_DECREF(code);
goto invalid;
}
/* Check for bad eoo */
if (buffer[*position + scope_size - 1]) {
goto invalid;
}
scope = elements_to_dict(self, buffer + *position + 4,
scope_size - 5, (PyObject*)&PyDict_Type,
tz_aware, uuid_subtype);
if (!scope) { if (!scope) {
Py_DECREF(code); Py_DECREF(code);
return NULL; return NULL;
@ -1845,16 +1931,18 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
error = _error("InvalidBSON"); error = _error("InvalidBSON");
if (error) { if (error) {
PyErr_SetNone(error); PyErr_SetString(error,
"invalid length or type code");
Py_DECREF(error); Py_DECREF(error);
} }
return NULL; return NULL;
} }
static PyObject* elements_to_dict(PyObject* self, const char* string, int max, static PyObject* _elements_to_dict(PyObject* self, const char* string,
PyObject* as_class, unsigned char tz_aware, unsigned max, PyObject* as_class,
unsigned char uuid_subtype) { unsigned char tz_aware,
int position = 0; unsigned char uuid_subtype) {
unsigned position = 0;
PyObject* dict = PyObject_CallObject(as_class, NULL); PyObject* dict = PyObject_CallObject(as_class, NULL);
if (!dict) { if (!dict) {
return NULL; return NULL;
@ -1864,7 +1952,7 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
PyObject* value; PyObject* value;
int type = (int)string[position++]; int type = (int)string[position++];
size_t name_length = strlen(string + position); size_t name_length = strlen(string + position);
if (name_length > BSON_MAX_SIZE || position + (int)name_length >= max) { if (name_length > BSON_MAX_SIZE || position + name_length >= max) {
PyObject* InvalidBSON = _error("InvalidBSON"); PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) { if (InvalidBSON) {
PyErr_SetNone(InvalidBSON); PyErr_SetNone(InvalidBSON);
@ -1878,7 +1966,7 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
Py_DECREF(dict); Py_DECREF(dict);
return NULL; return NULL;
} }
position += (int)name_length + 1; position += (unsigned)name_length + 1;
value = get_value(self, string, &position, type, value = get_value(self, string, &position, type,
max - position, as_class, tz_aware, uuid_subtype); max - position, as_class, tz_aware, uuid_subtype);
if (!value) { if (!value) {
@ -1894,6 +1982,19 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
return dict; return dict;
} }
static PyObject* elements_to_dict(PyObject* self, const char* string,
unsigned max, PyObject* as_class,
unsigned char tz_aware,
unsigned char uuid_subtype) {
PyObject* result;
if (Py_EnterRecursiveCall(" while decoding a BSON document"))
return NULL;
result = _elements_to_dict(self, string, max,
as_class, tz_aware, uuid_subtype);
Py_LeaveRecursiveCall();
return result;
}
static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
int size; int size;
Py_ssize_t total_size; Py_ssize_t total_size;
@ -1924,7 +2025,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
#else #else
total_size = PyString_Size(bson); total_size = PyString_Size(bson);
#endif #endif
if (total_size < 5) { if (total_size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON"); PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) { if (InvalidBSON) {
PyErr_SetString(InvalidBSON, PyErr_SetString(InvalidBSON,
@ -1944,7 +2045,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
} }
memcpy(&size, string, 4); memcpy(&size, string, 4);
if (size < 0) { if (size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON"); PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) { if (InvalidBSON) {
PyErr_SetString(InvalidBSON, "invalid message size"); PyErr_SetString(InvalidBSON, "invalid message size");
@ -1953,7 +2054,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
return NULL; return NULL;
} }
if (total_size < size) { if (total_size < size || total_size > BSON_MAX_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON"); PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) { if (InvalidBSON) {
PyErr_SetString(InvalidBSON, "objsize too large"); PyErr_SetString(InvalidBSON, "objsize too large");
@ -1971,7 +2072,8 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
return NULL; return NULL;
} }
dict = elements_to_dict(self, string + 4, size - 5, as_class, tz_aware, uuid_subtype); dict = elements_to_dict(self, string + 4, (unsigned)size - 5,
as_class, tz_aware, uuid_subtype);
if (!dict) { if (!dict) {
return NULL; return NULL;
} }
@ -2029,7 +2131,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
return NULL; return NULL;
while (total_size > 0) { while (total_size > 0) {
if (total_size < 5) { if (total_size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON"); PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) { if (InvalidBSON) {
PyErr_SetString(InvalidBSON, PyErr_SetString(InvalidBSON,
@ -2041,7 +2143,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
} }
memcpy(&size, string, 4); memcpy(&size, string, 4);
if (size < 0) { if (size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON"); PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) { if (InvalidBSON) {
PyErr_SetString(InvalidBSON, "invalid message size"); PyErr_SetString(InvalidBSON, "invalid message size");
@ -2071,7 +2173,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
return NULL; return NULL;
} }
dict = elements_to_dict(self, string + 4, size - 5, dict = elements_to_dict(self, string + 4, (unsigned)size - 5,
as_class, tz_aware, uuid_subtype); as_class, tz_aware, uuid_subtype);
if (!dict) { if (!dict) {
Py_DECREF(result); Py_DECREF(result);

View File

@ -1,6 +1,48 @@
Changelog Changelog
========= =========
Changes in Version 2.6.3
------------------------
Version 2.6.3 fixes issues reported since the release of 2.6.2, most
importantly a semaphore leak when a connection to the server fails.
Issues Resolved
...............
See the `PyMongo 2.6.3 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 2.6.3 release notes in JIRA: https://jira.mongodb.org/browse/PYTHON/fixforversion/13098
Changes in Version 2.6.2
------------------------
Version 2.6.2 fixes a :exc:`TypeError` problem when max_pool_size=None
is used in Python 3.
Issues Resolved
...............
See the `PyMongo 2.6.2 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 2.6.2 release notes in JIRA: https://jira.mongodb.org/browse/PYTHON/fixforversion/12910
Changes in Version 2.6.1
------------------------
Version 2.6.1 fixes a reference leak in
the :meth:`~pymongo.collection.Collection.insert` method.
Issues Resolved
...............
See the `PyMongo 2.6.1 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 2.6.1 release notes in JIRA: https://jira.mongodb.org/browse/PYTHON/fixforversion/12905
Changes in Version 2.6 Changes in Version 2.6
---------------------- ----------------------

View File

@ -67,7 +67,7 @@ SLOW_ONLY = 1
ALL = 2 ALL = 2
"""Profile all operations.""" """Profile all operations."""
version_tuple = (2, 6, '+') version_tuple = (2, 6, 3)
def get_version_string(): def get_version_string():
if isinstance(version_tuple[-1], basestring): if isinstance(version_tuple[-1], basestring):

View File

@ -555,6 +555,9 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
PyObject* client; PyObject* client;
PyObject* last_error_args; PyObject* last_error_args;
PyObject* result; PyObject* result;
PyObject* max_bson_size_obj;
PyObject* max_message_size_obj;
PyObject* send_message_result;
unsigned char check_keys; unsigned char check_keys;
unsigned char safe; unsigned char safe;
unsigned char continue_on_error; unsigned char continue_on_error;
@ -578,24 +581,25 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
options += 1; options += 1;
} }
max_bson_size_obj = PyObject_GetAttrString(client, "max_bson_size");
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
max_bson_size = PyLong_AsLong( max_bson_size = PyLong_AsLong(max_bson_size_obj);
PyObject_GetAttrString(client, "max_bson_size"));
#else #else
max_bson_size = PyInt_AsLong( max_bson_size = PyInt_AsLong(max_bson_size_obj);
PyObject_GetAttrString(client, "max_bson_size"));
#endif #endif
Py_XDECREF(max_bson_size_obj);
if (max_bson_size == -1) { if (max_bson_size == -1) {
PyMem_Free(collection_name); PyMem_Free(collection_name);
return NULL; return NULL;
} }
max_message_size_obj = PyObject_GetAttrString(client, "max_message_size");
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
max_message_size = PyLong_AsLong( max_message_size = PyLong_AsLong(max_message_size_obj);
PyObject_GetAttrString(client, "max_message_size"));
#else #else
max_message_size = PyInt_AsLong( max_message_size = PyInt_AsLong(max_message_size_obj);
PyObject_GetAttrString(client, "max_message_size"));
#endif #endif
Py_XDECREF(max_message_size_obj);
if (max_message_size == -1) { if (max_message_size == -1) {
PyMem_Free(collection_name); PyMem_Free(collection_name);
return NULL; return NULL;
@ -707,8 +711,10 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
request_id = new_request_id; request_id = new_request_id;
length_location = message_start; length_location = message_start;
if (!PyObject_CallMethod(client, send_message_result = PyObject_CallMethod(client, "_send_message",
"_send_message", "NO", result, send_gle)) { "NO", result, send_gle);
if (!send_message_result) {
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyObject* OperationFailure; PyObject* OperationFailure;
PyErr_Fetch(&etype, &evalue, &etrace); PyErr_Fetch(&etype, &evalue, &etrace);
@ -746,6 +752,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
* acknowledged writes. Re-raise immediately. */ * acknowledged writes. Re-raise immediately. */
PyErr_Restore(etype, evalue, etrace); PyErr_Restore(etype, evalue, etrace);
goto iterfail; goto iterfail;
} else {
Py_DECREF(send_message_result);
} }
} }
} }
@ -783,12 +791,17 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
buffer_free(buffer); buffer_free(buffer);
/* Send the last (or only) batch */ /* Send the last (or only) batch */
if (!PyObject_CallMethod(client, "_send_message", "NN", send_message_result = PyObject_CallMethod(client, "_send_message", "NN",
result, PyBool_FromLong((long)safe))) { result,
PyBool_FromLong((long)safe));
if (!send_message_result) {
Py_XDECREF(exc_type); Py_XDECREF(exc_type);
Py_XDECREF(exc_value); Py_XDECREF(exc_value);
Py_XDECREF(exc_trace); Py_XDECREF(exc_trace);
return NULL; return NULL;
} else {
Py_DECREF(send_message_result);
} }
if (exc_type) { if (exc_type) {

View File

@ -552,6 +552,12 @@ class MongoReplicaSetClient(common.BaseObject):
receive on a socket can take before timing out. receive on a socket can take before timing out.
- `connectTimeoutMS`: (integer) How long (in milliseconds) a - `connectTimeoutMS`: (integer) How long (in milliseconds) a
connection can take to be opened before timing out. connection can take to be opened before timing out.
- `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a
thread will wait for a socket from the pool if the pool has no
free sockets. Defaults to ``None`` (no timeout).
- `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give
the number of threads allowed to wait for a socket at one time.
Defaults to ``None`` (no waiters).
- `auto_start_request`: If ``True``, each thread that accesses - `auto_start_request`: If ``True``, each thread that accesses
this :class:`MongoReplicaSetClient` has a socket allocated to it this :class:`MongoReplicaSetClient` has a socket allocated to it
for the thread's lifetime, for each member of the set. For for the thread's lifetime, for each member of the set. For
@ -694,6 +700,8 @@ class MongoReplicaSetClient(common.BaseObject):
self.__net_timeout = self.__opts.get('sockettimeoutms') self.__net_timeout = self.__opts.get('sockettimeoutms')
self.__conn_timeout = self.__opts.get('connecttimeoutms') self.__conn_timeout = self.__opts.get('connecttimeoutms')
self.__wait_queue_timeout = self.__opts.get('waitqueuetimeoutms')
self.__wait_queue_multiple = self.__opts.get('waitqueuemultiple')
self.__use_ssl = self.__opts.get('ssl', None) self.__use_ssl = self.__opts.get('ssl', None)
self.__ssl_keyfile = self.__opts.get('ssl_keyfile', None) self.__ssl_keyfile = self.__opts.get('ssl_keyfile', None)
self.__ssl_certfile = self.__opts.get('ssl_certfile', None) self.__ssl_certfile = self.__opts.get('ssl_certfile', None)
@ -1034,6 +1042,8 @@ class MongoReplicaSetClient(common.BaseObject):
self.__net_timeout, self.__net_timeout,
self.__conn_timeout, self.__conn_timeout,
self.__use_ssl, self.__use_ssl,
wait_queue_timeout=self.__wait_queue_timeout,
wait_queue_multiple=self.__wait_queue_multiple,
use_greenlets=self.__use_greenlets, use_greenlets=self.__use_greenlets,
ssl_keyfile=self.__ssl_keyfile, ssl_keyfile=self.__ssl_keyfile,
ssl_certfile=self.__ssl_certfile, ssl_certfile=self.__ssl_certfile,

View File

@ -141,12 +141,6 @@ class Pool:
# Can override for testing: 0 to always check, None to never check. # Can override for testing: 0 to always check, None to never check.
self._check_interval_seconds = 1 self._check_interval_seconds = 1
if use_greenlets and not thread_util.have_gevent:
raise ConfigurationError(
"The Gevent module is not available. "
"Install the gevent package from PyPI."
)
self.sockets = set() self.sockets = set()
self.lock = threading.Lock() self.lock = threading.Lock()
@ -169,11 +163,17 @@ class Pool:
if HAS_SSL and use_ssl and not ssl_cert_reqs: if HAS_SSL and use_ssl and not ssl_cert_reqs:
self.ssl_cert_reqs = ssl.CERT_NONE self.ssl_cert_reqs = ssl.CERT_NONE
self._ident = thread_util.create_ident(use_greenlets)
# Map self._ident.get() -> request socket # Map self._ident.get() -> request socket
self._tid_to_sock = {} self._tid_to_sock = {}
if use_greenlets and not thread_util.have_gevent:
raise ConfigurationError(
"The Gevent module is not available. "
"Install the gevent package from PyPI."
)
self._ident = thread_util.create_ident(use_greenlets)
# Count the number of calls to start_request() per thread or greenlet # Count the number of calls to start_request() per thread or greenlet
self._request_counter = thread_util.Counter(use_greenlets) self._request_counter = thread_util.Counter(use_greenlets)
@ -324,28 +324,34 @@ class Pool:
elif not self._socket_semaphore.acquire(True, self.wait_queue_timeout): elif not self._socket_semaphore.acquire(True, self.wait_queue_timeout):
self._raise_wait_queue_timeout() self._raise_wait_queue_timeout()
sock_info, from_pool = None, None # We've now acquired the semaphore and must release it on error.
try: try:
sock_info, from_pool = None, None
try: try:
# set.pop() isn't atomic in Jython less than 2.7, see try:
# http://bugs.jython.org/issue1854 # set.pop() isn't atomic in Jython less than 2.7, see
self.lock.acquire() # http://bugs.jython.org/issue1854
sock_info, from_pool = self.sockets.pop(), True self.lock.acquire()
finally: sock_info, from_pool = self.sockets.pop(), True
self.lock.release() finally:
except KeyError: self.lock.release()
sock_info, from_pool = self.connect(pair), False except KeyError:
sock_info, from_pool = self.connect(pair), False
if from_pool: if from_pool:
sock_info = self._check(sock_info, pair) sock_info = self._check(sock_info, pair)
sock_info.forced = forced sock_info.forced = forced
if req_state == NO_SOCKET_YET: if req_state == NO_SOCKET_YET:
# start_request has been called but we haven't assigned a socket to # start_request has been called but we haven't assigned a
# the request yet. Let's use this socket for this request until # socket to the request yet. Let's use this socket for this
# end_request. # request until end_request.
self._set_request_state(sock_info) self._set_request_state(sock_info)
except:
if not forced:
self._socket_semaphore.release()
raise
sock_info.last_checkout = time.time() sock_info.last_checkout = time.time()
return sock_info return sock_info
@ -410,8 +416,10 @@ class Pool:
""" """
try: try:
self.lock.acquire() self.lock.acquire()
if (len(self.sockets) < self.max_size too_many_sockets = (self.max_size is not None
and sock_info.pool_id == self.pool_id): and len(self.sockets) >= self.max_size)
if not too_many_sockets and sock_info.pool_id == self.pool_id:
self.sockets.add(sock_info) self.sockets.add(sock_info)
else: else:
sock_info.close() sock_info.close()

View File

@ -110,6 +110,12 @@ class ReplicaSetConnection(MongoReplicaSetClient):
receive on a socket can take before timing out. receive on a socket can take before timing out.
- `connectTimeoutMS`: (integer) How long (in milliseconds) a - `connectTimeoutMS`: (integer) How long (in milliseconds) a
connection can take to be opened before timing out. connection can take to be opened before timing out.
- `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a
thread will wait for a socket from the pool if the pool has no
free sockets. Defaults to ``None`` (no timeout).
- `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give
the number of threads allowed to wait for a socket at one time.
Defaults to ``None`` (no waiters).
- `auto_start_request`: If ``True`` (the default), each thread that - `auto_start_request`: If ``True`` (the default), each thread that
accesses this :class:`ReplicaSetConnection` has a socket allocated accesses this :class:`ReplicaSetConnection` has a socket allocated
to it for the thread's lifetime, for each member of the set. For to it for the thread's lifetime, for each member of the set. For

View File

@ -31,7 +31,7 @@ from distutils.errors import CCompilerError
from distutils.errors import DistutilsPlatformError, DistutilsExecError from distutils.errors import DistutilsPlatformError, DistutilsExecError
from distutils.core import Extension from distutils.core import Extension
version = "2.6+" version = "2.6.3"
f = open("README.rst") f = open("README.rst")
try: try:

View File

@ -67,6 +67,8 @@ class TestBSON(unittest.TestCase):
# the simplest valid BSON document # the simplest valid BSON document
self.assertTrue(is_valid(b("\x05\x00\x00\x00\x00"))) self.assertTrue(is_valid(b("\x05\x00\x00\x00\x00")))
self.assertTrue(is_valid(BSON(b("\x05\x00\x00\x00\x00")))) self.assertTrue(is_valid(BSON(b("\x05\x00\x00\x00\x00"))))
# failure cases
self.assertFalse(is_valid(b("\x04\x00\x00\x00\x00"))) self.assertFalse(is_valid(b("\x04\x00\x00\x00\x00")))
self.assertFalse(is_valid(b("\x05\x00\x00\x00\x01"))) self.assertFalse(is_valid(b("\x05\x00\x00\x00\x01")))
self.assertFalse(is_valid(b("\x05\x00\x00\x00"))) self.assertFalse(is_valid(b("\x05\x00\x00\x00")))
@ -74,6 +76,17 @@ class TestBSON(unittest.TestCase):
self.assertFalse(is_valid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12"))) self.assertFalse(is_valid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12")))
self.assertFalse(is_valid(b("\x09\x00\x00\x00\x10a\x00\x05\x00"))) self.assertFalse(is_valid(b("\x09\x00\x00\x00\x10a\x00\x05\x00")))
self.assertFalse(is_valid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"))) self.assertFalse(is_valid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")))
self.assertFalse(is_valid(b("\x13\x00\x00\x00\x02foo\x00"
"\x04\x00\x00\x00bar\x00\x00")))
self.assertFalse(is_valid(b("\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00"
"\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00")))
self.assertFalse(is_valid(b("\x15\x00\x00\x00\x03foo\x00\x0c"
"\x00\x00\x00\x08bar\x00\x01\x00\x00")))
self.assertFalse(is_valid(b("\x1c\x00\x00\x00\x03foo\x00"
"\x12\x00\x00\x00\x02bar\x00"
"\x05\x00\x00\x00baz\x00\x00\x00")))
self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00"
"\x04\x00\x00\x00abc\xff\x00")))
def test_random_data_is_not_bson(self): def test_random_data_is_not_bson(self):
qcheck.check_unittest(self, qcheck.isnt(is_valid), qcheck.check_unittest(self, qcheck.isnt(is_valid),

View File

@ -1004,7 +1004,8 @@ class TestCollection(unittest.TestCase):
self.fail() self.fail()
except OperationFailure, e: except OperationFailure, e:
if version.at_least(self.db.connection, (1, 3)): if version.at_least(self.db.connection, (1, 3)):
self.assertEqual(10147, e.code) if e.code not in (10147, 17009):
self.fail()
def test_index_on_subfield(self): def test_index_on_subfield(self):
db = self.db db = self.db
@ -1890,14 +1891,19 @@ class TestCollection(unittest.TestCase):
ref_only = {'ref': {'$ref': 'collection'}} ref_only = {'ref': {'$ref': 'collection'}}
id_only = {'ref': {'$id': ObjectId()}} id_only = {'ref': {'$id': ObjectId()}}
# Force insert of ref without $id. # Starting with MongoDB 2.5.2 this is no longer possible
c.insert(ref_only, check_keys=False) # from insert, update, or findAndModify.
self.assertEqual(DBRef('collection', id=None), c.find_one()['ref']) if not version.at_least(self.db.connection, (2, 5, 2)):
c.drop() # Force insert of ref without $id.
c.insert(ref_only, check_keys=False)
self.assertEqual(DBRef('collection', id=None),
c.find_one()['ref'])
# DBRef without $ref is decoded as normal subdocument. c.drop()
c.insert(id_only, check_keys=False)
self.assertEqual(id_only, c.find_one()) # DBRef without $ref is decoded as normal subdocument.
c.insert(id_only, check_keys=False)
self.assertEqual(id_only, c.find_one())
def test_as_class(self): def test_as_class(self):
c = self.db.test c = self.db.test

View File

@ -28,6 +28,7 @@ from pymongo.replica_set_connection import ReplicaSetConnection
from pymongo.errors import ConfigurationError from pymongo.errors import ConfigurationError
from test import host, port, pair from test import host, port, pair
from test.test_replica_set_client import TestReplicaSetClientBase from test.test_replica_set_client import TestReplicaSetClientBase
from test.utils import get_pool
class TestConnection(unittest.TestCase): class TestConnection(unittest.TestCase):
@ -49,6 +50,11 @@ class TestConnection(unittest.TestCase):
c = Connection("mongodb://%s:%s/?safe=true" % (host, port)) c = Connection("mongodb://%s:%s/?safe=true" % (host, port))
self.assertTrue(c.safe) self.assertTrue(c.safe)
# To preserve legacy Connection's behavior, max_size should be None.
# Pool should handle this without error.
self.assertEqual(None, c._MongoClient__pool.max_size)
c.end_request()
# Connection's network_timeout argument is translated into # Connection's network_timeout argument is translated into
# socketTimeoutMS # socketTimeoutMS
self.assertEqual(123, Connection( self.assertEqual(123, Connection(
@ -85,6 +91,12 @@ class TestReplicaSetConnection(TestReplicaSetClientBase):
self.assertTrue(c.safe) self.assertTrue(c.safe)
# To preserve legacy ReplicaSetConnection's behavior, max_size should
# be None. Pool should handle this without error.
pool = get_pool(c)
self.assertEqual(None, pool.max_size)
c.end_request()
# ReplicaSetConnection's network_timeout argument is translated into # ReplicaSetConnection's network_timeout argument is translated into
# socketTimeoutMS # socketTimeoutMS
self.assertEqual(123, ReplicaSetConnection( self.assertEqual(123, ReplicaSetConnection(
@ -92,7 +104,8 @@ class TestReplicaSetConnection(TestReplicaSetClientBase):
)._MongoReplicaSetClient__net_timeout) )._MongoReplicaSetClient__net_timeout)
for network_timeout in 'foo', 0, -1: for network_timeout in 'foo', 0, -1:
self.assertRaises(ConfigurationError, self.assertRaises(
ConfigurationError,
ReplicaSetConnection, pair, replicaSet=self.name, ReplicaSetConnection, pair, replicaSet=self.name,
network_timeout=network_timeout) network_timeout=network_timeout)

View File

@ -788,7 +788,8 @@ class _TestMaxPoolSize(_TestPoolingBase):
recent Gevent development. recent Gevent development.
""" """
if start_request: if start_request:
assert max_pool_size >= nthreads, "Deadlock" if max_pool_size is not None and max_pool_size < nthreads:
raise AssertionError("Deadlock")
c = self.get_client( c = self.get_client(
max_pool_size=max_pool_size, auto_start_request=False) max_pool_size=max_pool_size, auto_start_request=False)
@ -870,7 +871,11 @@ class _TestMaxPoolSize(_TestPoolingBase):
self.sleep(0.1) self.sleep(0.1)
cx_pool._ident.get() cx_pool._ident.get()
self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter) if max_pool_size is not None:
self.assertEqual(
max_pool_size,
cx_pool._socket_semaphore.counter)
self.assertEqual(0, len(cx_pool._tid_to_sock)) self.assertEqual(0, len(cx_pool._tid_to_sock))
def _test_max_pool_size_no_rendezvous(self, start_request, end_request): def _test_max_pool_size_no_rendezvous(self, start_request, end_request):
@ -954,6 +959,10 @@ class _TestMaxPoolSize(_TestPoolingBase):
self._test_max_pool_size( self._test_max_pool_size(
start_request=0, end_request=0, nthreads=10, max_pool_size=4) start_request=0, end_request=0, nthreads=10, max_pool_size=4)
def test_max_pool_size_none(self):
self._test_max_pool_size(
start_request=0, end_request=0, nthreads=10, max_pool_size=None)
def test_max_pool_size_with_request(self): def test_max_pool_size_with_request(self):
self._test_max_pool_size( self._test_max_pool_size(
start_request=1, end_request=1, nthreads=10, max_pool_size=10) start_request=1, end_request=1, nthreads=10, max_pool_size=10)
@ -989,6 +998,28 @@ class _TestMaxPoolSize(_TestPoolingBase):
# Call end_request() but not start_request() # Call end_request() but not start_request()
self._test_max_pool_size(0, 1) self._test_max_pool_size(0, 1)
def test_max_pool_size_with_connection_failure(self):
# The pool acquires its semaphore before attempting to connect; ensure
# it releases the semaphore on connection failure.
class TestPool(Pool):
def connect(self, pair):
raise socket.error()
test_pool = TestPool(
pair=('example.com', 27017),
max_size=1,
net_timeout=1,
conn_timeout=1,
use_ssl=False,
wait_queue_timeout=1,
use_greenlets=self.use_greenlets)
# First call to get_socket fails; if pool doesn't release its semaphore
# then the second call raises "ConnectionFailure: Timed out waiting for
# socket from pool" instead of the socket.error.
for i in range(2):
self.assertRaises(socket.error, test_pool.get_socket)
class SocketGetter(MongoThread): class SocketGetter(MongoThread):
"""Utility for _TestMaxOpenSockets and _TestWaitQueueMultiple""" """Utility for _TestMaxOpenSockets and _TestWaitQueueMultiple"""
@ -1009,7 +1040,7 @@ class _TestMaxOpenSockets(_TestPoolingBase):
To be run both with threads and with greenlets. To be run both with threads and with greenlets.
""" """
def get_pool_with_wait_queue_timeout(self, wait_queue_timeout): def get_pool_with_wait_queue_timeout(self, wait_queue_timeout):
return self.get_pool(('127.0.0.1', 27017), return self.get_pool((host, port),
1, None, None, 1, None, None,
False, False,
wait_queue_timeout=wait_queue_timeout, wait_queue_timeout=wait_queue_timeout,
@ -1057,7 +1088,7 @@ class _TestWaitQueueMultiple(_TestPoolingBase):
To be run both with threads and with greenlets. To be run both with threads and with greenlets.
""" """
def get_pool_with_wait_queue_multiple(self, wait_queue_multiple): def get_pool_with_wait_queue_multiple(self, wait_queue_multiple):
return self.get_pool(('127.0.0.1', 27017), return self.get_pool((host, port),
2, None, None, 2, None, None,
False, False,
wait_queue_timeout=None, wait_queue_timeout=None,

View File

@ -14,16 +14,19 @@
"""Tests for connection-pooling with greenlets and Gevent""" """Tests for connection-pooling with greenlets and Gevent"""
import gc
import time
import unittest import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from pymongo import pool from pymongo import pool
from pymongo.errors import ConfigurationError
from test import host, port from test import host, port
from test.utils import looplet from test.utils import looplet
from test.test_pooling_base import ( from test.test_pooling_base import (
_TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets, _TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets,
_TestPoolSocketSharing, _TestWaitQueueMultiple) _TestPoolSocketSharing, _TestWaitQueueMultiple, has_gevent)
class TestPoolingGevent(_TestPooling, unittest.TestCase): class TestPoolingGevent(_TestPooling, unittest.TestCase):
@ -189,5 +192,49 @@ class TestWaitQueueMultipleGevent(_TestWaitQueueMultiple, unittest.TestCase):
use_greenlets = True use_greenlets = True
class TestUseGreenletsWithoutGevent(unittest.TestCase):
def test_use_greenlets_without_gevent(self):
# Verify that Pool(use_greenlets=True) raises ConfigurationError if
# Gevent is not installed, and that its destructor runs without error.
if has_gevent:
raise SkipTest(
"Gevent is installed, can't test what happens calling "
"Pool(use_greenlets=True) when Gevent is unavailable")
# Possible outcomes of __del__.
DID_NOT_RUN, RAISED, SUCCESS = range(3)
outcome = [DID_NOT_RUN]
class TestPool(pool.Pool):
def __del__(self):
try:
pool.Pool.__del__(self) # Pool is old-style, no super()
outcome[0] = SUCCESS
except:
outcome[0] = RAISED
# Pool raises ConfigurationError, "The Gevent module is not available".
self.assertRaises(
ConfigurationError,
TestPool,
pair=(host, port),
max_size=10,
net_timeout=1000,
conn_timeout=1000,
use_ssl=False,
use_greenlets=True)
# Convince Jython or PyPy to call __del__.
for _ in range(10):
if outcome[0] == DID_NOT_RUN:
gc.collect()
time.sleep(0.1)
if outcome[0] == DID_NOT_RUN:
self.fail("Pool.__del__ didn't run")
elif outcome[0] == RAISED:
self.fail("Pool.__del__ raised exception")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -49,7 +49,8 @@ from pymongo.errors import (AutoReconnect,
from test import version, port, pair from test import version, port, pair
from test.utils import ( from test.utils import (
delay, assertReadFrom, assertReadFromAll, read_from_which_host, delay, assertReadFrom, assertReadFromAll, read_from_which_host,
assertRaisesExactly, TestRequestMixin, one, server_started_with_auth) assertRaisesExactly, TestRequestMixin, one, server_started_with_auth,
pools_from_rs_client, get_pool)
class TestReplicaSetClientAgainstStandalone(unittest.TestCase): class TestReplicaSetClientAgainstStandalone(unittest.TestCase):
@ -699,8 +700,8 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
previous_writer = c._MongoReplicaSetClient__rs_state.writer previous_writer = c._MongoReplicaSetClient__rs_state.writer
def kill_sockets(): def kill_sockets():
for member in c._MongoReplicaSetClient__rs_state.members: for pool in pools_from_rs_client(c):
for socket_info in member.pool.sockets: for socket_info in pool.sockets:
socket_info.sock.close() socket_info.sock.close()
kill_sockets() kill_sockets()
@ -762,6 +763,17 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertTrue(rs_state.get(secondary_host).up) self.assertTrue(rs_state.get(secondary_host).up)
collection.find_one(read_preference=SECONDARY) # No error. collection.find_one(read_preference=SECONDARY) # No error.
def test_waitQueueTimeoutMS(self):
client = self._get_client(waitQueueTimeoutMS=2000)
pool = get_pool(client)
self.assertEqual(pool.wait_queue_timeout, 2)
def test_waitQueueMultiple(self):
client = self._get_client(max_pool_size=3, waitQueueMultiple=2)
pool = get_pool(client)
self.assertEqual(pool.wait_queue_multiple, 2)
self.assertEqual(pool._socket_semaphore.waiter_semaphore.counter, 6)
def test_tz_aware(self): def test_tz_aware(self):
self.assertRaises(ConfigurationError, MongoReplicaSetClient, self.assertRaises(ConfigurationError, MongoReplicaSetClient,
tz_aware='foo', replicaSet=self.name) tz_aware='foo', replicaSet=self.name)
@ -923,7 +935,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
# Ensure MongoReplicaSetClient doesn't close socket after it gets an # Ensure MongoReplicaSetClient doesn't close socket after it gets an
# error response to getLastError. PYTHON-395. # error response to getLastError. PYTHON-395.
c = self._get_client(auto_start_request=False) c = self._get_client(auto_start_request=False)
pool = c._MongoReplicaSetClient__rs_state.get(self.primary).pool pool = get_pool(c)
self.assertEqual(1, len(pool.sockets)) self.assertEqual(1, len(pool.sockets))
old_sock_info = iter(pool.sockets).next() old_sock_info = iter(pool.sockets).next()
c.pymongo_test.test.drop() c.pymongo_test.test.drop()
@ -943,7 +955,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
# error response to getLastError. PYTHON-395. # error response to getLastError. PYTHON-395.
c = self._get_client(auto_start_request=True) c = self._get_client(auto_start_request=True)
c.pymongo_test.test.find_one() c.pymongo_test.test.find_one()
pool = c._MongoReplicaSetClient__rs_state.get(self.primary).pool pool = get_pool(c)
# Client reserved a socket for this thread # Client reserved a socket for this thread
self.assertTrue(isinstance(pool._get_request_state(), SocketInfo)) self.assertTrue(isinstance(pool._get_request_state(), SocketInfo))
@ -968,13 +980,10 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
client = self._get_client(auto_start_request=True) client = self._get_client(auto_start_request=True)
self.assertTrue(client.auto_start_request) self.assertTrue(client.auto_start_request)
pools = [member.pool for member in pools = pools_from_rs_client(client)
client._MongoReplicaSetClient__rs_state.members]
self.assertInRequestAndSameSock(client, pools) self.assertInRequestAndSameSock(client, pools)
primary_pool = \ primary_pool = get_pool(client)
client._MongoReplicaSetClient__rs_state.get(client.primary).pool
# Trigger the RSC to actually start a request on primary pool # Trigger the RSC to actually start a request on primary pool
client.pymongo_test.test.find_one() client.pymongo_test.test.find_one()
@ -1003,9 +1012,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
client.close() client.close()
client = self._get_client() client = self._get_client()
pools = [mongo.pool for mongo in pools = pools_from_rs_client(client)
client._MongoReplicaSetClient__rs_state.members]
self.assertNotInRequestAndDifferentSock(client, pools) self.assertNotInRequestAndDifferentSock(client, pools)
client.start_request() client.start_request()
self.assertInRequestAndSameSock(client, pools) self.assertInRequestAndSameSock(client, pools)
@ -1016,8 +1023,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
def test_nested_request(self): def test_nested_request(self):
client = self._get_client(auto_start_request=True) client = self._get_client(auto_start_request=True)
try: try:
pools = [member.pool for member in pools = pools_from_rs_client(client)
client._MongoReplicaSetClient__rs_state.members]
self.assertTrue(client.in_request()) self.assertTrue(client.in_request())
# Start and end request - we're still in "outer" original request # Start and end request - we're still in "outer" original request
@ -1059,8 +1065,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
def test_request_threads(self): def test_request_threads(self):
client = self._get_client() client = self._get_client()
try: try:
pools = [member.pool for member in pools = pools_from_rs_client(client)
client._MongoReplicaSetClient__rs_state.members]
self.assertNotInRequestAndDifferentSock(client, pools) self.assertNotInRequestAndDifferentSock(client, pools)
started_request, ended_request = threading.Event(), threading.Event() started_request, ended_request = threading.Event(), threading.Event()

View File

@ -22,22 +22,11 @@ from nose.plugins.skip import SkipTest
from test.utils import server_started_with_auth, joinall, RendezvousThread from test.utils import server_started_with_auth, joinall, RendezvousThread
from test.test_client import get_client from test.test_client import get_client
from pymongo.mongo_client import MongoClient from test.utils import get_pool
from pymongo.replica_set_connection import MongoReplicaSetClient
from pymongo.pool import SocketInfo, _closed from pymongo.pool import SocketInfo, _closed
from pymongo.errors import AutoReconnect, OperationFailure from pymongo.errors import AutoReconnect, OperationFailure
def get_pool(client):
if isinstance(client, MongoClient):
return client._MongoClient__pool
elif isinstance(client, MongoReplicaSetClient):
rs_state = client._MongoReplicaSetClient__rs_state
return rs_state[rs_state.writer].pool
else:
raise TypeError(str(client))
class AutoAuthenticateThreads(threading.Thread): class AutoAuthenticateThreads(threading.Thread):
def __init__(self, collection, num): def __init__(self, collection, num):

View File

@ -263,6 +263,22 @@ def assertReadFromAll(testcase, rsc, members, *args, **kwargs):
testcase.assertEqual(members, used) testcase.assertEqual(members, used)
def get_pool(client):
if isinstance(client, MongoClient):
return client._MongoClient__pool
elif isinstance(client, MongoReplicaSetClient):
rs_state = client._MongoReplicaSetClient__rs_state
return rs_state.primary_member.pool
else:
raise TypeError(str(client))
def pools_from_rs_client(client):
"""Get Pool instances from a MongoReplicaSetClient or ReplicaSetConnection.
"""
return [
member.pool for member in
client._MongoReplicaSetClient__rs_state.members]
class TestRequestMixin(object): class TestRequestMixin(object):
"""Inherit from this class and from unittest.TestCase to get some """Inherit from this class and from unittest.TestCase to get some
convenient methods for testing connection pools and requests convenient methods for testing connection pools and requests