diff --git a/bson/son.py b/bson/son.py index dd4ba979f..520cdd263 100644 --- a/bson/son.py +++ b/bson/son.py @@ -124,14 +124,20 @@ class SON(dict): # efficient. # second level definitions support higher levels def __iter__(self): - for k in self.keys(): + """ + Cannot remove nor add keys while iterating + """ + key_len = len(self.__keys) + for k in self.__keys: + if len(self.__keys) != key_len: + raise RuntimeError("son changed length during iteration") yield k def has_key(self, key): - return key in self.keys() + return key in self.__keys def __contains__(self, key): - return key in self.keys() + return key in self.__keys # third level takes advantage of second level definitions def iteritems(self): @@ -153,8 +159,8 @@ class SON(dict): return [(key, self[key]) for key in self] def clear(self): - for key in self.keys(): - del self[key] + self.__keys = [] + super(SON, self).clear() def setdefault(self, key, default=None): try: diff --git a/doc/contributors.rst b/doc/contributors.rst index 3488fc8cd..5c1ec55b0 100644 --- a/doc/contributors.rst +++ b/doc/contributors.rst @@ -70,3 +70,4 @@ The following is a list of people who have contributed to - Kyle Erf (3rf) - Luke Lovett (lovett89) - Jaroslav Semančík (girogiro) +- Don Mitchell (dmitchell) diff --git a/test/test_son.py b/test/test_son.py index 02312ac59..8e101e97e 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -154,6 +154,52 @@ class TestSON(unittest.TestCase): self.assertEqual(list(reflexive_son), list(reflexive_son1)) self.assertEqual(id(reflexive_son1), id(reflexive_son1["reflexive"])) + def test_iteration(self): + """ + Test __iter__ + """ + # test success case + test_son = SON([(1, 100), (2, 200), (3, 300)]) + for ele in test_son: + self.assertEqual(ele * 100, test_son[ele]) + # test failure case + def break_iter(): + for ele in test_son: + del test_son[ele] + self.assertRaises(RuntimeError, break_iter) + + def test_contains_has(self): + """ + has_key and __contains__ + """ + test_son = SON([(1, 100), (2, 200), (3, 300)]) + self.assertIn(1, test_son) + self.assertTrue(2 in test_son, "in failed") + self.assertFalse(22 in test_son, "in succeeded when it shouldn't") + self.assertTrue(test_son.has_key(2), "has_key failed") + self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") + + def test_clears(self): + """ + Test clear() + """ + test_son = SON([(1, 100), (2, 200), (3, 300)]) + test_son.clear() + self.assertNotIn(1, test_son) + self.assertEqual(0, len(test_son)) + self.assertEqual(0, len(test_son.keys())) + self.assertEqual({}, test_son.to_dict()) + + def test_len(self): + """ + Test len + """ + test_son = SON() + self.assertEqual(0, len(test_son)) + test_son = SON([(1, 100), (2, 200), (3, 300)]) + self.assertEqual(3, len(test_son)) + test_son.popitem() + self.assertEqual(2, len(test_son)) if __name__ == "__main__": unittest.main()