diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 5c5d721ec..f00c6f3c5 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -1424,17 +1424,34 @@ int write_dict(PyObject* self, buffer_t buffer, } /* Write _id first if this is a top level doc. */ - if (top_level && PyMapping_HasKeyString(dict, "_id")) { - PyObject* _id = PyMapping_GetItemString(dict, "_id"); - if (!_id) { - return 0; - } - if (!write_pair(self, buffer, "_id", 3, - _id, check_keys, options, 1)) { + if (top_level) { + /* + * If "dict" is a defaultdict we don't want to call + * PyMapping_GetItemString on it. That would **create** + * an _id where one didn't previously exist (PYTHON-871). + */ + if (PyDict_Check(dict)) { + /* PyDict_GetItemString returns a borrowed reference. */ + PyObject* _id = PyDict_GetItemString(dict, "_id"); + if (_id) { + if (!write_pair(self, buffer, "_id", 3, + _id, check_keys, options, 1)) { + return 0; + } + } + } else if (PyMapping_HasKeyString(dict, "_id")) { + PyObject* _id = PyMapping_GetItemString(dict, "_id"); + if (!_id) { + return 0; + } + if (!write_pair(self, buffer, "_id", 3, + _id, check_keys, options, 1)) { + Py_DECREF(_id); + return 0; + } + /* PyMapping_GetItemString returns a new reference. */ Py_DECREF(_id); - return 0; } - Py_DECREF(_id); } iter = PyObject_GetIter(dict); diff --git a/test/test_bson.py b/test/test_bson.py index 81ba22fce..b4b676c5c 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -143,6 +143,11 @@ class TestBSON(unittest.TestCase): def test_encode_then_decode_any_mapping(self): self.check_encode_then_decode(doc_class=NotADict) + def test_encoding_defaultdict(self): + dct = collections.defaultdict(dict, [('foo', 'bar')]) + BSON.encode(dct) + self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')])) + def test_basic_validation(self): self.assertRaises(TypeError, is_valid, 100) self.assertRaises(TypeError, is_valid, u("test")) diff --git a/test/test_collection.py b/test/test_collection.py index ebae73f54..7303e7b58 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -20,6 +20,8 @@ import re import sys import threading +from collections import defaultdict + sys.path[0:0] = [""] from bson.regex import Regex @@ -657,6 +659,13 @@ class TestCollection(IntegrationTest): self.assertFalse(result.acknowledged) wait_until(lambda: 0 == db.test.count(), 'delete 2 documents') + def test_find_by_default_dct(self): + db = self.db + db.test.insert_one({'foo': 'bar'}) + dct = defaultdict(dict, [('foo', 'bar')]) + self.assertIsNotNone(db.test.find_one(dct)) + self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')])) + def test_find_w_fields(self): db = self.db db.test.delete_many({})