Added __ne__ for DBRefs PYTHON-440

This commit is contained in:
Ross Lawley 2012-11-22 11:12:36 +00:00
parent da6d73b77e
commit 29b9de45db
20 changed files with 126 additions and 35 deletions

View File

@ -154,7 +154,8 @@ class Binary(binary_type):
def __eq__(self, other):
if isinstance(other, Binary):
return (self.__subtype, binary_type(self)) == (other.subtype, binary_type(other))
return ((self.__subtype, binary_type(self)) ==
(other.subtype, binary_type(other)))
# We don't return NotImplemented here because if we did then
# Binary("foo") == "foo" would return True, since Binary is a
# subclass of str...

View File

@ -115,13 +115,16 @@ class DBRef(object):
def __eq__(self, other):
if isinstance(other, DBRef):
us = [self.__database, self.__collection,
self.__id, self.__kwargs]
them = [other.__database, other.__collection,
other.__id, other.__kwargs]
us = (self.__database, self.__collection,
self.__id, self.__kwargs)
them = (other.__database, other.__collection,
other.__id, other.__kwargs)
return us == them
return NotImplemented
def __ne__(self, other):
return not self == other
def __hash__(self):
"""Get a hash value for this :class:`DBRef`.

View File

@ -254,7 +254,7 @@ class ObjectId(object):
return self.__id == other.__id
return NotImplemented
def __ne__(self,other):
def __ne__(self, other):
if isinstance(other, ObjectId):
return self.__id != other.__id
return NotImplemented

View File

@ -198,6 +198,9 @@ class SON(dict):
dict(self.items()) == dict(other.items()))
return dict(self.items()) == other
def __ne__(self, other):
return not self == other
def __len__(self):
return len(self.keys())

View File

@ -150,6 +150,9 @@ class Collection(common.BaseObject):
return us == them
return NotImplemented
def __ne__(self, other):
return not self == other
@property
def full_name(self):
"""The full name of this :class:`Collection`.

View File

@ -183,6 +183,9 @@ class Database(common.BaseObject):
return us == them
return NotImplemented
def __ne__(self, other):
return not self == other
def __repr__(self):
return "Database(%r, %r)" % (self.__connection, self.__name)

View File

@ -220,6 +220,9 @@ class MasterSlaveConnection(BaseObject):
return us == them
return NotImplemented
def __ne__(self, other):
return not self == other
def __repr__(self):
return "MasterSlaveConnection(%r, %r)" % (self.__master, self.__slaves)

View File

@ -994,6 +994,9 @@ class MongoClient(common.BaseObject):
return us == them
return NotImplemented
def __ne__(self, other):
return not self == other
def __repr__(self):
if len(self.__nodes) == 1:
return "MongoClient(%r, %r)" % (self.__host, self.__port)

View File

@ -1264,6 +1264,9 @@ class MongoReplicaSetClient(common.BaseObject):
# XXX: Implement this?
return NotImplemented
def __ne__(self, other):
return NotImplemented
def __repr__(self):
return "MongoReplicaSetClient(%r)" % (["%s:%d" % n
for n in self.__hosts],)

View File

@ -86,6 +86,9 @@ class SocketInfo(object):
# if its sock is the same as ours
return hasattr(other, 'sock') and self.sock == other.sock
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self.sock)

View File

