diff --git a/bson/binary.py b/bson/binary.py index d894708d9..48cbe5d75 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -154,7 +154,8 @@ class Binary(binary_type): def __eq__(self, other): if isinstance(other, Binary): - return (self.__subtype, binary_type(self)) == (other.subtype, binary_type(other)) + return ((self.__subtype, binary_type(self)) == + (other.subtype, binary_type(other))) # We don't return NotImplemented here because if we did then # Binary("foo") == "foo" would return True, since Binary is a # subclass of str... diff --git a/bson/dbref.py b/bson/dbref.py index 85828d623..5b15bbd86 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -115,13 +115,16 @@ class DBRef(object): def __eq__(self, other): if isinstance(other, DBRef): - us = [self.__database, self.__collection, - self.__id, self.__kwargs] - them = [other.__database, other.__collection, - other.__id, other.__kwargs] + us = (self.__database, self.__collection, + self.__id, self.__kwargs) + them = (other.__database, other.__collection, + other.__id, other.__kwargs) return us == them return NotImplemented + def __ne__(self, other): + return not self == other + def __hash__(self): """Get a hash value for this :class:`DBRef`. diff --git a/bson/objectid.py b/bson/objectid.py index 0cefd2026..d21fd837f 100644 --- a/bson/objectid.py +++ b/bson/objectid.py @@ -254,7 +254,7 @@ class ObjectId(object): return self.__id == other.__id return NotImplemented - def __ne__(self,other): + def __ne__(self, other): if isinstance(other, ObjectId): return self.__id != other.__id return NotImplemented diff --git a/bson/son.py b/bson/son.py index 4afc4c892..55b5b9e76 100644 --- a/bson/son.py +++ b/bson/son.py @@ -198,6 +198,9 @@ class SON(dict): dict(self.items()) == dict(other.items())) return dict(self.items()) == other + def __ne__(self, other): + return not self == other + def __len__(self): return len(self.keys()) diff --git a/pymongo/collection.py b/pymongo/collection.py index ae5d82ea2..eaef05424 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -150,6 +150,9 @@ class Collection(common.BaseObject): return us == them return NotImplemented + def __ne__(self, other): + return not self == other + @property def full_name(self): """The full name of this :class:`Collection`. diff --git a/pymongo/database.py b/pymongo/database.py index f69053612..a9caff6be 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -183,6 +183,9 @@ class Database(common.BaseObject): return us == them return NotImplemented + def __ne__(self, other): + return not self == other + def __repr__(self): return "Database(%r, %r)" % (self.__connection, self.__name) diff --git a/pymongo/master_slave_connection.py b/pymongo/master_slave_connection.py index f55032159..abb0f718a 100644 --- a/pymongo/master_slave_connection.py +++ b/pymongo/master_slave_connection.py @@ -220,6 +220,9 @@ class MasterSlaveConnection(BaseObject): return us == them return NotImplemented + def __ne__(self, other): + return not self == other + def __repr__(self): return "MasterSlaveConnection(%r, %r)" % (self.__master, self.__slaves) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 61bc3748e..f7c0ff5b3 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -994,6 +994,9 @@ class MongoClient(common.BaseObject): return us == them return NotImplemented + def __ne__(self, other): + return not self == other + def __repr__(self): if len(self.__nodes) == 1: return "MongoClient(%r, %r)" % (self.__host, self.__port) diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index d357768c0..ca1cebebf 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -1264,6 +1264,9 @@ class MongoReplicaSetClient(common.BaseObject): # XXX: Implement this? return NotImplemented + def __ne__(self, other): + return NotImplemented + def __repr__(self): return "MongoReplicaSetClient(%r)" % (["%s:%d" % n for n in self.__hosts],) diff --git a/pymongo/pool.py b/pymongo/pool.py index 3e25d1e24..a196679a7 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -86,6 +86,9 @@ class SocketInfo(object): # if its sock is the same as ours return hasattr(other, 'sock') and self.sock == other.sock + def __ne__(self, other): + return not self == other + def __hash__(self): return hash(self.sock) diff --git a/test/test_binary.py b/test/test_binary.py index 2d4578713..c9743eb68 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -78,6 +78,10 @@ class TestBinary(unittest.TestCase): self.assertNotEqual(two, Binary(b("hello "))) self.assertNotEqual(b("hello"), Binary(b("hello"))) + # Explicitly test inequality + self.assertFalse(three != Binary(b("hello"), 100)) + self.assertFalse(two != Binary(b("hello"))) + def test_repr(self): one = Binary(b("hello world")) self.assertEqual(repr(one), diff --git a/test/test_code.py b/test/test_code.py index 194257cad..fe5853f3d 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -75,6 +75,11 @@ class TestCode(unittest.TestCase): self.assertNotEqual(b, Code("hello ")) self.assertNotEqual("hello", Code("hello")) + # Explicitly test inequality + self.assertFalse(c != Code("hello", {"foo": 5})) + self.assertFalse(b != Code("hello")) + self.assertFalse(b != Code("hello", {})) + def test_scope_preserved(self): a = Code("hello") b = Code("hello", {"foo": 5}) diff --git a/test/test_common.py b/test/test_common.py index 21a8d70ad..ab7b5cc6b 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -275,6 +275,10 @@ class TestCommon(unittest.TestCase): coll = c.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert, doc) + # Equality tests + self.assertEqual(c, Connection("mongodb://%s/?safe=true" % (pair,))) + self.assertFalse(c != Connection("mongodb://%s/?safe=true" % (pair,))) + def test_mongo_client(self): m = MongoClient(pair, w=0) coll = m.pymongo_test.write_concern_test @@ -304,6 +308,10 @@ class TestCommon(unittest.TestCase): coll = m.pymongo_test.write_concern_test self.assertTrue(coll.insert(doc)) + # Equality tests + self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,))) + self.assertFalse(m != MongoClient("mongodb://%s/?w=0" % (pair,))) + def test_replica_set_connection(self): c = Connection(pair) ismaster = c.admin.command('ismaster') @@ -339,6 +347,10 @@ class TestCommon(unittest.TestCase): coll = c.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert, doc) + # Equality tests + self.assertEqual(c, ReplicaSetConnection("mongodb://%s/?replicaSet=%s;safe=true" % (pair, setname))) + self.assertFalse(c != ReplicaSetConnection("mongodb://%s/?replicaSet=%s;safe=true" % (pair, setname))) + def test_mongo_replica_set_client(self): c = Connection(pair) ismaster = c.admin.command('ismaster') @@ -374,6 +386,10 @@ class TestCommon(unittest.TestCase): coll = m.pymongo_test.write_concern_test self.assertTrue(coll.insert(doc)) + # Equality tests + self.assertEqual(m, MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname))) + self.assertFalse(m != MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname))) + if __name__ == "__main__": unittest.main() diff --git a/test/test_connection.py b/test/test_connection.py index 8eed62567..188f7dea0 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -92,6 +92,12 @@ class TestConnection(unittest.TestCase, TestRequestMixin): self.assertTrue(Connection(self.host, self.port)) + def test_equality(self): + connection = Connection(self.host, self.port) + self.assertEqual(connection, Connection(self.host, self.port)) + # Explicity test inequality + self.assertFalse(connection != Connection(self.host, self.port)) + def test_host_w_port(self): self.assertTrue(Connection("%s:%d" % (self.host, self.port))) assertRaisesExactly(ConnectionFailure, Connection, diff --git a/test/test_database.py b/test/test_database.py index 44f41f21b..ca3add89f 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -58,12 +58,16 @@ class TestDatabase(unittest.TestCase): self.connection, u"my\u0000db") self.assertEqual("name", Database(self.connection, "name").name) - def test_cmp(self): + def test_equality(self): self.assertNotEqual(Database(self.connection, "test"), Database(self.connection, "mike")) self.assertEqual(Database(self.connection, "test"), Database(self.connection, "test")) + # Explicitly test inequality + self.assertFalse(Database(self.connection, "test") != + Database(self.connection, "test")) + def test_repr(self): self.assertEqual(repr(Database(self.connection, "pymongo_test")), "Database(%r, %s)" % (self.connection, diff --git a/test/test_dbref.py b/test/test_dbref.py index e8b97f2d0..4fe680304 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -83,29 +83,27 @@ class TestDBRef(unittest.TestCase): self.assertEqual(repr(DBRef("coll", 5, "baz", foo="bar", baz=4)), "DBRef('coll', 5, 'baz', foo='bar', baz=4)") - def test_cmp(self): - self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), - DBRef(u"coll", ObjectId("1234567890abcdef12345678"))) - self.assertNotEqual(DBRef("coll", - ObjectId("1234567890abcdef12345678")), - DBRef(u"coll", - ObjectId("1234567890abcdef12345678"), "foo")) - self.assertNotEqual(DBRef("coll", - ObjectId("1234567890abcdef12345678")), - DBRef("col", ObjectId("1234567890abcdef12345678"))) - self.assertNotEqual(DBRef("coll", - ObjectId("1234567890abcdef12345678")), + def test_equality(self): + obj_id = ObjectId("1234567890abcdef12345678") + + self.assertEqual(DBRef('foo', 5), DBRef('foo', 5)) + self.assertEqual(DBRef("coll", obj_id), DBRef(u"coll", obj_id)) + self.assertNotEqual(DBRef("coll", obj_id), + DBRef(u"coll", obj_id, "foo")) + self.assertNotEqual(DBRef("coll", obj_id), DBRef("col", obj_id)) + self.assertNotEqual(DBRef("coll", obj_id), DBRef("coll", ObjectId(b("123456789011")))) - self.assertNotEqual(DBRef("coll", - ObjectId("1234567890abcdef12345678")), 4) - self.assertEqual(DBRef("coll", - ObjectId("1234567890abcdef12345678"), "foo"), - DBRef(u"coll", - ObjectId("1234567890abcdef12345678"), "foo")) - self.assertNotEqual(DBRef("coll", - ObjectId("1234567890abcdef12345678"), "foo"), - DBRef(u"coll", - ObjectId("1234567890abcdef12345678"), "bar")) + self.assertNotEqual(DBRef("coll", obj_id), 4) + self.assertEqual(DBRef("coll", obj_id, "foo"), + DBRef(u"coll", obj_id, "foo")) + self.assertNotEqual(DBRef("coll", obj_id, "foo"), + DBRef(u"coll", obj_id, "bar")) + + # Explicitly test inequality + self.assertFalse(DBRef('foo', 5) != DBRef('foo', 5)) + self.assertFalse(DBRef("coll", obj_id) != DBRef(u"coll", obj_id)) + self.assertFalse(DBRef("coll", obj_id, "foo") != + DBRef(u"coll", obj_id, "foo")) def test_kwargs(self): self.assertEqual(DBRef("coll", 5, foo="bar"), diff --git a/test/test_master_slave_connection.py b/test/test_master_slave_connection.py index d431113c2..f81f3a1e8 100644 --- a/test/test_master_slave_connection.py +++ b/test/test_master_slave_connection.py @@ -196,7 +196,7 @@ class TestMasterSlaveConnection(unittest.TestCase): self.assertRaises(TypeError, self.connection.drop_database, None) raise SkipTest("This test often fails due to SERVER-2329") - + self.connection.pymongo_test.test.save({"dummy": u"object"}, safe=True) dbs = self.connection.database_names() self.assertTrue("pymongo_test" in dbs) diff --git a/test/test_objectid.py b/test/test_objectid.py index bbd5fb8c0..d3247c8e0 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -79,13 +79,19 @@ class TestObjectId(unittest.TestCase): self.assertEqual(str(ObjectId(b('\x124Vx\x90\xab\xcd\xef\x124Vx'))), "1234567890abcdef12345678") - def test_cmp(self): + def test_equality(self): a = ObjectId() self.assertEqual(a, ObjectId(a)) - self.assertEqual(ObjectId(b("123456789012")), ObjectId(b("123456789012"))) + self.assertEqual(ObjectId(b("123456789012")), + ObjectId(b("123456789012"))) self.assertNotEqual(ObjectId(), ObjectId()) self.assertNotEqual(ObjectId(b("123456789012")), b("123456789012")) + # Explicitly test inequality + self.assertFalse(a != ObjectId(a)) + self.assertFalse(ObjectId(b("123456789012")) != + ObjectId(b("123456789012"))) + def test_binary_str_equivalence(self): a = ObjectId() self.assertEqual(a, ObjectId(a.binary)) diff --git a/test/test_son.py b/test/test_son.py index b50fadd01..7e2eff2c4 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -39,9 +39,33 @@ class TestSON(unittest.TestCase): ("mike", "awesome"), ("hello_", "mike")]) - b = SON({"hello": "world"}) - self.assertEqual(b["hello"], "world") - self.assertRaises(KeyError, lambda: b["goodbye"]) + c = SON({"hello": "world"}) + self.assertEqual(c["hello"], "world") + self.assertRaises(KeyError, lambda: c["goodbye"]) + + def test_equality(self): + a = SON({"hello": "world"}) + c = SON((('hello', 'world'), ('mike', 'awesome'), ('hello_', 'mike'))) + + self.assertEqual(a, SON({"hello": "world"})) + self.assertEqual(c, SON((('hello', 'world'), + ('mike', 'awesome'), + ('hello_', 'mike')))) + self.assertEqual(c, SON((('hello', 'world'), + ('hello_', 'mike'), + ('mike', 'awesome')))) + + self.assertNotEqual(c, a) + + # Explicitly test inequality + self.assertFalse(a != SON({"hello": "world"})) + self.assertFalse(c != SON((('hello', 'world'), + ('mike', 'awesome'), + ('hello_', 'mike')))) + self.assertFalse(c != SON((('hello', 'world'), + ('hello_', 'mike'), + ('mike', 'awesome')))) + def test_to_dict(self): a = SON() diff --git a/test/test_timestamp.py b/test/test_timestamp.py index b98ba44e1..d8d4cb8ad 100644 --- a/test/test_timestamp.py +++ b/test/test_timestamp.py @@ -70,6 +70,9 @@ class TestTimestamp(unittest.TestCase): self.assertNotEqual(t, Timestamp(1, 0)) self.assertEqual(t, Timestamp(1, 1)) + # Explicitly test inequality + self.assertFalse(t != Timestamp(1, 1)) + def test_repr(self): t = Timestamp(0, 0) self.assertEqual(repr(t), "Timestamp(0, 0)")