diff --git a/pymongo/son.py b/pymongo/son.py index 40c275d18..903561527 100644 --- a/pymongo/son.py +++ b/pymongo/son.py @@ -22,7 +22,6 @@ import datetime import re import binascii import base64 -from UserDict import DictMixin import elementtree.ElementTree as ET @@ -32,7 +31,7 @@ from objectid import ObjectId from dbref import DBRef from errors import UnsupportedTag -class SON(DictMixin): +class SON(dict): """SON data. A subclass of dict that maintains ordering of keys and provides a few extra @@ -69,63 +68,130 @@ class SON(DictMixin): """ def __init__(self, data=None, **kwargs): self.__keys = [] - self.__data = {} - if data is not None: - if hasattr(data, 'items'): - items = data.iteritems() - else: - items = list(data) - for item in items: - if len(item) != 2: - raise ValueError("sequence elements must have length 2") - self.__keys.append(item[0]) - self.__data[item[0]] = item[1] - if kwargs: - self.__merge_keys(kwargs.iterkeys()) - self.update(kwargs) - - def __merge_keys(self, keys): - self.__keys.extend(keys) - newkeys = {} - self.__keys = [newkeys.setdefault(x, x) for x in self.__keys if x not in newkeys] + dict.__init__(self) + self.update(data) + self.update(kwargs) def __repr__(self): result = [] for key in self.__keys: - result.append("(%r, %r)" % (key, self.__data[key])) + result.append("(%r, %r)" % (key, self[key])) return "SON([%s])" % ", ".join(result) - def update(self, data=None, **kwargs): - if data is not None: - if hasattr(data, "iterkeys"): - self.__merge_keys(data.iterkeys()) - else: - self.__merge_keys(data.keys()) - self.__data.update(data) - if kwargs: - self.update(kwargs) - - def __getitem__(self, key): - return self.__data[key] - def __setitem__(self, key, value): - if key not in self.__data: + if key not in self: self.__keys.append(key) - self.__data[key] = value + dict.__setitem__(self, key, value) def __delitem__(self, key): - del self.__data[key] self.__keys.remove(key) + dict.__delitem__(self, key) def keys(self): return list(self.__keys) def copy(self): other = SON() - other.__data = self.__data.copy() - other.__keys = self.__keys[:] + other.update(self) return other + # TODO this is all from UserDict.DictMixin. it could probably be made more efficient. + # second level definitions support higher levels + def __iter__(self): + for k in self.keys(): + yield k + + def has_key(self, key): + try: + value = self[key] + except KeyError: + return False + return True + + def __contains__(self, key): + return self.has_key(key) + + # third level takes advantage of second level definitions + def iteritems(self): + for k in self: + yield (k, self[k]) + + def iterkeys(self): + return self.__iter__() + + # fourth level uses definitions from lower levels + def itervalues(self): + for _, v in self.iteritems(): + yield v + + def values(self): + return [v for _, v in self.iteritems()] + + def items(self): + return list(self.iteritems()) + + def clear(self): + for key in self.keys(): + del self[key] + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key, *args): + if len(args) > 1: + raise TypeError, "pop expected at most 2 arguments, got "\ + + repr(1 + len(args)) + try: + value = self[key] + except KeyError: + if args: + return args[0] + raise + del self[key] + return value + + def popitem(self): + try: + k, v = self.iteritems().next() + except StopIteration: + raise KeyError, 'container is empty' + del self[k] + return (k, v) + + def update(self, other=None, **kwargs): + # Make progressively weaker assumptions about "other" + if other is None: + pass + elif hasattr(other, 'iteritems'): # iteritems saves memory and lookups + for k, v in other.iteritems(): + self[k] = v + elif hasattr(other, 'keys'): + for k in other.keys(): + self[k] = other[k] + else: + for k, v in other: + self[k] = v + if kwargs: + self.update(kwargs) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __cmp__(self, other): + if isinstance(other, SON): + return cmp((dict(self.iteritems()), self.keys()), (dict(other.iteritems()), other.keys())) + return cmp(dict(self.iteritems()), other) + + def __len__(self): + return len(self.keys()) + @classmethod def from_xml(cls, xml): """Create an instance of SON from an xml document.