@ -78,6 +78,10 @@ class TestBinary(unittest.TestCase):
self.assertNotEqual(two, Binary(b("hello ")))
self.assertNotEqual(b("hello"), Binary(b("hello")))
# Explicitly test inequality
self.assertFalse(three != Binary(b("hello"), 100))
self.assertFalse(two != Binary(b("hello")))
def test_repr(self):
one = Binary(b("hello world"))
self.assertEqual(repr(one),

View File

@ -75,6 +75,11 @@ class TestCode(unittest.TestCase):
self.assertNotEqual(b, Code("hello "))
self.assertNotEqual("hello", Code("hello"))
# Explicitly test inequality
self.assertFalse(c != Code("hello", {"foo": 5}))
self.assertFalse(b != Code("hello"))
self.assertFalse(b != Code("hello", {}))
def test_scope_preserved(self):
a = Code("hello")
b = Code("hello", {"foo": 5})

View File

@ -275,6 +275,10 @@ class TestCommon(unittest.TestCase):
coll = c.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert, doc)
# Equality tests
self.assertEqual(c, Connection("mongodb://%s/?safe=true" % (pair,)))
self.assertFalse(c != Connection("mongodb://%s/?safe=true" % (pair,)))
def test_mongo_client(self):
m = MongoClient(pair, w=0)
coll = m.pymongo_test.write_concern_test
@ -304,6 +308,10 @@ class TestCommon(unittest.TestCase):
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc))
# Equality tests
self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,)))
self.assertFalse(m != MongoClient("mongodb://%s/?w=0" % (pair,)))
def test_replica_set_connection(self):
c = Connection(pair)
ismaster = c.admin.command('ismaster')
@ -339,6 +347,10 @@ class TestCommon(unittest.TestCase):
coll = c.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert, doc)
# Equality tests
self.assertEqual(c, ReplicaSetConnection("mongodb://%s/?replicaSet=%s;safe=true" % (pair, setname)))
self.assertFalse(c != ReplicaSetConnection("mongodb://%s/?replicaSet=%s;safe=true" % (pair, setname)))
def test_mongo_replica_set_client(self):
c = Connection(pair)
ismaster = c.admin.command('ismaster')
@ -374,6 +386,10 @@ class TestCommon(unittest.TestCase):
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc))
# Equality tests
self.assertEqual(m, MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname)))
self.assertFalse(m != MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname)))
if __name__ == "__main__":
unittest.main()

View File

