diff --git a/objectid.py b/objectid.py index af981441d..269eda06b 100644 --- a/objectid.py +++ b/objectid.py @@ -2,11 +2,7 @@ import random import types -import unittest - -class InvalidId(ValueError): - """Raised when trying to create an ObjectId from invalid data. - """ +from errors import InvalidId class ObjectId(object): """A Mongo ObjectId. @@ -32,13 +28,13 @@ class ObjectId(object): """ # TODO for now, just generate 12 random bytes. this will change when we decide on an _id algorithm... self.__new = True - id = "" + oid = "" for _ in range(12): - id += chr(random.randint(0, 255)) + oid += chr(random.randint(0, 255)) - self.__id = id + self.__id = oid - def __validate(self, id): + def __validate(self, oid): """Validate and use the given id for this ObjectId. Raises TypeError if id is not an instance of (str, ObjectId) and @@ -48,15 +44,15 @@ class ObjectId(object): - `id`: a valid ObjectId """ self.__new = False - if isinstance(id, ObjectId): - self.__id = id.__id - elif isinstance(id, types.StringType): - if len(id) == 12: - self.__id = id + if isinstance(oid, ObjectId): + self.__id = oid.__id + elif isinstance(oid, types.StringType): + if len(oid) == 12: + self.__id = oid else: - raise InvalidId("%s is not a valid ObjectId" % id) + raise InvalidId("%s is not a valid ObjectId" % oid) else: - raise TypeError("id must be an instance of (str, ObjectId), not %s" % type(id)) + raise TypeError("id must be an instance of (str, ObjectId), not %s" % type(oid)) def __str__(self): return self.__id @@ -76,45 +72,3 @@ class ObjectId(object): """Return True if this ObjectId has not been used as an _id field. """ return self.__new - -class TestObjectId(unittest.TestCase): - def setUp(self): - pass - - def test_creation(self): - self.assertRaises(TypeError, ObjectId, 4) - self.assertRaises(TypeError, ObjectId, u"hello") - self.assertRaises(TypeError, ObjectId, 175.0) - self.assertRaises(TypeError, ObjectId, {"test": 4}) - self.assertRaises(TypeError, ObjectId, ["something"]) - self.assertRaises(InvalidId, ObjectId, "") - self.assertRaises(InvalidId, ObjectId, "12345678901") - self.assertRaises(InvalidId, ObjectId, "1234567890123") - self.assertTrue(ObjectId()) - self.assertTrue(ObjectId("123456789012")) - a = ObjectId() - self.assertTrue(ObjectId(a)) - - def test_repr_str(self): - self.assertEqual(repr(ObjectId("123456789012")), "ObjectId('123456789012')") - self.assertEqual(str(ObjectId("123456789012")), "123456789012") - - def test_cmp(self): - a = ObjectId() - self.assertEqual(a, ObjectId(a)) - self.assertEqual(ObjectId("123456789012"), ObjectId("123456789012")) - self.assertNotEqual(ObjectId(), ObjectId()) - self.assertNotEqual(ObjectId("123456789012"), "123456789012") - - def test_new(self): - a = ObjectId() - b = ObjectId("123456789012") - self.assertTrue(a.is_new()) - self.assertFalse(b.is_new()) - a._use() - b._use() - self.assertFalse(a.is_new()) - self.assertFalse(b.is_new()) - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_objectid.py b/test/test_objectid.py new file mode 100644 index 000000000..bd00fd154 --- /dev/null +++ b/test/test_objectid.py @@ -0,0 +1,48 @@ +"""Tests for the objectid module.""" + +import unittest + +from objectid import ObjectId +from errors import InvalidId + +class TestObjectId(unittest.TestCase): + def setUp(self): + pass + + def test_creation(self): + self.assertRaises(TypeError, ObjectId, 4) + self.assertRaises(TypeError, ObjectId, u"hello") + self.assertRaises(TypeError, ObjectId, 175.0) + self.assertRaises(TypeError, ObjectId, {"test": 4}) + self.assertRaises(TypeError, ObjectId, ["something"]) + self.assertRaises(InvalidId, ObjectId, "") + self.assertRaises(InvalidId, ObjectId, "12345678901") + self.assertRaises(InvalidId, ObjectId, "1234567890123") + self.assertTrue(ObjectId()) + self.assertTrue(ObjectId("123456789012")) + a = ObjectId() + self.assertTrue(ObjectId(a)) + + def test_repr_str(self): + self.assertEqual(repr(ObjectId("123456789012")), "ObjectId('123456789012')") + self.assertEqual(str(ObjectId("123456789012")), "123456789012") + + def test_cmp(self): + a = ObjectId() + self.assertEqual(a, ObjectId(a)) + self.assertEqual(ObjectId("123456789012"), ObjectId("123456789012")) + self.assertNotEqual(ObjectId(), ObjectId()) + self.assertNotEqual(ObjectId("123456789012"), "123456789012") + + def test_new(self): + a = ObjectId() + b = ObjectId("123456789012") + self.assertTrue(a.is_new()) + self.assertFalse(b.is_new()) + a._use() + b._use() + self.assertFalse(a.is_new()) + self.assertFalse(b.is_new()) + +if __name__ == "__main__": + unittest.main()