PYTHON-871 - Fix encoding of defaultdict.
This commit is contained in:
parent
5df17c2c63
commit
3716b449cc
@ -1424,17 +1424,34 @@ int write_dict(PyObject* self, buffer_t buffer,
|
||||
}
|
||||
|
||||
/* Write _id first if this is a top level doc. */
|
||||
if (top_level && PyMapping_HasKeyString(dict, "_id")) {
|
||||
PyObject* _id = PyMapping_GetItemString(dict, "_id");
|
||||
if (!_id) {
|
||||
return 0;
|
||||
}
|
||||
if (!write_pair(self, buffer, "_id", 3,
|
||||
_id, check_keys, options, 1)) {
|
||||
if (top_level) {
|
||||
/*
|
||||
* If "dict" is a defaultdict we don't want to call
|
||||
* PyMapping_GetItemString on it. That would **create**
|
||||
* an _id where one didn't previously exist (PYTHON-871).
|
||||
*/
|
||||
if (PyDict_Check(dict)) {
|
||||
/* PyDict_GetItemString returns a borrowed reference. */
|
||||
PyObject* _id = PyDict_GetItemString(dict, "_id");
|
||||
if (_id) {
|
||||
if (!write_pair(self, buffer, "_id", 3,
|
||||
_id, check_keys, options, 1)) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
} else if (PyMapping_HasKeyString(dict, "_id")) {
|
||||
PyObject* _id = PyMapping_GetItemString(dict, "_id");
|
||||
if (!_id) {
|
||||
return 0;
|
||||
}
|
||||
if (!write_pair(self, buffer, "_id", 3,
|
||||
_id, check_keys, options, 1)) {
|
||||
Py_DECREF(_id);
|
||||
return 0;
|
||||
}
|
||||
/* PyMapping_GetItemString returns a new reference. */
|
||||
Py_DECREF(_id);
|
||||
return 0;
|
||||
}
|
||||
Py_DECREF(_id);
|
||||
}
|
||||
|
||||
iter = PyObject_GetIter(dict);
|
||||
|
||||
@ -143,6 +143,11 @@ class TestBSON(unittest.TestCase):
|
||||
def test_encode_then_decode_any_mapping(self):
|
||||
self.check_encode_then_decode(doc_class=NotADict)
|
||||
|
||||
def test_encoding_defaultdict(self):
|
||||
dct = collections.defaultdict(dict, [('foo', 'bar')])
|
||||
BSON.encode(dct)
|
||||
self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')]))
|
||||
|
||||
def test_basic_validation(self):
|
||||
self.assertRaises(TypeError, is_valid, 100)
|
||||
self.assertRaises(TypeError, is_valid, u("test"))
|
||||
|
||||
@ -20,6 +20,8 @@ import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson.regex import Regex
|
||||
@ -657,6 +659,13 @@ class TestCollection(IntegrationTest):
|
||||
self.assertFalse(result.acknowledged)
|
||||
wait_until(lambda: 0 == db.test.count(), 'delete 2 documents')
|
||||
|
||||
def test_find_by_default_dct(self):
|
||||
db = self.db
|
||||
db.test.insert_one({'foo': 'bar'})
|
||||
dct = defaultdict(dict, [('foo', 'bar')])
|
||||
self.assertIsNotNone(db.test.find_one(dct))
|
||||
self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')]))
|
||||
|
||||
def test_find_w_fields(self):
|
||||
db = self.db
|
||||
db.test.delete_many({})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user