PYTHON-871 - Fix encoding of defaultdict.

This commit is contained in:
Bernie Hackett 2015-03-27 14:35:09 -07:00
parent 5df17c2c63
commit 3716b449cc
3 changed files with 40 additions and 9 deletions

View File

@ -1424,17 +1424,34 @@ int write_dict(PyObject* self, buffer_t buffer,
}
/* Write _id first if this is a top level doc. */
if (top_level && PyMapping_HasKeyString(dict, "_id")) {
PyObject* _id = PyMapping_GetItemString(dict, "_id");
if (!_id) {
return 0;
}
if (!write_pair(self, buffer, "_id", 3,
_id, check_keys, options, 1)) {
if (top_level) {
/*
* If "dict" is a defaultdict we don't want to call
* PyMapping_GetItemString on it. That would **create**
* an _id where one didn't previously exist (PYTHON-871).
*/
if (PyDict_Check(dict)) {
/* PyDict_GetItemString returns a borrowed reference. */
PyObject* _id = PyDict_GetItemString(dict, "_id");
if (_id) {
if (!write_pair(self, buffer, "_id", 3,
_id, check_keys, options, 1)) {
return 0;
}
}
} else if (PyMapping_HasKeyString(dict, "_id")) {
PyObject* _id = PyMapping_GetItemString(dict, "_id");
if (!_id) {
return 0;
}
if (!write_pair(self, buffer, "_id", 3,
_id, check_keys, options, 1)) {
Py_DECREF(_id);
return 0;
}
/* PyMapping_GetItemString returns a new reference. */
Py_DECREF(_id);
return 0;
}
Py_DECREF(_id);
}
iter = PyObject_GetIter(dict);

View File

@ -143,6 +143,11 @@ class TestBSON(unittest.TestCase):
def test_encode_then_decode_any_mapping(self):
self.check_encode_then_decode(doc_class=NotADict)
def test_encoding_defaultdict(self):
dct = collections.defaultdict(dict, [('foo', 'bar')])
BSON.encode(dct)
self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')]))
def test_basic_validation(self):
self.assertRaises(TypeError, is_valid, 100)
self.assertRaises(TypeError, is_valid, u("test"))

View File

@ -20,6 +20,8 @@ import re
import sys
import threading
from collections import defaultdict
sys.path[0:0] = [""]
from bson.regex import Regex
@ -657,6 +659,13 @@ class TestCollection(IntegrationTest):
self.assertFalse(result.acknowledged)
wait_until(lambda: 0 == db.test.count(), 'delete 2 documents')
def test_find_by_default_dct(self):
db = self.db
db.test.insert_one({'foo': 'bar'})
dct = defaultdict(dict, [('foo', 'bar')])
self.assertIsNotNone(db.test.find_one(dct))
self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')]))
def test_find_w_fields(self):
db = self.db
db.test.delete_many({})