@ -92,6 +92,12 @@ class TestConnection(unittest.TestCase, TestRequestMixin):
self.assertTrue(Connection(self.host, self.port))
def test_equality(self):
connection = Connection(self.host, self.port)
self.assertEqual(connection, Connection(self.host, self.port))
# Explicity test inequality
self.assertFalse(connection != Connection(self.host, self.port))
def test_host_w_port(self):
self.assertTrue(Connection("%s:%d" % (self.host, self.port)))
assertRaisesExactly(ConnectionFailure, Connection,

View File

@ -58,12 +58,16 @@ class TestDatabase(unittest.TestCase):
self.connection, u"my\u0000db")
self.assertEqual("name", Database(self.connection, "name").name)
def test_cmp(self):
def test_equality(self):
self.assertNotEqual(Database(self.connection, "test"),
Database(self.connection, "mike"))
self.assertEqual(Database(self.connection, "test"),
Database(self.connection, "test"))
# Explicitly test inequality
self.assertFalse(Database(self.connection, "test") !=
Database(self.connection, "test"))
def test_repr(self):
self.assertEqual(repr(Database(self.connection, "pymongo_test")),
"Database(%r, %s)" % (self.connection,

View File

@ -83,29 +83,27 @@ class TestDBRef(unittest.TestCase):
self.assertEqual(repr(DBRef("coll", 5, "baz", foo="bar", baz=4)),
"DBRef('coll', 5, 'baz', foo='bar', baz=4)")
def test_cmp(self):
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
DBRef(u"coll", ObjectId("1234567890abcdef12345678")))
self.assertNotEqual(DBRef("coll",
ObjectId("1234567890abcdef12345678")),
DBRef(u"coll",
ObjectId("1234567890abcdef12345678"), "foo"))
self.assertNotEqual(DBRef("coll",
ObjectId("1234567890abcdef12345678")),
DBRef("col", ObjectId("1234567890abcdef12345678")))
self.assertNotEqual(DBRef("coll",
ObjectId("1234567890abcdef12345678")),
def test_equality(self):
obj_id = ObjectId("1234567890abcdef12345678")
self.assertEqual(DBRef('foo', 5), DBRef('foo', 5))
self.assertEqual(DBRef("coll", obj_id), DBRef(u"coll", obj_id))
self.assertNotEqual(DBRef("coll", obj_id),
DBRef(u"coll", obj_id, "foo"))
self.assertNotEqual(DBRef("coll", obj_id), DBRef("col", obj_id))
self.assertNotEqual(DBRef("coll", obj_id),
DBRef("coll", ObjectId(b("123456789011"))))
self.assertNotEqual(DBRef("coll",
ObjectId("1234567890abcdef12345678")), 4)
self.assertEqual(DBRef("coll",
ObjectId("1234567890abcdef12345678"), "foo"),
DBRef(u"coll",
ObjectId("1234567890abcdef12345678"), "foo"))
self.assertNotEqual(DBRef("coll",
ObjectId("1234567890abcdef12345678"), "foo"),
DBRef(u"coll",
ObjectId("1234567890abcdef12345678"), "bar"))
self.assertNotEqual(DBRef("coll", obj_id), 4)
self.assertEqual(DBRef("coll", obj_id, "foo"),
DBRef(u"coll", obj_id, "foo"))
self.assertNotEqual(DBRef("coll", obj_id, "foo"),
DBRef(u"coll", obj_id, "bar"))
# Explicitly test inequality
self.assertFalse(DBRef('foo', 5) != DBRef('foo', 5))
self.assertFalse(DBRef("coll", obj_id) != DBRef(u"coll", obj_id))
self.assertFalse(DBRef("coll", obj_id, "foo") !=
DBRef(u"coll", obj_id, "foo"))
def test_kwargs(self):
self.assertEqual(DBRef("coll", 5, foo="bar"),

View File

@ -196,7 +196,7 @@ class TestMasterSlaveConnection(unittest.TestCase):
self.assertRaises(TypeError, self.connection.drop_database, None)
raise SkipTest("This test often fails due to SERVER-2329")
self.connection.pymongo_test.test.save({"dummy": u"object"}, safe=True)
dbs = self.connection.database_names()
self.assertTrue("pymongo_test" in dbs)

View File

@ -79,13 +79,19 @@ class TestObjectId(unittest.TestCase):
self.assertEqual(str(ObjectId(b('\x124Vx\x90\xab\xcd\xef\x124Vx'))),
"1234567890abcdef12345678")
def test_cmp(self):
def test_equality(self):
a = ObjectId()
self.assertEqual(a, ObjectId(a))
self.assertEqual(ObjectId(b("123456789012")), ObjectId(b("123456789012")))
self.assertEqual(ObjectId(b("123456789012")),
ObjectId(b("123456789012")))
self.assertNotEqual(ObjectId(), ObjectId())
self.assertNotEqual(ObjectId(b("123456789012")), b("123456789012"))
# Explicitly test inequality
self.assertFalse(a != ObjectId(a))
self.assertFalse(ObjectId(b("123456789012")) !=
ObjectId(b("123456789012")))
def test_binary_str_equivalence(self):
a = ObjectId()
self.assertEqual(a, ObjectId(a.binary))

View File

@ -39,9 +39,33 @@ class TestSON(unittest.TestCase):
("mike", "awesome"),
("hello_", "mike")])
b = SON({"hello": "world"})
self.assertEqual(b["hello"], "world")
self.assertRaises(KeyError, lambda: b["goodbye"])
c = SON({"hello": "world"})
self.assertEqual(c["hello"], "world")
self.assertRaises(KeyError, lambda: c["goodbye"])
def test_equality(self):
a = SON({"hello": "world"})
c = SON((('hello', 'world'), ('mike', 'awesome'), ('hello_', 'mike')))
self.assertEqual(a, SON({"hello": "world"}))
self.assertEqual(c, SON((('hello', 'world'),
('mike', 'awesome'),
('hello_', 'mike'))))
self.assertEqual(c, SON((('hello', 'world'),
('hello_', 'mike'),
('mike', 'awesome'))))
self.assertNotEqual(c, a)
# Explicitly test inequality
self.assertFalse(a != SON({"hello": "world"}))
self.assertFalse(c != SON((('hello', 'world'),
('mike', 'awesome'),
('hello_', 'mike'))))
self.assertFalse(c != SON((('hello', 'world'),
('hello_', 'mike'),
('mike', 'awesome'))))
def test_to_dict(self):
a = SON()

View File

@ -70,6 +70,9 @@ class TestTimestamp(unittest.TestCase):
self.assertNotEqual(t, Timestamp(1, 0))
self.assertEqual(t, Timestamp(1, 1))
# Explicitly test inequality
self.assertFalse(t != Timestamp(1, 1))
def test_repr(self):
t = Timestamp(0, 0)
self.assertEqual(repr(t), "Timestamp(0, 0)")